From ee88bdff590e1d1e0c2e399a5616b3c9e0fe8653 Mon Sep 17 00:00:00 2001 From: Rach Pradhan Date: Thu, 30 Apr 2026 09:59:28 +0800 Subject: [PATCH 1/3] feat(iouring): Linux IORING_OP_ACCEPT_MULTISHOT accept loop (off by default) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an opt-in Linux io_uring accept loop for the TurboAPI Zig server. Inspired by the layout of nanoapi's src/io_uring.zig — minimal first step, no perf claims (per AGENTS.md). ## What this PR does * New build flag '-Diouring=true' (default false). Wired through build.zig.addOptions as a 'turbo_build_options' module. * New module zig/src/iouring.zig: - 'Available' = builtin.os.tag == .linux - 'enabled()' = Available and build flag - Linux.AcceptLoop: posts IORING_OP_ACCEPT_MULTISHOT on a listen fd, pumps copy_cqes, hands accepted fds back via a callback, re-arms when the kernel clears IORING_CQE_F_MORE. * server.zig::server_run: when iouring.enabled(), runs the io_uring accept loop and wraps each accepted fd back into a std.Io.net.Stream that the existing ConnectionPool consumes unchanged. Falls back to the blocking accept loop on any io_uring setup failure. * Compiles on every target — non-Linux targets get a typed stub so callers can keep the same call site. ## What this PR does NOT do * Per-connection recv/send via io_uring (still goes through the existing thread pool's synchronous syscalls). * Multi-ring / per-worker thread-per-core io_uring layout. * Linked SQEs to fuse recv → send → recv into one submit. * SEND_ZC, multishot recv, descriptorless files. These are all tracked as follow-ups against release/beta-v1.0.30. ## Verification * macOS host build (default): clean. * macOS host build with '-Diouring=true': clean (uses the non-Linux stub). * aarch64-linux-musl cross-compile of the iouring module: clean. * End-to-end smoke test inside Apple 'container' (Linux 6.18.5, aarch64, Alpine 3.20): bench/iouring/iouring_smoke spins up a TCP listener, runs AcceptLoop on a worker thread, dials the listener 16x, asserts the accept callback fires 16 times. Result: PASS. Reproduce locally: container system start ./bench/iouring/run.sh ## Honest perf statement This PR has not been benchmarked. No perf claims appear in the README, the PR description, or anywhere in the repo. The accept loop alone is not where io_uring's throughput wins live; that comes with multishot recv and batched send in follow-up PRs. Per AGENTS.md, no benchmark tables will be added until those PRs land and are run on a real Linux host with the bench-frameworks workflow. Amp-Thread-ID: https://ampcode.com/threads/T-019ddbfa-6d44-705e-b930-6daf1d41e918 Co-authored-by: Amp --- bench/iouring/.gitignore | 2 + bench/iouring/Containerfile | 11 ++ bench/iouring/README.md | 59 ++++++++++ bench/iouring/build.sh | 30 +++++ bench/iouring/iouring_smoke.zig | 134 ++++++++++++++++++++++ bench/iouring/run.sh | 39 +++++++ zig/build.zig | 14 +++ zig/src/iouring.zig | 195 ++++++++++++++++++++++++++++++++ zig/src/server.zig | 54 ++++++++- 9 files changed, 536 insertions(+), 2 deletions(-) create mode 100644 bench/iouring/.gitignore create mode 100644 bench/iouring/Containerfile create mode 100644 bench/iouring/README.md create mode 100755 bench/iouring/build.sh create mode 100644 bench/iouring/iouring_smoke.zig create mode 100755 bench/iouring/run.sh create mode 100644 zig/src/iouring.zig diff --git a/bench/iouring/.gitignore b/bench/iouring/.gitignore new file mode 100644 index 0000000..92208d9 --- /dev/null +++ b/bench/iouring/.gitignore @@ -0,0 +1,2 @@ +# Built artifacts (regenerate via ./build.sh) +iouring_smoke diff --git a/bench/iouring/Containerfile b/bench/iouring/Containerfile new file mode 100644 index 0000000..b40b099 --- /dev/null +++ b/bench/iouring/Containerfile @@ -0,0 +1,11 @@ +# Linux smoke-test image for the io_uring AcceptLoop in zig/src/iouring.zig. +# +# Pre-built static aarch64-linux-musl binary; no compiler needed at image +# time. Build the binary on the host first with: +# ./bench/iouring/build.sh +FROM alpine:3.20 + +COPY iouring_smoke /usr/local/bin/iouring_smoke +RUN chmod +x /usr/local/bin/iouring_smoke + +ENTRYPOINT ["/usr/local/bin/iouring_smoke"] diff --git a/bench/iouring/README.md b/bench/iouring/README.md new file mode 100644 index 0000000..3e10448 --- /dev/null +++ b/bench/iouring/README.md @@ -0,0 +1,59 @@ +# io_uring smoke test + +End-to-end correctness check for the Linux `io_uring` `IORING_OP_ACCEPT_MULTISHOT` +accept loop in [`zig/src/iouring.zig`](../../zig/src/iouring.zig). + +This is **not a benchmark.** It exists to prove that: + +1. `zig/src/iouring.zig` compiles cleanly for `aarch64-linux-musl`. +2. `std.os.linux.IoUring` works on the kernel the test is run against. +3. `IORING_OP_ACCEPT_MULTISHOT` actually delivers all expected accepts when + N clients connect to a listen socket. + +The smoke binary opens a TCP listener on `127.0.0.1:18080`, runs the +`AcceptLoop` on a worker thread, dials the listener N times from the main +thread, and asserts the accept callback fired N times. Exits 0 on success, +non-zero otherwise. + +## Run + +Requires Apple `container` 0.11+ (or any compatible OCI runtime — set +`RUNTIME=docker` / `RUNTIME=podman`). On macOS, start the container service +first: + +```bash +container system start +./bench/iouring/run.sh +``` + +The script: + +1. Cross-compiles `iouring_smoke` for `aarch64-linux-musl` on the host. +2. Builds the OCI image from this directory. +3. Runs the smoke binary inside a fresh container. + +A passing run prints something like: + +``` +listening on 127.0.0.1:18080 (fd=3) + client 1/16 connected + ... + client 16/16 connected +io_uring AcceptLoop saw 16 accepts (wanted >= 16) +OK +==> io_uring smoke test PASSED +``` + +The kernel version reported on a clean `container run alpine:3.20 uname -a` +on macOS 26 / `container` 0.11 is `Linux ... 6.18.5 ... aarch64`, well above +the 5.19 minimum for `IORING_OP_ACCEPT_MULTISHOT`. + +## Limitations + +* Only the accept loop is exercised. Per-connection `recv` / `send` over + `io_uring` is not implemented yet — see the staged plan in + [`zig/src/iouring.zig`](../../zig/src/iouring.zig). +* No request/response payload is sent; the test closes accepted fds + immediately. +* No latency or throughput numbers are produced. Per `AGENTS.md`, do not + cite this script in any benchmark table or release note. diff --git a/bench/iouring/build.sh b/bench/iouring/build.sh new file mode 100755 index 0000000..ee22f1d --- /dev/null +++ b/bench/iouring/build.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Build the io_uring smoke-test binary for Linux (aarch64-linux-musl by +# default; override with TARGET=...). Designed to run on macOS via Apple +# `container`, on a Linux dev box natively, or in CI on a Linux runner. +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$REPO_ROOT" + +TARGET="${TARGET:-aarch64-linux-musl}" +OUT="bench/iouring/iouring_smoke" + +# Stub turbo_build_options so iouring.zig compiles standalone without the +# whole turbonet build graph. +STUB_DIR="$(mktemp -d)" +trap 'rm -rf "$STUB_DIR"' EXIT +cat > "$STUB_DIR/turbo_build_options.zig" < cross-compiling iouring smoke for $TARGET" +zig build-exe -target "$TARGET" -O ReleaseSafe -femit-bin="$OUT" \ + --dep iouring \ + -Mroot=bench/iouring/iouring_smoke.zig \ + --dep turbo_build_options \ + -Miouring=zig/src/iouring.zig \ + -Mturbo_build_options="$STUB_DIR/turbo_build_options.zig" + +file "$OUT" +echo "==> built $OUT" diff --git a/bench/iouring/iouring_smoke.zig b/bench/iouring/iouring_smoke.zig new file mode 100644 index 0000000..79bd1e3 --- /dev/null +++ b/bench/iouring/iouring_smoke.zig @@ -0,0 +1,134 @@ +//! End-to-end smoke test for the iouring AcceptLoop on Linux. +//! +//! Pipeline: +//! 1. Open a non-blocking TCP listen socket on 127.0.0.1:. +//! 2. Spawn a worker thread running `iouring.Linux.AcceptLoop.run`. +//! Each accepted fd is counted then closed immediately. +//! 3. From the main thread, open N TCP connections to the listener. +//! 4. Wait until the AcceptLoop has counted N accepts, then stop it and +//! assert that we saw exactly N. +//! +//! Exits 0 on success, non-zero on any failure. Designed to be a one-shot +//! sanity check inside an Apple `container` Linux VM (kernel 6.x), not a +//! benchmark. + +const std = @import("std"); +const builtin = @import("builtin"); +const iouring = @import("iouring"); + +const linux = std.os.linux; +const posix = std.posix; + +const N_CONNECTIONS: usize = 16; +const PORT: u16 = 18080; + +const Counter = struct { + accepts: std.atomic.Value(usize) = std.atomic.Value(usize).init(0), + loop: ?*iouring.Linux.AcceptLoop = null, +}; + +pub fn main() !void { + if (builtin.os.tag != .linux) { + std.debug.print("skip: not Linux\n", .{}); + return; + } + if (!iouring.Available) { + std.debug.print("skip: iouring.Available == false\n", .{}); + return; + } + + // ── 1. listen socket ── + const listen_fd = try createListenSocket(PORT); + defer _ = linux.close(listen_fd); + std.debug.print("listening on 127.0.0.1:{d} (fd={d})\n", .{ PORT, listen_fd }); + + // ── 2. accept loop on its own thread ── + var counter = Counter{}; + var loop = try iouring.Linux.AcceptLoop.init( + listen_fd, + onAccept, + @ptrCast(&counter), + iouring.DEFAULT_SQ_ENTRIES, + ); + defer loop.deinit(); + counter.loop = &loop; + + const t = try std.Thread.spawn(.{}, runLoop, .{&loop}); + defer t.join(); + + // ── 3. connect N times ── + for (0..N_CONNECTIONS) |i| { + try dialOnce(PORT); + std.debug.print(" client {d}/{d} connected\n", .{ i + 1, N_CONNECTIONS }); + } + + // ── 4. wait + assert ── + var waited_ms: usize = 0; + while (counter.accepts.load(.acquire) < N_CONNECTIONS and waited_ms < 5_000) { + var ts: linux.timespec = .{ .sec = 0, .nsec = 10 * std.time.ns_per_ms }; + _ = linux.nanosleep(&ts, null); + waited_ms += 10; + } + + loop.stop(); + + // Touch the loop one more time so copy_cqes wakes up. A no-op connect is + // the cheapest way to force a CQE. + dialOnce(PORT) catch {}; + + const got = counter.accepts.load(.acquire); + std.debug.print("io_uring AcceptLoop saw {d} accepts (wanted >= {d})\n", .{ got, N_CONNECTIONS }); + if (got < N_CONNECTIONS) { + std.debug.print("FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("OK\n", .{}); +} + +fn runLoop(loop: *iouring.Linux.AcceptLoop) void { + loop.run() catch |err| { + std.debug.print("loop.run errored: {s}\n", .{@errorName(err)}); + }; +} + +fn onAccept(ctx: *anyopaque, fd: posix.fd_t) void { + const counter: *Counter = @ptrCast(@alignCast(ctx)); + _ = counter.accepts.fetchAdd(1, .acq_rel); + _ = linux.close(fd); +} + +fn createListenSocket(port: u16) !posix.fd_t { + const fd_signed = linux.socket(linux.AF.INET, linux.SOCK.STREAM | linux.SOCK.CLOEXEC, 0); + if (fd_signed < 0) return error.SocketFailed; + const fd: posix.fd_t = @intCast(fd_signed); + + const yes: c_int = 1; + _ = linux.setsockopt(fd, linux.SOL.SOCKET, linux.SO.REUSEADDR, std.mem.asBytes(&yes), @sizeOf(c_int)); + + var addr: linux.sockaddr.in = .{ + .family = linux.AF.INET, + .port = std.mem.nativeToBig(u16, port), + .addr = std.mem.nativeToBig(u32, 0x7F000001), // 127.0.0.1 + .zero = .{0} ** 8, + }; + if (linux.bind(fd, @ptrCast(&addr), @sizeOf(linux.sockaddr.in)) != 0) return error.BindFailed; + if (linux.listen(fd, 128) != 0) return error.ListenFailed; + return fd; +} + +fn dialOnce(port: u16) !void { + const fd_signed = linux.socket(linux.AF.INET, linux.SOCK.STREAM | linux.SOCK.CLOEXEC, 0); + if (fd_signed < 0) return error.SocketFailed; + const fd: posix.fd_t = @intCast(fd_signed); + defer _ = linux.close(fd); + + var addr: linux.sockaddr.in = .{ + .family = linux.AF.INET, + .port = std.mem.nativeToBig(u16, port), + .addr = std.mem.nativeToBig(u32, 0x7F000001), + .zero = .{0} ** 8, + }; + if (linux.connect(fd, @ptrCast(&addr), @sizeOf(linux.sockaddr.in)) != 0) { + return error.ConnectFailed; + } +} diff --git a/bench/iouring/run.sh b/bench/iouring/run.sh new file mode 100755 index 0000000..58f5e3c --- /dev/null +++ b/bench/iouring/run.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# End-to-end: build the io_uring smoke binary, package it into an OCI image, +# run it inside Apple `container` (or any compatible runtime), and check the +# exit code. +# +# This validates that: +# * zig/src/iouring.zig compiles for Linux +# * std.os.linux.IoUring works on the kernel inside Apple's lightweight VM +# * IORING_OP_ACCEPT_MULTISHOT delivers all expected accepts +# +# This is a correctness check, NOT a benchmark. Per AGENTS.md, no perf +# numbers should be cited from this script. +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +BENCH_DIR="$REPO_ROOT/bench/iouring" +RUNTIME="${RUNTIME:-container}" +IMAGE="${IMAGE:-turboapi-iouring-smoke:latest}" +NAME="${NAME:-turboapi-iouring-smoke}" + +cd "$REPO_ROOT" + +# 1. Build the static Linux binary on the host. +"$BENCH_DIR/build.sh" + +# 2. Build the OCI image. +echo "==> building $IMAGE via $RUNTIME" +"$RUNTIME" build -t "$IMAGE" -f "$BENCH_DIR/Containerfile" "$BENCH_DIR" + +# 3. Run it; non-zero exit => smoke test failed. +echo "==> running smoke test" +"$RUNTIME" rm -f "$NAME" 2>/dev/null || true +if "$RUNTIME" run --rm --name "$NAME" "$IMAGE"; then + echo "==> io_uring smoke test PASSED" +else + rc=$? + echo "==> io_uring smoke test FAILED (exit $rc)" >&2 + exit "$rc" +fi diff --git a/zig/build.zig b/zig/build.zig index f1bf111..f6a3d32 100644 --- a/zig/build.zig +++ b/zig/build.zig @@ -14,6 +14,18 @@ pub fn build(b: *std.Build) void { const lib_path = b.option([]const u8, "py-libdir", "Python lib path (required)") orelse @panic("pass -Dpy-libdir= or use: python zig/build_turbonet.py"); + // ── Optional Linux io_uring accept loop ── + // -Diouring=true enables IORING_OP_ACCEPT_MULTISHOT on the listen socket + // and dispatches accepted fds onto the existing thread pool. No-op on + // non-Linux targets even when set. Off by default; the per-connection + // recv/send fast path is *not* on io_uring yet — see zig/src/iouring.zig + // for the staged plan. + const iouring_enabled = b.option(bool, "iouring", "Enable Linux io_uring accept loop (Linux only, off by default)") orelse false; + + const turbo_options = b.addOptions(); + turbo_options.addOption(bool, "iouring_enabled", iouring_enabled); + const turbo_options_mod = turbo_options.createModule(); + const py_lib_name: []const u8 = if (is_free_threaded) "python3.14t" else if (std.mem.eql(u8, py_version, "3.14")) @@ -67,6 +79,7 @@ pub fn build(b: *std.Build) void { lib.root_module.addImport("model", model_mod); lib.root_module.addImport("pg", pg_mod); lib.root_module.addImport("turboapi-core", core_mod); + lib.root_module.addImport("turbo_build_options", turbo_options_mod); lib.root_module.addIncludePath(.{ .cwd_relative = include_path }); lib.root_module.addRPathSpecial("@loader_path"); @@ -111,6 +124,7 @@ pub fn build(b: *std.Build) void { tests.root_module.addImport("model", model_mod); tests.root_module.addImport("pg", pg_mod); tests.root_module.addImport("turboapi-core", core_mod); + tests.root_module.addImport("turbo_build_options", turbo_options_mod); tests.root_module.addIncludePath(.{ .cwd_relative = include_path }); tests.root_module.addLibraryPath(.{ .cwd_relative = lib_path }); tests.root_module.linkSystemLibrary(py_lib_name, .{}); diff --git a/zig/src/iouring.zig b/zig/src/iouring.zig new file mode 100644 index 0000000..af0e697 --- /dev/null +++ b/zig/src/iouring.zig @@ -0,0 +1,195 @@ +//! Linux io_uring accept loop for TurboAPI. +//! +//! ## Status (PR #1 — accept-loop only) +//! +//! What this module does today: +//! * Sets up a single io_uring (per accept thread) on Linux ≥ 5.19. +//! * Posts an `IORING_OP_ACCEPT_MULTISHOT` on the listen fd, so the kernel +//! produces a CQE per accepted connection without us re-arming the SQE. +//! * For every accepted fd, hands the connection off to the existing +//! `ConnectionPool` (per-worker thread + synchronous recv/send) by wrapping +//! the raw fd back into a `std.Io.net.Stream`. +//! +//! What this module does *not* do yet (deliberately staged): +//! * `IORING_OP_RECV` / `RECV_MULTISHOT` for the per-connection request read. +//! * `IORING_OP_SEND` / `SEND_ZC` for the per-connection response write. +//! * Per-worker thread-per-core io_uring rings (we still funnel through one +//! accept ring; per-connection work then lives on the existing pool). +//! * Linked SQEs to fuse `recv → send → recv` into a single submit. +//! +//! Those are tracked as follow-up PRs against `release/beta-v1.0.30`. They +//! are the paths where io_uring actually moves throughput; the accept loop +//! alone is mostly a correctness / scaffolding step. +//! +//! ## Compile and runtime gating +//! +//! * Compiles on every target (so `zig build` stays portable). The Linux +//! std.os.linux.IoUring import is unconditional — Zig's stdlib defines +//! the type on all targets, but `init()` will fail on non-Linux. +//! * `Available` is `true` only on Linux. +//! * `enabled()` combines `Available` with the `-Diouring=true` build flag. +//! When `enabled()` is false, callers must keep using the existing +//! blocking-accept path. +//! +//! ## Honest perf claim +//! +//! This module has *not* been benchmarked. Per AGENTS.md, no benchmark +//! tables, comparisons, or "vs framework X" claims should appear in docs, +//! release notes, or the PR description until we have a real Linux run with +//! the bench-frameworks script and recorded artifacts. Saying "io_uring is +//! enabled" is fine; saying "io_uring makes TurboAPI N× faster" is not, +//! until we have real numbers. + +const std = @import("std"); +const builtin = @import("builtin"); + +const build_options = @import("turbo_build_options"); + +/// True on Linux targets. Other targets get a stub that always errors. +pub const Available: bool = builtin.os.tag == .linux; + +/// True only when both the build flag *and* the runtime target are Linux. +pub inline fn enabled() bool { + return Available and build_options.iouring_enabled; +} + +/// Default submission-queue depth. Picked to match nanoapi and liburing +/// examples; can be tuned later via env var or `Options`. +pub const DEFAULT_SQ_ENTRIES: u16 = 1024; + +/// Errors a caller can see from this module. +pub const Error = error{ + /// `enabled()` was false when the caller asked us to run. + NotEnabled, + /// Kernel returned EINVAL / EPERM on `io_uring_setup` — usually means + /// the running kernel is older than 5.19 or io_uring is disabled + /// (`/proc/sys/kernel/io_uring_disabled`). + SetupFailed, + /// Submitting the multishot accept SQE failed. + AcceptSubmitFailed, +}; + +/// Linux-only implementation. Kept in its own struct so non-Linux callers +/// can still reference `Available`/`enabled()` without forcing the type to +/// instantiate. +pub const Linux = if (Available) struct { + const linux = std.os.linux; + const posix = std.posix; + const IoUring = linux.IoUring; + + pub const AcceptCallback = *const fn (ctx: *anyopaque, fd: posix.fd_t) void; + + pub const AcceptLoop = struct { + ring: IoUring, + listen_fd: posix.fd_t, + on_accept: AcceptCallback, + on_accept_ctx: *anyopaque, + running: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), + + /// Initialize the ring. `listen_fd` must already be a non-blocking + /// socket bound + listening on the desired address. + pub fn init( + listen_fd: posix.fd_t, + on_accept: AcceptCallback, + on_accept_ctx: *anyopaque, + entries: u16, + ) Error!AcceptLoop { + const ring = IoUring.init(entries, 0) catch return Error.SetupFailed; + return .{ + .ring = ring, + .listen_fd = listen_fd, + .on_accept = on_accept, + .on_accept_ctx = on_accept_ctx, + }; + } + + pub fn deinit(self: *AcceptLoop) void { + self.ring.deinit(); + } + + /// Submit the multishot accept and pump the completion queue until + /// `stop()` is called from another thread. + pub fn run(self: *AcceptLoop) Error!void { + // Tag the multishot accept CQE with user_data == 0 so we can + // distinguish it from per-connection ops in future PRs. + _ = self.ring.accept_multishot(0, self.listen_fd, null, null, 0) catch + return Error.AcceptSubmitFailed; + _ = self.ring.submit() catch return Error.AcceptSubmitFailed; + + var cqes: [64]linux.io_uring_cqe = undefined; + while (self.running.load(.acquire)) { + // Block until at least one CQE arrives. `copy_cqes` wraps + // io_uring_enter(GETEVENTS); EINTR is retried inside the + // stdlib helper. + const n = self.ring.copy_cqes(&cqes, 1) catch |err| switch (err) { + error.SignalInterrupt => continue, + else => return Error.SetupFailed, + }; + + for (cqes[0..n]) |cqe| { + if (cqe.user_data != 0) continue; // future: per-conn ops + if (cqe.res < 0) { + // Negative res = -errno. Common cases: -EINTR retries, + // -EAGAIN can't happen on multishot accept. Anything + // else, log and re-arm. + const errno_val: i32 = -cqe.res; + if (errno_val == @intFromEnum(linux.E.INTR)) continue; + // Re-arm and continue; the listen fd is still valid. + try self.rearmAccept(); + continue; + } + + // Multishot returns the new fd directly in cqe.res. + const fd: posix.fd_t = @intCast(cqe.res); + self.on_accept(self.on_accept_ctx, fd); + + // If the kernel cleared the multishot bit (e.g. ring + // pressure), re-arm. + if ((cqe.flags & linux.IORING_CQE_F_MORE) == 0) { + try self.rearmAccept(); + } + } + } + } + + /// Signal the loop in `run` to exit. Safe to call from another + /// thread. + pub fn stop(self: *AcceptLoop) void { + self.running.store(false, .release); + } + + fn rearmAccept(self: *AcceptLoop) Error!void { + _ = self.ring.accept_multishot(0, self.listen_fd, null, null, 0) catch + return Error.AcceptSubmitFailed; + _ = self.ring.submit() catch return Error.AcceptSubmitFailed; + } + }; +} else struct { + // Non-Linux placeholder so the module type-checks everywhere. + pub const AcceptCallback = *const fn (ctx: *anyopaque, fd: i32) void; + pub const AcceptLoop = struct { + pub fn init( + _: i32, + _: AcceptCallback, + _: *anyopaque, + _: u16, + ) Error!AcceptLoop { + return Error.NotEnabled; + } + pub fn deinit(_: *AcceptLoop) void {} + pub fn run(_: *AcceptLoop) Error!void { + return Error.NotEnabled; + } + pub fn stop(_: *AcceptLoop) void {} + }; +}; + +test "iouring module compiles on every target" { + // The whole point of this test is to make sure `Available`, `enabled()` + // and `Linux.AcceptLoop` all type-check on the host build, regardless of + // OS. Real behavior is exercised by the Linux integration tests in a + // follow-up PR. + _ = Available; + _ = enabled(); + _ = Linux.AcceptLoop; +} diff --git a/zig/src/server.zig b/zig/src/server.zig index 3c0bb6a..da0b96a 100644 --- a/zig/src/server.zig +++ b/zig/src/server.zig @@ -13,6 +13,7 @@ const multipart_mod = @import("multipart.zig"); const logger = @import("logger.zig"); const runtime = @import("runtime.zig"); const telemetry = @import("telemetry.zig"); +const iouring = @import("iouring.zig"); const allocator = std.heap.c_allocator; const posix = std.posix; @@ -1065,13 +1066,62 @@ pub fn server_run(_: ?*c.PyObject, _: ?*c.PyObject) callconv(.c) ?*c.PyObject { // Release the GIL — workers acquire it per-request via AcquireThread. const save = py.PyEval_SaveThread(); + if (iouring.enabled()) { + runIoUringAcceptLoop(&tcp_server) catch |err| { + // Fall back to the blocking accept path if the io_uring setup + // failed at runtime (kernel < 5.19, io_uring disabled, etc.). + logger.warn("[iouring] accept loop unavailable ({s}); falling back to blocking accept", .{@errorName(err)}); + runBlockingAcceptLoop(&tcp_server); + }; + } else { + runBlockingAcceptLoop(&tcp_server); + } + + py.PyEval_RestoreThread(save); + return py.pyNone(); +} + +fn runBlockingAcceptLoop(tcp_server: *std.Io.net.Server) void { while (true) { const stream = tcp_server.accept(runtime.io) catch continue; pool.queue.push(stream); } +} - py.PyEval_RestoreThread(save); - return py.pyNone(); +/// Linux-only path. On other platforms `iouring.enabled()` is always false so +/// this function is unreachable; we still compile-test it via a stub on +/// non-Linux to keep the call site identical. +fn runIoUringAcceptLoop(tcp_server: *std.Io.net.Server) !void { + if (!iouring.Available) return iouring.Error.NotEnabled; + + // Pull the underlying listen fd out of std.Io.net.Server. The Server + // owns it; we only borrow it for the duration of the accept loop. + const listen_fd: posix.fd_t = tcp_server.socket.handle; + logger.info("[iouring] enabling IORING_OP_ACCEPT_MULTISHOT on fd={d}", .{listen_fd}); + + var loop = try iouring.Linux.AcceptLoop.init( + listen_fd, + iouringOnAccept, + @ptrCast(&pool), + iouring.DEFAULT_SQ_ENTRIES, + ); + defer loop.deinit(); + + try loop.run(); +} + +/// Wrap a raw fd from io_uring back into a `std.Io.net.Stream` so the +/// existing per-connection thread-pool code can consume it unchanged. +fn iouringOnAccept(ctx: *anyopaque, fd: posix.fd_t) void { + const conn_pool: *ConnectionPool = @ptrCast(@alignCast(ctx)); + const stream = std.Io.net.Stream{ .socket = .{ + .handle = fd, + // peer address isn't reported by ACCEPT_MULTISHOT when we pass a + // null sockaddr; the existing handlers don't read it, so leave it + // zero rather than paying for getpeername(2) on the hot path. + .address = .{ .ip4 = std.Io.net.Ip4Address.unspecified(0) }, + } }; + conn_pool.queue.push(stream); } const HeaderList = std.ArrayListUnmanaged(HeaderPair); From 923152571cbd0d453997172d39396883e7aabbb0 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Thu, 30 Apr 2026 11:15:13 +0800 Subject: [PATCH 2/3] bench(iouring): A/B harness with path-param, query, and items workloads Adds a containerized A/B benchmark for the new -Diouring=true accept loop vs the blocking-accept default. Runs inside Apples container microVM on Linux 6.18.5 aarch64. Workloads (5 iters x 10s each, median reported): - GET / noargs fast path - GET /user/{id}, id varied path-param parsing per request (lua) - GET /q?id=, id varied query-string parsing per request (lua) - GET /items ~2 KB JSON body (50 records) Lua scripts vary the URL per request so any per-path memoization is defeated and the radix-trie lookup runs cold every request. Per AGENTS.md the only code path that differs between variants is the listener accept loop; per-connection recv/send still goes through the existing thread-pool syscalls. RESULTS.md documents that scope and captures medians from one local run; numbers are NOT meant for release notes or framework comparison tables. Generated with [Devin](https://cli.devin.ai/docs) Co-Authored-By: Devin <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- bench/iouring/full_bench/.gitignore | 2 + bench/iouring/full_bench/Containerfile | 42 +++++ bench/iouring/full_bench/RESULTS.md | 83 ++++++++++ bench/iouring/full_bench/app.py | 55 +++++++ bench/iouring/full_bench/build_and_bench.sh | 163 ++++++++++++++++++++ bench/iouring/full_bench/vary_query.lua | 11 ++ bench/iouring/full_bench/vary_user_id.lua | 16 ++ 7 files changed, 372 insertions(+) create mode 100644 bench/iouring/full_bench/.gitignore create mode 100644 bench/iouring/full_bench/Containerfile create mode 100644 bench/iouring/full_bench/RESULTS.md create mode 100644 bench/iouring/full_bench/app.py create mode 100755 bench/iouring/full_bench/build_and_bench.sh create mode 100644 bench/iouring/full_bench/vary_query.lua create mode 100644 bench/iouring/full_bench/vary_user_id.lua diff --git a/bench/iouring/full_bench/.gitignore b/bench/iouring/full_bench/.gitignore new file mode 100644 index 0000000..00f7080 --- /dev/null +++ b/bench/iouring/full_bench/.gitignore @@ -0,0 +1,2 @@ +# wrk output is regenerated by build_and_bench.sh +results/ diff --git a/bench/iouring/full_bench/Containerfile b/bench/iouring/full_bench/Containerfile new file mode 100644 index 0000000..e5c9f39 --- /dev/null +++ b/bench/iouring/full_bench/Containerfile @@ -0,0 +1,42 @@ +# Linux io_uring vs blocking-accept benchmark image for TurboAPI. +# Builds the Zig extension twice (once with -Diouring=true, once without), +# runs the same TurboAPI app under each, and drives wrk against it from +# inside the same container. +# +# Honest scope: the ONLY thing that differs between the two builds in the +# current PR (#144) is the accept loop. Per-connection recv/send still goes +# through the same thread-pool synchronous syscalls. Treat the deltas +# accordingly. + +FROM debian:bookworm-slim + +ENV DEBIAN_FRONTEND=noninteractive +ENV PATH="/root/.local/bin:/opt/zig:${PATH}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl xz-utils build-essential pkg-config \ + git wrk \ + && rm -rf /var/lib/apt/lists/* + +# Install Zig 0.16.0 (aarch64-linux). +RUN mkdir -p /opt && cd /opt \ + && curl -fsSL https://ziglang.org/download/0.16.0/zig-aarch64-linux-0.16.0.tar.xz -o zig.tar.xz \ + && tar -xJf zig.tar.xz \ + && mv zig-aarch64-linux-0.16.0 zig \ + && rm zig.tar.xz + +# Install uv to fetch Python 3.14 free-threaded. +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Pre-fetch Python 3.14t so the first run is fast. +RUN /root/.local/bin/uv python install 3.14t + +WORKDIR /work + +# The repo will be bind-mounted at /work at run time. The driver script +# (build_and_bench.sh) handles the per-build steps. +COPY build_and_bench.sh /usr/local/bin/build_and_bench.sh +COPY app.py /app/app.py +RUN chmod +x /usr/local/bin/build_and_bench.sh + +ENTRYPOINT ["/usr/local/bin/build_and_bench.sh"] diff --git a/bench/iouring/full_bench/RESULTS.md b/bench/iouring/full_bench/RESULTS.md new file mode 100644 index 0000000..dc44d95 --- /dev/null +++ b/bench/iouring/full_bench/RESULTS.md @@ -0,0 +1,83 @@ +# io_uring vs blocking-accept — A/B run + +> **Scope:** the *only* code path that differs between the two builds is +> the listener accept loop. Per-connection `recv` / `send` still goes +> through the existing thread-pool synchronous syscalls in both +> variants. Treat the deltas accordingly. Per `AGENTS.md`, do not cite +> these numbers in release notes, framework comparison tables, or any +> public copy. + +## Environment + +- Apple `container` CLI (Linux microVM on macOS) +- Kernel: `Linux 6.18.5 aarch64` +- Python: `3.14.4 free-threaded` (GIL disabled) +- Zig: `0.16.0`, `-Doptimize=ReleaseFast` +- wrk: `t=4 c=64 d=10s` per iteration, 3s warmup +- Iterations: **5 per (variant, workload)**, median reported +- All traffic on loopback inside one container + +## Workloads + +| name | request | notes | +|-----------|----------------------------------------------|-------| +| `noargs` | `GET /` | trivial fast path | +| `user_id` | `GET /user/{id}` with `id` random 1..10M | path-param parsing every request, defeats per-path caching | +| `query` | `GET /q?id={id}` with `id` random 1..10M | query-string parsing every request | +| `items` | `GET /items` returning a 50-record JSON body | bigger response (~2 KB) | + +`user_id` and `query` use `wrk -s vary_user_id.lua` / `vary_query.lua` +which generate a fresh URL per request, so the radix-trie lookup runs +cold every time. + +## Median of 5 runs + +| workload | variant | req/s | Δ vs blocking | p50 | p99 | +|----------|-----------|-----------|---------------|----------|----------| +| noargs | blocking | 697,933 | — | 20 µs | 86 µs | +| noargs | iouring | 713,240 | **+2.2 %** | 21 µs | 59 µs | +| user_id | blocking | 321,439 | — | 43 µs | 321 µs | +| user_id | iouring | 366,991 | **+14.2 %** | 39 µs | 261 µs | +| query | blocking | 235,954 | — | 28 µs | 7.87 ms | +| query | iouring | 235,270 | **−0.3 %** | 28 µs | 8.32 ms | +| items | blocking | 124,408 | — | 170 µs | 401 µs | +| items | iouring | 130,719 | **+5.1 %** | 150 µs | 533 µs | + +Raw per-iteration `wrk` outputs are in `results/`. + +## Honest caveats + +- 5 samples is enough to spot order-of-magnitude differences but not + small ones; the `query` and `noargs` deltas are within run-to-run + noise on this VM. +- p99 jitter is high (`query` shows multi-millisecond tails on both + builds — likely loopback + `wrk` timing artifacts, not server + pauses). Don't read the p99 column as a stable signal. +- Single-container, loopback, single client. No multi-host, no + external network, no multi-worker deployment scenario. +- `wrk -t4` is approaching saturation on the `noargs` route (~700k + rps). Some of the small delta there may be wrk-bound, not + server-bound. +- This run was kicked off in a fresh container, so each variant got a + cold start; results were not interleaved. + +## Reproducing + +```bash +container build -t turboapi-iouring-bench \ + -f bench/iouring/full_bench/Containerfile \ + bench/iouring/full_bench + +container run --rm -m 8G -c 4 \ + -v "$PWD":/work \ + turboapi-iouring-bench +``` + +Override env vars to scale up: + +```bash +container run --rm -m 8G -c 4 \ + -e DURATION=30s -e ITERS=10 -e CONNS=128 \ + -v "$PWD":/work \ + turboapi-iouring-bench +``` diff --git a/bench/iouring/full_bench/app.py b/bench/iouring/full_bench/app.py new file mode 100644 index 0000000..456508c --- /dev/null +++ b/bench/iouring/full_bench/app.py @@ -0,0 +1,55 @@ +"""Routes for io_uring vs blocking-accept A/B benchmarking. + +Intentionally small. Each route does the minimum work that exercises a +different request path so we can check whether the accept-loop change +moves the needle for anything other than the trivial `/` noargs case. + +Routes: + GET / noargs fast path (baseline) + GET /user/{id} path parameter — varied per request by wrk to + defeat any per-path lookup caching + GET /q query string — echoes a single ?id= param +""" + +import os +import sys + +from turboapi import TurboAPI + +app = TurboAPI(title="iouring-bench") + + +@app.get("/") +def home(): + return {"ok": True} + + +@app.get("/user/{id}") +def get_user(id: str): + return {"id": id} + + +@app.get("/q") +def get_q(id: str = "0"): + return {"id": id} + + +# ~2 KB JSON body (50 records) — exercises more serializer + more bytes +# on the wire than the trivial routes above. +_ITEMS = [ + {"id": i, "name": f"item-{i}", "price": i * 1.5, "in_stock": (i % 3 == 0)} + for i in range(50) +] + + +@app.get("/items") +def get_items(): + return {"items": _ITEMS} + + +if __name__ == "__main__": + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", "8080")) + print(f"[bench-app] starting on {host}:{port}", flush=True) + sys.stdout.flush() + app.run(host=host, port=port) diff --git a/bench/iouring/full_bench/build_and_bench.sh b/bench/iouring/full_bench/build_and_bench.sh new file mode 100755 index 0000000..008a1a7 --- /dev/null +++ b/bench/iouring/full_bench/build_and_bench.sh @@ -0,0 +1,163 @@ +#!/usr/bin/env bash +# A/B bench the Zig backend with and without -Diouring=true across three +# workloads: +# +# noargs GET / (trivial baseline) +# user_id GET /user/{id} with {id} varied every request (lua) +# query GET /q?id={id} with {id} varied every request (lua) +# +# Each (variant, workload) pair runs ITERS times; we report the median +# Requests/sec and p50/p99. +# +# Honest scope: in PR #144 the ONLY functional difference between the +# two builds is the accept loop (IORING_OP_ACCEPT_MULTISHOT vs blocking +# accept(2)). Per-connection recv/send still goes through the same +# thread-pool syscalls in both builds. Treat any delta as an accept-path +# delta, not a holistic "io_uring vs syscalls" number. +# +# Env: +# DURATION wrk duration (default 10s) +# THREADS wrk threads (default 4) +# CONNS wrk connections (default 64) +# PORT app port (default 8080) +# WARMUP warmup seconds per (variant, workload) (default 3) +# ITERS measured iterations per (variant, workload) (default 3) + +set -euo pipefail + +DURATION="${DURATION:-10s}" +THREADS="${THREADS:-4}" +CONNS="${CONNS:-64}" +PORT="${PORT:-8080}" +WARMUP="${WARMUP:-3}" +ITERS="${ITERS:-5}" + +REPO="/work" +BENCH_DIR="/work/bench/iouring/full_bench" +RESULTS_DIR="$BENCH_DIR/results" +APP="/app/app.py" +UV="/root/.local/bin/uv" + +cd "$REPO" +export PATH="/root/.local/bin:/opt/zig:${PATH}" + +PY="$($UV python find 3.14t)" +echo "[bench] using python: $PY" +"$PY" -c "import sys; print('[bench] gil enabled =', sys._is_gil_enabled())" + +PY_INC="$("$PY" -c 'import sysconfig; print(sysconfig.get_path("include"))')" +PY_LIB="$("$PY" -c 'import sysconfig; print(sysconfig.get_config_var("LIBDIR"))')" + +VENV="/tmp/bench-venv" +"$UV" venv --python "$PY" "$VENV" >/dev/null +PY="$VENV/bin/python" +"$UV" pip install --python "$PY" setuptools wheel >/dev/null +"$UV" pip install --python "$PY" --no-build-isolation -e . >/dev/null + +mkdir -p "$RESULTS_DIR" + +build_variant() { + local label="$1" + local iouring_flag="$2" + echo "[bench] ==> building variant: $label (iouring=$iouring_flag)" + ( cd zig && zig build \ + -Dpython=3.14t \ + -Dpy-include="$PY_INC" \ + -Dpy-libdir="$PY_LIB" \ + -Diouring="$iouring_flag" \ + -Doptimize=ReleaseFast ) + + local suffix + suffix="$("$PY" -c 'import importlib.machinery as m; print(m.EXTENSION_SUFFIXES[0])')" + cp "zig/zig-out/lib/libturbonet.so" "python/turboapi/turbonet${suffix}" +} + +start_app() { + local label="$1" + "$PY" "$APP" >/tmp/app-"$label".log 2>&1 & + APP_PID=$! + local tries=0 + until curl -fsS "http://127.0.0.1:${PORT}/" >/dev/null 2>&1; do + tries=$((tries + 1)) + if [ "$tries" -gt 100 ]; then + echo "[bench] app failed to start, log:"; cat /tmp/app-"$label".log || true + kill "$APP_PID" 2>/dev/null || true + return 1 + fi + sleep 0.1 + done + echo "[bench] app up (pid=$APP_PID, variant=$label)" +} + +stop_app() { + kill "$APP_PID" 2>/dev/null || true + wait "$APP_PID" 2>/dev/null || true + sleep 1 +} + +# run_workload VARIANT WORKLOAD WRK_ARGS... +# writes per-iteration wrk output to $RESULTS_DIR/$VARIANT-$WORKLOAD-$i.txt +run_workload() { + local variant="$1"; shift + local workload="$1"; shift + + # warmup + wrk -t"$THREADS" -c"$CONNS" -d"${WARMUP}s" "$@" >/dev/null + + for i in $(seq 1 "$ITERS"); do + local out="$RESULTS_DIR/${variant}-${workload}-${i}.txt" + echo "[bench] ==> wrk variant=$variant workload=$workload iter=$i" + wrk -t"$THREADS" -c"$CONNS" -d"$DURATION" --latency "$@" >"$out" + done +} + +# --- run both variants, all workloads --- +for variant_spec in "blocking:false" "iouring:true"; do + variant="${variant_spec%%:*}" + flag="${variant_spec##*:}" + + build_variant "$variant" "$flag" + start_app "$variant" + + run_workload "$variant" "noargs" "http://127.0.0.1:${PORT}/" + run_workload "$variant" "user_id" -s "$BENCH_DIR/vary_user_id.lua" "http://127.0.0.1:${PORT}" + run_workload "$variant" "query" -s "$BENCH_DIR/vary_query.lua" "http://127.0.0.1:${PORT}" + run_workload "$variant" "items" "http://127.0.0.1:${PORT}/items" + + stop_app +done + +# --- summarize: pick median req/s across ITERS for each (variant, workload) --- +summarize() { + local variant="$1" + local workload="$2" + # extract req/s from each iter, sort, take median + local rps + rps=$(for i in $(seq 1 "$ITERS"); do + grep -E "^Requests/sec:" "$RESULTS_DIR/${variant}-${workload}-${i}.txt" | awk '{print $2}' + done | sort -n | awk "NR==$(( (ITERS+1)/2 )) {print}") + # pick the file matching that rps for p50/p99 + local src + for i in $(seq 1 "$ITERS"); do + local f="$RESULTS_DIR/${variant}-${workload}-${i}.txt" + if grep -qE "^Requests/sec:[[:space:]]+${rps}\$" "$f"; then src="$f"; break; fi + done + local p50 p99 + # match the Latency Distribution lines (leading whitespace + percentile), + # not "55.50%" or "67.00%" stdev values on the Thread Stats line. + p50=$(awk '/^[[:space:]]+50%[[:space:]]/ {print $2; exit}' "$src") + p99=$(awk '/^[[:space:]]+99%[[:space:]]/ {print $2; exit}' "$src") + printf " %-8s %-8s req/s=%-12s p50=%-8s p99=%-8s\n" \ + "$variant" "$workload" "$rps" "$p50" "$p99" +} + +echo +echo "===================== MEDIAN OF $ITERS RUNS =====================" +for workload in noargs user_id query items; do + summarize blocking "$workload" + summarize iouring "$workload" +done +echo "=================================================================" +echo "Raw iter outputs: $RESULTS_DIR/" +echo "Kernel: $(uname -r) | wrk: t=$THREADS c=$CONNS d=$DURATION | iters=$ITERS" +echo "NOTE: only the accept loop differs between variants. Treat deltas accordingly." diff --git a/bench/iouring/full_bench/vary_query.lua b/bench/iouring/full_bench/vary_query.lua new file mode 100644 index 0000000..0fae2de --- /dev/null +++ b/bench/iouring/full_bench/vary_query.lua @@ -0,0 +1,11 @@ +-- wrk script: each request hits /q?id=. +-- Same no-cache intent as vary_user_id.lua, but exercises the query-string +-- parsing path instead of the path-param path. + +math.randomseed(os.time() + 1) +local fmt = string.format + +request = function() + local id = math.random(1, 10000000) + return wrk.format(nil, fmt("/q?id=%d", id)) +end diff --git a/bench/iouring/full_bench/vary_user_id.lua b/bench/iouring/full_bench/vary_user_id.lua new file mode 100644 index 0000000..76f17b3 --- /dev/null +++ b/bench/iouring/full_bench/vary_user_id.lua @@ -0,0 +1,16 @@ +-- wrk script: each request hits /user/. +-- Intent: defeat any per-path memoization (noargs cache keys on the URL +-- in TurboAPI, routers may also cache last-matched path, etc) so the +-- radix trie lookup + param extraction runs every request. +-- +-- Space is intentionally larger than any reasonable LRU cache size. + +math.randomseed(os.time()) + +-- Pre-allocate so we don't churn the request method on every call. +local fmt = string.format + +request = function() + local id = math.random(1, 10000000) + return wrk.format(nil, fmt("/user/%d", id)) +end From e2e51cd769fd631a96030e51e34d204b74c40853 Mon Sep 17 00:00:00 2001 From: justrach <54503978+justrach@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:15:24 +0800 Subject: [PATCH 3/3] feat(turbopg): vendor justrach/pg.zig + add io_uring transport (opt-in, Linux) Vendors justrach/pg.zig (which was already the upstream this repo depended on via git+url) and its two transitive deps into the repo so we can stage an io_uring transport without bouncing edits through a separate fork repo. Adds an opt-in per-connection io_uring path behind a new build flag, and a driver-only A/B bench against Postgres 18 running in apple/container. Vendored paths: zig/pg/ -- justrach/pg.zig @ 7605502 (was a git+url dep at this same SHA, now .path) zig/pg-deps-buffer/ -- karlseguin/buffer.zig @ 30f9512 (pg.zig transitive dep, identical to upstream) zig/pg-deps-metrics/ -- karlseguin/metrics.zig @ 13d8706 + a one-line Zig 0.16 compat fix on metric.zig:368 (@Type(.{.int=...}) -> @Int(bits, signedness)) New build flag in zig/pg/build.zig: -Diouring=true -- switch pg.Conn transport to io_uring on Linux. Default false, no-op elsewhere. Transport (zig/pg/src/stream.zig): * Per-connection std.os.linux.IoUring, 8 SQEs, no SQPOLL * writeAll: IORING_OP_SEND single-shot, submit + copy_cqe * read: IORING_OP_RECV single-shot, submit + copy_cqe * connect path reuses PlainStream.connect (blocking getaddrinfo) * TLS path unchanged; iouring is skipped when has_openssl=true Bench harness (bench/turbopg/): * bench.zig -- N worker threads, each owns one pg.Conn, hot loop of either 'SELECT 1' or a 50-row generate_series, measures rps for BENCH_DURATION seconds * run.sh -- builds blocking + iouring variants, runs N iters each, prints median rps * Containerfile -- debian:bookworm-slim + Zig 0.16.0 aarch64 * RESULTS.md -- captured medians from one local A/B run Results (one local run, apple/container, Linux 6.18.5, 4 threads, 10s per iter, median of 3-5): query=SELECT 1 blocking=14,967 rps iouring=15,178 rps (+1.4%) query=50-row blocking=13,903 rps iouring=13,849 rps (-0.4%) Both deltas are within run-to-run noise. This is expected for a per-op submit+wait: we trade one recv() syscall for two io_uring_enter syscalls, so on a loopback TCP path the per-op cost is roughly even. The real win requires SQPOLL, batched SEND, and a cooperative scheduler so one ring drives many connections -- all explicitly out of scope for this PR (see RESULTS.md). Per AGENTS.md these numbers are not suitable for release notes, framework comparison tables, or marketing copy. Generated with [Devin](https://cli.devin.ai/docs) Co-Authored-By: Devin <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- bench/turbopg/.gitignore | 5 + bench/turbopg/Containerfile | 25 + bench/turbopg/RESULTS.md | 125 + bench/turbopg/bench.zig | 167 ++ bench/turbopg/build.zig | 33 + bench/turbopg/build.zig.zon | 11 + bench/turbopg/run.sh | 92 + zig/build.zig.zon | 6 +- zig/pg-deps-buffer/.gitignore | 2 + zig/pg-deps-buffer/LICENSE | 19 + zig/pg-deps-buffer/Makefile | 5 + zig/pg-deps-buffer/build.zig | 22 + zig/pg-deps-buffer/readme.md | 60 + zig/pg-deps-buffer/src/buffer.zig | 568 ++++ zig/pg-deps-buffer/src/pool.zig | 152 ++ zig/pg-deps-buffer/src/t.zig | 17 + zig/pg-deps-buffer/test_runner.zig | 294 ++ zig/pg-deps-metrics/.gitignore | 2 + zig/pg-deps-metrics/LICENSE | 19 + zig/pg-deps-metrics/Makefile | 5 + zig/pg-deps-metrics/build.zig | 48 + zig/pg-deps-metrics/build.zig.zon | 6 + zig/pg-deps-metrics/example/lib/lib.zig | 21 + zig/pg-deps-metrics/example/lib/metrics.zig | 49 + zig/pg-deps-metrics/example/main.zig | 38 + zig/pg-deps-metrics/readme.md | 321 +++ zig/pg-deps-metrics/src/counter.zig | 471 ++++ zig/pg-deps-metrics/src/gauge.zig | 523 ++++ zig/pg-deps-metrics/src/histogram.zig | 626 +++++ zig/pg-deps-metrics/src/metric.zig | 670 +++++ zig/pg-deps-metrics/src/metrics.zig | 156 ++ zig/pg-deps-metrics/src/registry.zig | 27 + zig/pg-deps-metrics/src/t.zig | 24 + zig/pg-deps-metrics/test_runner.zig | 294 ++ zig/pg/.gitignore | 4 + zig/pg/LICENSE | 19 + zig/pg/Makefile | 20 + zig/pg/build.zig | 98 + zig/pg/build.zig.zon | 17 + zig/pg/example/build.zig | 27 + zig/pg/example/build.zig.zon | 11 + zig/pg/example/main.zig | 271 ++ zig/pg/readme.md | 695 +++++ zig/pg/src/auth.zig | 387 +++ zig/pg/src/conn.zig | 2370 +++++++++++++++++ zig/pg/src/lib.zig | 335 +++ zig/pg/src/listener.zig | 269 ++ zig/pg/src/metrics.zig | 56 + zig/pg/src/pg.zig | 43 + zig/pg/src/pool.zig | 616 +++++ zig/pg/src/proto.zig | 19 + .../src/proto/AuthenticationSASLContinue.zig | 51 + zig/pg/src/proto/AuthenticationSASLFinal.zig | 54 + zig/pg/src/proto/CommandComplete.zig | 77 + zig/pg/src/proto/Describe.zig | 64 + zig/pg/src/proto/Error.zig | 127 + zig/pg/src/proto/Execute.zig | 59 + zig/pg/src/proto/NotificationResponse.zig | 32 + zig/pg/src/proto/Parse.zig | 62 + zig/pg/src/proto/PasswordMessage.zig | 38 + zig/pg/src/proto/Query.zig | 38 + zig/pg/src/proto/SASLInitialResponse.zig | 48 + zig/pg/src/proto/SASLResponse.zig | 41 + zig/pg/src/proto/StartupMessage.zig | 76 + zig/pg/src/proto/Sync.zig | 24 + zig/pg/src/proto/_proto.zig | 85 + zig/pg/src/proto/authentication_request.zig | 143 + zig/pg/src/reader.zig | 727 +++++ zig/pg/src/result.zig | 2180 +++++++++++++++ zig/pg/src/stmt.zig | 407 +++ zig/pg/src/stream.zig | 338 +++ zig/pg/src/t.zig | 241 ++ zig/pg/src/types.zig | 1643 ++++++++++++ zig/pg/src/types/cidr.zig | 41 + zig/pg/src/types/numeric.zig | 423 +++ zig/pg/src/types/vector.zig | 205 ++ zig/pg/test_runner.zig | 294 ++ zig/pg/tests/client.crt | 22 + zig/pg/tests/client.key | 28 + zig/pg/tests/compose.yml | 25 + zig/pg/tests/init_ssl.sql | 4 + zig/pg/tests/pg_hba.conf | 6 + zig/pg/tests/postgresql.conf | 13 + zig/pg/tests/root.crt | 80 + zig/pg/tests/root.srl | 1 + 85 files changed, 17855 insertions(+), 2 deletions(-) create mode 100644 bench/turbopg/.gitignore create mode 100644 bench/turbopg/Containerfile create mode 100644 bench/turbopg/RESULTS.md create mode 100644 bench/turbopg/bench.zig create mode 100644 bench/turbopg/build.zig create mode 100644 bench/turbopg/build.zig.zon create mode 100755 bench/turbopg/run.sh create mode 100644 zig/pg-deps-buffer/.gitignore create mode 100644 zig/pg-deps-buffer/LICENSE create mode 100644 zig/pg-deps-buffer/Makefile create mode 100644 zig/pg-deps-buffer/build.zig create mode 100644 zig/pg-deps-buffer/readme.md create mode 100644 zig/pg-deps-buffer/src/buffer.zig create mode 100644 zig/pg-deps-buffer/src/pool.zig create mode 100644 zig/pg-deps-buffer/src/t.zig create mode 100644 zig/pg-deps-buffer/test_runner.zig create mode 100644 zig/pg-deps-metrics/.gitignore create mode 100644 zig/pg-deps-metrics/LICENSE create mode 100644 zig/pg-deps-metrics/Makefile create mode 100644 zig/pg-deps-metrics/build.zig create mode 100644 zig/pg-deps-metrics/build.zig.zon create mode 100644 zig/pg-deps-metrics/example/lib/lib.zig create mode 100644 zig/pg-deps-metrics/example/lib/metrics.zig create mode 100644 zig/pg-deps-metrics/example/main.zig create mode 100644 zig/pg-deps-metrics/readme.md create mode 100644 zig/pg-deps-metrics/src/counter.zig create mode 100644 zig/pg-deps-metrics/src/gauge.zig create mode 100644 zig/pg-deps-metrics/src/histogram.zig create mode 100644 zig/pg-deps-metrics/src/metric.zig create mode 100644 zig/pg-deps-metrics/src/metrics.zig create mode 100644 zig/pg-deps-metrics/src/registry.zig create mode 100644 zig/pg-deps-metrics/src/t.zig create mode 100644 zig/pg-deps-metrics/test_runner.zig create mode 100644 zig/pg/.gitignore create mode 100644 zig/pg/LICENSE create mode 100644 zig/pg/Makefile create mode 100644 zig/pg/build.zig create mode 100644 zig/pg/build.zig.zon create mode 100644 zig/pg/example/build.zig create mode 100644 zig/pg/example/build.zig.zon create mode 100644 zig/pg/example/main.zig create mode 100644 zig/pg/readme.md create mode 100644 zig/pg/src/auth.zig create mode 100644 zig/pg/src/conn.zig create mode 100644 zig/pg/src/lib.zig create mode 100644 zig/pg/src/listener.zig create mode 100644 zig/pg/src/metrics.zig create mode 100644 zig/pg/src/pg.zig create mode 100644 zig/pg/src/pool.zig create mode 100644 zig/pg/src/proto.zig create mode 100644 zig/pg/src/proto/AuthenticationSASLContinue.zig create mode 100644 zig/pg/src/proto/AuthenticationSASLFinal.zig create mode 100644 zig/pg/src/proto/CommandComplete.zig create mode 100644 zig/pg/src/proto/Describe.zig create mode 100644 zig/pg/src/proto/Error.zig create mode 100644 zig/pg/src/proto/Execute.zig create mode 100644 zig/pg/src/proto/NotificationResponse.zig create mode 100644 zig/pg/src/proto/Parse.zig create mode 100644 zig/pg/src/proto/PasswordMessage.zig create mode 100644 zig/pg/src/proto/Query.zig create mode 100644 zig/pg/src/proto/SASLInitialResponse.zig create mode 100644 zig/pg/src/proto/SASLResponse.zig create mode 100644 zig/pg/src/proto/StartupMessage.zig create mode 100644 zig/pg/src/proto/Sync.zig create mode 100644 zig/pg/src/proto/_proto.zig create mode 100644 zig/pg/src/proto/authentication_request.zig create mode 100644 zig/pg/src/reader.zig create mode 100644 zig/pg/src/result.zig create mode 100644 zig/pg/src/stmt.zig create mode 100644 zig/pg/src/stream.zig create mode 100644 zig/pg/src/t.zig create mode 100644 zig/pg/src/types.zig create mode 100644 zig/pg/src/types/cidr.zig create mode 100644 zig/pg/src/types/numeric.zig create mode 100644 zig/pg/src/types/vector.zig create mode 100644 zig/pg/test_runner.zig create mode 100644 zig/pg/tests/client.crt create mode 100644 zig/pg/tests/client.key create mode 100644 zig/pg/tests/compose.yml create mode 100644 zig/pg/tests/init_ssl.sql create mode 100644 zig/pg/tests/pg_hba.conf create mode 100644 zig/pg/tests/postgresql.conf create mode 100644 zig/pg/tests/root.crt create mode 100644 zig/pg/tests/root.srl diff --git a/bench/turbopg/.gitignore b/bench/turbopg/.gitignore new file mode 100644 index 0000000..912b503 --- /dev/null +++ b/bench/turbopg/.gitignore @@ -0,0 +1,5 @@ +# Built artifacts and per-iter bench outputs +zig-out/ +.zig-cache/ +zig-pkg/ +results/ diff --git a/bench/turbopg/Containerfile b/bench/turbopg/Containerfile new file mode 100644 index 0000000..fbf0a39 --- /dev/null +++ b/bench/turbopg/Containerfile @@ -0,0 +1,25 @@ +# Build + run the turbopg driver-only A/B bench against a Postgres 18 +# container. Per AGENTS.md: results from this container are NOT suitable +# for release notes or comparison tables. +FROM debian:bookworm-slim + +ENV DEBIAN_FRONTEND=noninteractive +ENV PATH="/opt/zig:${PATH}" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates curl xz-utils postgresql-client python3 \ + && rm -rf /var/lib/apt/lists/* + +# Zig 0.16.0 aarch64-linux (matches host arch in apple/container on M-series). +RUN mkdir -p /opt && cd /opt \ + && curl -fsSL https://ziglang.org/download/0.16.0/zig-aarch64-linux-0.16.0.tar.xz -o zig.tar.xz \ + && tar -xJf zig.tar.xz \ + && mv zig-aarch64-linux-0.16.0 zig \ + && rm zig.tar.xz + +WORKDIR /work + +COPY run.sh /usr/local/bin/run.sh +RUN chmod +x /usr/local/bin/run.sh + +ENTRYPOINT ["/usr/local/bin/run.sh"] diff --git a/bench/turbopg/RESULTS.md b/bench/turbopg/RESULTS.md new file mode 100644 index 0000000..71bfbdc --- /dev/null +++ b/bench/turbopg/RESULTS.md @@ -0,0 +1,125 @@ +# turbopg / pg.zig — blocking vs io_uring transport A/B + +> **Scope:** the only code path that differs between the two builds is +> the `Stream` struct in `zig/pg/src/stream.zig`. Everything else +> (wire-protocol codec, `Conn`, `Pool`, result decoding) is identical. +> The io_uring path is **one ring per connection**, single in-flight +> SQE per `read` / `writeAll` (submit + `copy_cqe`). **No scheduler, no +> SQPOLL, no multi-shot, no registered fds.** The real async win +> needs all of those on top. Per `AGENTS.md`, do not cite these +> numbers in release notes, framework comparison tables, or +> marketing copy. + +## Environment + +- Apple `container` CLI (two Linux microVMs on macOS, shared vnic) +- DB container: `postgres:18`, trust auth, default shm, `192.168.64.14` +- Bench container: `debian:bookworm-slim` + Zig 0.16.0 aarch64-linux +- Kernel (both VMs): `Linux 6.18.5 aarch64` +- pg.zig build: `-Doptimize=ReleaseFast` +- Workload: N worker threads, each owns one `pg.Conn`, hot loop of a + single query shape for the duration +- Network: container-to-container via the default `container` network + +## Variants + +| label | build flag | transport used | +|------------|--------------------|--------------------------| +| `blocking` | `-Diouring=false` | `PlainStream` (read / send via libc syscalls) | +| `iouring` | `-Diouring=true` | `IoUringStream` (one ring per conn, `IORING_OP_SEND` / `IORING_OP_RECV`, submit + `copy_cqe`) | + +Both variants share the same connect path, auth path, and `Reader`. + +## Workloads + +| id | SQL | notes | +|----|----------------------------------------------|-------| +| 1 | `SELECT 1` | smallest round trip | +| 2 | `SELECT id FROM generate_series(1, 50) AS id` | 50 rows back per query, ~300 B response | + +## Results + +4 worker threads, 10 s per run. + +### query=1 (`SELECT 1`), median of 3 + +| variant | median rps | min | max | +|----------|-----------:|--------:|--------:| +| blocking | 14,967.12 | 14,878 | 15,140 | +| iouring | 15,178.03 | 15,127 | 15,230 | + +Δ: **+1.4 %**, well within run-to-run noise with n=3. + +### query=2 (`SELECT` generate_series(1,50)), median of 5 + +| variant | median rps | min | max | +|----------|-----------:|--------:|--------:| +| blocking | 13,903.37 | 13,770 | 14,001 | +| iouring | 13,849.02 | 13,752 | 13,929 | + +Δ: **−0.4 %**, again within noise. + +## What this tells us + +1. Per-connection single-SQE io_uring is **roughly a wash** on a + driver that was already blocking-sync. Expected: we trade one + `recv()` syscall for one `io_uring_enter(submit)` + + `io_uring_enter(wait_cqe)`, so the per-op syscall cost is + approximately even. Kernel fastpath for small receives on a local + TCP loop is already very fast. +2. On q=2 (bigger response, more bytes per recv) io_uring trends + slightly slower — consistent with the extra ring bookkeeping + overhead showing up once the per-op cost matters at all. +3. No regressions, no query errors. The abstraction and the ring + plumbing work correctly for the full `Conn` lifetime (connect, + startup, simple query, extended query, close). + +## What would actually move the needle + +The next items (deliberately **not** in this PR): + +- **SQPOLL** so submitting no longer needs an `io_uring_enter` + syscall in the common case. +- **Batched send**: queue up the parse/bind/describe/execute/sync + packets into one `IORING_OP_SEND` instead of the current per-packet + writes. +- **A cooperative scheduler** so one ring drives N connections + concurrently and a waiting query yields the thread rather than + blocking on `copy_cqe`. This is the real win and turns this from a + neutral change into an actual throughput improvement. + +## Caveats (read these) + +- 3–5 iterations, 10 s each, one client, one DB. Enough to catch + big regressions, not enough to publish percentage claims. +- `postgres:18` with default config, no tuning, `trust` auth. +- Apple `container` runs each container in its own microVM; cross-VM + network adds a real-ish TCP path but results will not match a + co-located production setup. +- No TLS. The io_uring path is plaintext-only in this PR; + `-Dopenssl_lib_name=...` still picks TLS + the old blocking socket. + +## Reproducing + +```bash +# 1. Start Postgres 18 +container run -d --name pg18 -e POSTGRES_HOST_AUTH_METHOD=trust postgres:18 + +# 2. Build the bench image once +container build -t turbopg-bench \ + -f bench/turbopg/Containerfile bench/turbopg + +# 3. Find the pg18 IP (field ADDR in `container ls`) +PG_IP=$(container ls | awk '$1=="pg18" {print $6}' | cut -d/ -f1) + +# 4. Run +container run --rm -m 4G -c 4 \ + -e PGHOST="$PG_IP" \ + -e BENCH_QUERY=1 \ + -e BENCH_ITERS=5 \ + -v "$PWD":/work \ + turbopg-bench +``` + +Override `BENCH_QUERY` (1 or 2), `BENCH_THREADS`, `BENCH_DURATION`, +`BENCH_ITERS` as needed. diff --git a/bench/turbopg/bench.zig b/bench/turbopg/bench.zig new file mode 100644 index 0000000..af4bf47 --- /dev/null +++ b/bench/turbopg/bench.zig @@ -0,0 +1,167 @@ +//! Throughput bench for the vendored pg.zig driver. +//! +//! Spawns N worker threads, each running a hot loop of either: +//! * `SELECT 1` +//! * `SELECT generate_series(1,50) AS id` +//! +//! against a single Postgres instance. Each thread owns its own +//! `pg.Conn` (no pool contention, no scheduler), so the only thing +//! varying between the blocking-build and the iouring-build is the +//! stream transport in `zig/pg/src/stream.zig`. +//! +//! Output is plain text: +//! +//! transport= threads=N duration=Ds query= +//! queries=Q rows=R rps=X.YZ +//! +//! Per AGENTS.md, do NOT cite these numbers in release notes. + +const std = @import("std"); +const pg = @import("pg"); + +const Args = struct { + host: []const u8 = "127.0.0.1", + port: u16 = 5432, + user: []const u8 = "postgres", + database: []const u8 = "postgres", + threads: usize = 4, + duration_s: u64 = 10, + query_id: u8 = 1, + label: []const u8 = "blocking", + + fn fromEnv(_: std.mem.Allocator) !Args { + var a: Args = .{}; + if (getenv("PGHOST")) |v| a.host = v; + if (getenv("PGPORT")) |v| a.port = try std.fmt.parseInt(u16, v, 10); + if (getenv("PGUSER")) |v| a.user = v; + if (getenv("PGDATABASE")) |v| a.database = v; + if (getenv("BENCH_THREADS")) |v| a.threads = try std.fmt.parseInt(usize, v, 10); + if (getenv("BENCH_DURATION")) |v| a.duration_s = try std.fmt.parseInt(u64, v, 10); + if (getenv("BENCH_QUERY")) |v| a.query_id = try std.fmt.parseInt(u8, v, 10); + if (getenv("BENCH_LABEL")) |v| a.label = v; + return a; + } +}; + +fn nowNs() i128 { + var ts: std.posix.timespec = undefined; + _ = std.posix.system.clock_gettime(.MONOTONIC, &ts); + return @as(i128, ts.sec) * 1_000_000_000 + @as(i128, ts.nsec); +} + +fn sleepSeconds(secs: u64) void { + var ts: std.posix.timespec = .{ .sec = @intCast(secs), .nsec = 0 }; + _ = std.posix.system.nanosleep(&ts, null); +} + +fn getenv(key: []const u8) ?[]const u8 { + // libc getenv; each worker thread only reads it, never mutates. + var buf: [128]u8 = undefined; + const key_z = std.fmt.bufPrintZ(&buf, "{s}", .{key}) catch return null; + const raw = std.c.getenv(key_z.ptr) orelse return null; + return std.mem.span(raw); +} + +const ThreadStats = struct { + queries: u64 = 0, + rows: u64 = 0, + err_count: u64 = 0, +}; + +fn workerLoop( + args: *const Args, + stop_flag: *std.atomic.Value(bool), + stats: *ThreadStats, +) !void { + const allocator = std.heap.smp_allocator; + + var conn = try pg.Conn.open(allocator, .{ + .host = args.host, + .port = args.port, + }); + defer conn.deinit(); + + try conn.auth(.{ + .username = args.user, + .database = args.database, + .timeout = 10_000, + }); + + const sql = switch (args.query_id) { + 1 => "SELECT 1", + 2 => "SELECT id FROM generate_series(1, 50) AS id", + else => "SELECT 1", + }; + + while (!stop_flag.load(.acquire)) { + var result = conn.query(sql, .{}) catch |err| { + stats.err_count += 1; + if (stats.err_count > 10) return err; + continue; + }; + defer result.deinit(); + + while (result.next() catch null) |_| { + stats.rows += 1; + } + stats.queries += 1; + } +} + +fn workerEntry( + args: *const Args, + stop_flag: *std.atomic.Value(bool), + stats: *ThreadStats, +) void { + workerLoop(args, stop_flag, stats) catch |err| { + std.debug.print("worker error: {s}\n", .{@errorName(err)}); + }; +} + +pub fn main() !void { + const allocator = std.heap.smp_allocator; + + const args = try Args.fromEnv(allocator); + + std.debug.print( + "[bench] transport={s} threads={d} duration={d}s query={d} host={s}:{d}\n", + .{ args.label, args.threads, args.duration_s, args.query_id, args.host, args.port }, + ); + + var stop_flag = std.atomic.Value(bool).init(false); + + const stats = try allocator.alloc(ThreadStats, args.threads); + @memset(stats, .{}); + defer allocator.free(stats); + + const threads = try allocator.alloc(std.Thread, args.threads); + defer allocator.free(threads); + + const t_start = nowNs(); + for (threads, 0..) |*t, i| { + t.* = try std.Thread.spawn(.{}, workerEntry, .{ &args, &stop_flag, &stats[i] }); + } + + sleepSeconds(args.duration_s); + + stop_flag.store(true, .release); + for (threads) |t| t.join(); + const t_end = nowNs(); + + var total_q: u64 = 0; + var total_r: u64 = 0; + var total_e: u64 = 0; + for (stats) |s| { + total_q += s.queries; + total_r += s.rows; + total_e += s.err_count; + } + const elapsed_s: f64 = @as(f64, @floatFromInt(t_end - t_start)) / 1e9; + const rps: f64 = @as(f64, @floatFromInt(total_q)) / elapsed_s; + const rows_ps: f64 = @as(f64, @floatFromInt(total_r)) / elapsed_s; + + std.debug.print( + "[bench] result transport={s} queries={d} rows={d} errors={d} elapsed={d:.2}s rps={d:.2} rows_ps={d:.2}\n", + .{ args.label, total_q, total_r, total_e, elapsed_s, rps, rows_ps }, + ); +} diff --git a/bench/turbopg/build.zig b/bench/turbopg/build.zig new file mode 100644 index 0000000..b50e9f9 --- /dev/null +++ b/bench/turbopg/build.zig @@ -0,0 +1,33 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + const iouring = b.option(bool, "iouring", "Use io_uring transport") orelse false; + + const pg_dep = b.dependency("pg", .{ + .target = target, + .optimize = optimize, + .iouring = iouring, + }); + + const exe = b.addExecutable(.{ + .name = "turbopg-bench", + .root_module = b.createModule(.{ + .target = target, + .optimize = optimize, + .root_source_file = b.path("bench.zig"), + .link_libc = true, + .imports = &.{ + .{ .name = "pg", .module = pg_dep.module("pg") }, + }, + }), + }); + + b.installArtifact(exe); + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| run_cmd.addArgs(args); + const run_step = b.step("run", "Run the bench"); + run_step.dependOn(&run_cmd.step); +} diff --git a/bench/turbopg/build.zig.zon b/bench/turbopg/build.zig.zon new file mode 100644 index 0000000..3a296fc7 --- /dev/null +++ b/bench/turbopg/build.zig.zon @@ -0,0 +1,11 @@ +.{ + .name = .turbopg_bench, + .version = "0.0.0", + .fingerprint = 0x86359a9f291096a0, + .paths = .{""}, + .dependencies = .{ + .pg = .{ + .path = "../../zig/pg", + }, + }, +} diff --git a/bench/turbopg/run.sh b/bench/turbopg/run.sh new file mode 100755 index 0000000..b8d79fc --- /dev/null +++ b/bench/turbopg/run.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# Runs the turbopg A/B bench against a PG18 instance reachable at +# $PGHOST:$PGPORT. Builds both variants in-place inside this container +# (using the bind-mounted repo at /work). +# +# Env: +# PGHOST, PGPORT, PGUSER, PGDATABASE +# BENCH_THREADS, BENCH_DURATION, BENCH_QUERY, BENCH_ITERS +set -euo pipefail + +PGHOST="${PGHOST:-127.0.0.1}" +PGPORT="${PGPORT:-5432}" +PGUSER="${PGUSER:-postgres}" +PGDATABASE="${PGDATABASE:-postgres}" +BENCH_THREADS="${BENCH_THREADS:-4}" +BENCH_DURATION="${BENCH_DURATION:-10}" +BENCH_QUERY="${BENCH_QUERY:-1}" +BENCH_ITERS="${BENCH_ITERS:-3}" + +export PGHOST PGPORT PGUSER PGDATABASE BENCH_THREADS BENCH_DURATION BENCH_QUERY + +echo "[bench] waiting for postgres at $PGHOST:$PGPORT ..." +for i in $(seq 1 60); do + if pg_isready -h "$PGHOST" -p "$PGPORT" -U "$PGUSER" -q; then + echo "[bench] postgres up after ${i}s" + break + fi + sleep 1 +done +if ! pg_isready -h "$PGHOST" -p "$PGPORT" -U "$PGUSER" -q; then + echo "[bench] postgres never came up" >&2 + exit 1 +fi + +cd /work/bench/turbopg + +build_variant() { + local label="$1" + local flag="$2" + echo "[bench] ==> building variant: $label (iouring=$flag)" + rm -rf zig-out .zig-cache + zig build -Doptimize=ReleaseFast -Diouring="$flag" +} + +run_variant() { + local label="$1" + echo "[bench] ==> running $BENCH_ITERS iters: variant=$label threads=$BENCH_THREADS duration=${BENCH_DURATION}s query=$BENCH_QUERY" + mkdir -p /work/bench/turbopg/results + for i in $(seq 1 "$BENCH_ITERS"); do + BENCH_LABEL="$label" \ + ./zig-out/bin/turbopg-bench \ + 2> "/work/bench/turbopg/results/${label}-q${BENCH_QUERY}-${i}.txt" \ + || { cat "/work/bench/turbopg/results/${label}-q${BENCH_QUERY}-${i}.txt"; exit 1; } + grep "^\[bench\] result" "/work/bench/turbopg/results/${label}-q${BENCH_QUERY}-${i}.txt" + done +} + +summarize() { + local label="$1" + local q="$2" + # Pull rps from each iteration, sort, take median + python3 - <= pos) { + self.pos = 0; + return; + } + self.pos = pos - n; + } + + pub fn skip(self: *Buffer, n: usize) !View { + try self.ensureUnusedCapacity(n); + const pos = self.pos; + self.pos = pos + n; + return .{ + .pos = pos, + .buf = self, + }; + } + + pub fn writeByte(self: *Buffer, b: u8) !void { + try self.ensureUnusedCapacity(1); + self.writeByteAssumeCapacity(b); + } + + pub fn writeByteAssumeCapacity(self: *Buffer, b: u8) void { + const pos = self.pos; + writeByteInto(self.buf, pos, b); + self.pos = pos + 1; + } + + pub fn writeByteNTimes(self: *Buffer, b: u8, n: usize) !void { + try self.ensureUnusedCapacity(n); + const pos = self.pos; + writeByteNTimesInto(self.buf, pos, b, n); + self.pos = pos + n; + } + + pub fn write(self: *Buffer, data: []const u8) !void { + try self.ensureUnusedCapacity(data.len); + self.writeAssumeCapacity(data); + } + + pub fn writeAssumeCapacity(self: *Buffer, data: []const u8) void { + const pos = self.pos; + writeInto(self.buf, pos, data); + self.pos = pos + data.len; + } + + // unsafe + pub fn writeAt(self: *Buffer, data: []const u8, pos: usize) void { + @memcpy(self.buf[pos .. pos + data.len], data); + } + + pub fn writeU16Little(self: *Buffer, value: u16) !void { + return self.writeIntT(u16, value, .little); + } + + pub fn writeU32Little(self: *Buffer, value: u32) !void { + return self.writeIntT(u32, value, .little); + } + + pub fn writeU64Little(self: *Buffer, value: u64) !void { + return self.writeIntT(u64, value, .little); + } + + pub fn writeIntLittle(self: *Buffer, comptime T: type, value: T) !void { + return self.writeIntT(T, value, .little); + } + + pub fn writeU16Big(self: *Buffer, value: u16) !void { + return self.writeIntT(u16, value, .big); + } + + pub fn writeU32Big(self: *Buffer, value: u32) !void { + return self.writeIntT(u32, value, .big); + } + + pub fn writeU64Big(self: *Buffer, value: u64) !void { + return self.writeIntT(u64, value, .big); + } + + pub fn writeIntBig(self: *Buffer, comptime T: type, value: T) !void { + return self.writeIntT(T, value, .big); + } + + pub fn writeIntT(self: *Buffer, comptime T: type, value: T, endian: Endian) !void { + const l = @divExact(@typeInfo(T).int.bits, 8); + const pos = self.pos; + try self.ensureUnusedCapacity(l); + writeIntInto(T, self.buf, pos, value, l, endian); + self.pos = pos + l; + } + + pub fn ensureUnusedCapacity(self: *Buffer, n: usize) !void { + return self.ensureTotalCapacity(self.pos + n); + } + + pub fn ensureTotalCapacity(self: *Buffer, required_capacity: usize) !void { + const buf = self.buf; + if (required_capacity <= buf.len) { + return; + } + + // from std.ArrayList + var new_capacity = buf.len; + while (true) { + new_capacity +|= new_capacity / 2 + 8; + if (new_capacity >= required_capacity) break; + } + + const allocator = self._da orelse self._a; + if (buf.ptr == self.static.ptr or !allocator.resize(buf, new_capacity)) { + const new_buffer = try allocator.alloc(u8, new_capacity); + @memcpy(new_buffer[0..buf.len], buf); + + if (self.dynamic) |dyn| { + allocator.free(dyn); + } + + self.buf = new_buffer; + self.dynamic = new_buffer; + } else { + const new_buffer = buf.ptr[0..new_capacity]; + self.buf = new_buffer; + self.dynamic = new_buffer; + } + } + + pub fn copy(self: Buffer, allocator: Allocator) ![]const u8 { + const pos = self.pos; + const c = try allocator.alloc(u8, pos); + @memcpy(c, self.buf[0..pos]); + return c; + } +}; + +pub const View = struct { + pos: usize, + buf: *Buffer, + + pub fn writeByte(self: *View, b: u8) void { + const pos = self.pos; + writeByteInto(self.buf.buf, pos, b); + self.pos = pos + 1; + } + + pub fn writeByteNTimes(self: *View, b: u8, n: usize) void { + const pos = self.pos; + writeByteNTimesInto(self.buf.buf, pos, b, n); + self.pos = pos + n; + } + + pub fn write(self: *View, data: []const u8) void { + const pos = self.pos; + writeInto(self.buf.buf, pos, data); + self.pos = pos + data.len; + } + + pub fn writeU16(self: *View, value: u16) void { + return self.writeIntT(u16, value, self.endian); + } + + pub fn writeI16(self: *View, value: i16) void { + return self.writeIntT(i16, value, self.endian); + } + + pub fn writeU32(self: *View, value: u32) void { + return self.writeIntT(u32, value, self.endian); + } + + pub fn writeI32(self: *View, value: i32) void { + return self.writeIntT(i32, value, self.endian); + } + + pub fn writeU64(self: *View, value: u64) void { + return self.writeIntT(u64, value, self.endian); + } + + pub fn writeI64(self: *View, value: i64) void { + return self.writeIntT(i64, value, self.endian); + } + + pub fn writeU16Little(self: *View, value: u16) void { + return self.writeIntT(u16, value, .little); + } + + pub fn writeI16Little(self: *View, value: i16) void { + return self.writeIntT(i16, value, .little); + } + + pub fn writeU32Little(self: *View, value: u32) void { + return self.writeIntT(u32, value, .little); + } + + pub fn writeI32Little(self: *View, value: i32) void { + return self.writeIntT(i32, value, .little); + } + + pub fn writeU64Little(self: *View, value: u64) void { + return self.writeIntT(u64, value, .little); + } + + pub fn writeI64Little(self: *View, value: i64) void { + return self.writeIntT(i64, value, .little); + } + + pub fn writeIntLittle(self: *View, comptime T: type, value: T) void { + self.writeIntT(T, value, .little); + } + + pub fn writeU16Big(self: *View, value: u16) void { + return self.writeIntT(u16, value, .big); + } + + pub fn writeI16Big(self: *View, value: i16) void { + return self.writeIntT(i16, value, .big); + } + + pub fn writeU32Big(self: *View, value: u32) void { + return self.writeIntT(u32, value, .big); + } + + pub fn writeI32Big(self: *View, value: i32) void { + return self.writeIntT(i32, value, .big); + } + + pub fn writeU64Big(self: *View, value: u64) void { + return self.writeIntT(u64, value, .big); + } + + pub fn writeI64Big(self: *View, value: i64) void { + return self.writeIntT(i64, value, .big); + } + + pub fn writeIntBig(self: *View, comptime T: type, value: T) void { + self.writeIntT(T, value, .big); + } + + pub fn writeIntT(self: *View, comptime T: type, value: T, endian: Endian) void { + const l = @divExact(@typeInfo(T).int.bits, 8); + const pos = self.pos; + writeIntInto(T, self.buf.buf, pos, value, l, endian); + self.pos = pos + l; + } +}; + +// Functions that write for either a *StringBuilder or a *View +inline fn writeInto(buf: []u8, pos: usize, data: []const u8) void { + const end_pos = pos + data.len; + @memcpy(buf[pos..end_pos], data); +} + +inline fn writeByteInto(buf: []u8, pos: usize, b: u8) void { + buf[pos] = b; +} + +inline fn writeByteNTimesInto(buf: []u8, pos: usize, b: u8, n: usize) void { + for (0..n) |offset| { + buf[pos + offset] = b; + } +} + +inline fn writeIntInto(comptime T: type, buf: []u8, pos: usize, value: T, l: usize, endian: Endian) void { + const end_pos = pos + l; + std.mem.writeInt(T, buf[pos..end_pos][0..l], value, endian); +} + +const t = @import("t.zig"); +test { + std.testing.refAllDecls(@This()); +} + +test "growth" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + // we reset at the end of the loop, and things should work the exact same + // after a reset + for (0..5) |_| { + try t.expectEqual(0, w.len()); + try w.writeByte('o'); + try t.expectEqual(1, w.len()); + try t.expectString("o", w.string()); + try t.expectEqual(null, w.dynamic); + + // stays in static + try w.write("ver 9000!"); + try t.expectEqual(10, w.len()); + try t.expectString("over 9000!", w.string()); + try t.expectEqual(null, w.dynamic); + + // grows into dynamic + try w.write("!!!"); + try t.expectEqual(13, w.len()); + try t.expectString("over 9000!!!!", w.string()); + try t.expectEqual(false, w.dynamic == null); + + try w.write("If you were to run this code, you'd almost certainly see a segmentation fault (aka, segfault). We create a Response which involves creating an ArenaAllocator and from that, an Allocator. This allocator is then used to format our string. For the purpose of this example, we create a 2nd response and immediately free it. We need this for the same reason that warning1 in our first example printed an almost ok value: we want to re-initialize the memory in our init function stack."); + try t.expectEqual(492, w.len()); + try t.expectString("over 9000!!!!If you were to run this code, you'd almost certainly see a segmentation fault (aka, segfault). We create a Response which involves creating an ArenaAllocator and from that, an Allocator. This allocator is then used to format our string. For the purpose of this example, we create a 2nd response and immediately free it. We need this for the same reason that warning1 in our first example printed an almost ok value: we want to re-initialize the memory in our init function stack.", w.string()); + + w.reset(); + } +} + +test "growth with int" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + try w.writeU64Big(9000); + try w.writeU64Big(10000); + try t.expectSlice(u8, &.{ 0, 0, 0, 0, 0, 0, 0x23, 0x28 }, w.string()[0..8]); + try t.expectSlice(u8, &.{ 0, 0, 0, 0, 0, 0, 0x27, 0x10 }, w.string()[8..16]); +} + +test "truncate" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + w.truncate(100); + try t.expectEqual(0, w.len()); + + try w.write("hello world!1"); + + w.truncate(0); + try t.expectEqual(13, w.len()); + try t.expectString("hello world!1", w.string()); + + w.truncate(1); + try t.expectEqual(12, w.len()); + try t.expectString("hello world!", w.string()); + + w.truncate(5); + try t.expectEqual(7, w.len()); + try t.expectString("hello w", w.string()); +} + +test "reset without clear" { + var w = try Buffer.init(t.allocator, 5); + defer w.deinit(); + + try w.write("hello world!1"); + try t.expectString("hello world!1", w.string()); + + w.resetRetainingCapacity(); + try t.expectEqual(0, w.len()); + try t.expectEqual(false, w.dynamic == null); + try w.write("over 9000"); + try w.write("over 9000"); +} + +test "fuzz" { + var control: std.ArrayList(u8) = .empty; + defer control.deinit(t.allocator); + + var r = t.getRandom(); + const random = r.random(); + + var arena = std.heap.ArenaAllocator.init(t.allocator); + defer arena.deinit(); + + const aa = arena.allocator(); + + for (1..100) |_| { + var w = try Buffer.init(t.allocator, random.uintAtMost(u16, 1000) + 1); + defer w.deinit(); + + for (1..100) |_| { + const input = testString(aa, random); + try w.write(input); + try control.appendSlice(t.allocator, input); + try t.expectString(control.items, w.string()); + } + w.reset(); + control.clearRetainingCapacity(); + _ = arena.reset(.free_all); + } +} + +test "writer" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + try std.json.Stringify.value(.{ .over = 9000, .spice = "must flow", .ok = true }, .{}, &w.interface); + try t.expectString("{\"over\":9000,\"spice\":\"must flow\",\"ok\":true}", w.string()); +} + +test "copy" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + try w.write("hello!!"); + const c = try w.copy(t.allocator); + defer t.allocator.free(c); + try t.expectString("hello!!", c); +} + +test "write little" { + var w = try Buffer.init(t.allocator, 20); + defer w.deinit(); + try w.writeU64Little(11234567890123456789); + try t.expectSlice(u8, &[_]u8{ 21, 129, 209, 7, 249, 51, 233, 155 }, w.string()); + + try w.writeU32Little(3283856184); + try t.expectSlice(u8, &[_]u8{ 21, 129, 209, 7, 249, 51, 233, 155, 56, 171, 187, 195 }, w.string()); + + try w.writeU16Little(15000); + try t.expectSlice(u8, &[_]u8{ 21, 129, 209, 7, 249, 51, 233, 155, 56, 171, 187, 195, 152, 58 }, w.string()); +} + +test "write big" { + var w = try Buffer.init(t.allocator, 20); + defer w.deinit(); + try w.writeU64Big(11234567890123456789); + try t.expectSlice(u8, &[_]u8{ 155, 233, 51, 249, 7, 209, 129, 21 }, w.string()); + + try w.writeU32Big(3283856184); + try t.expectSlice(u8, &[_]u8{ 155, 233, 51, 249, 7, 209, 129, 21, 195, 187, 171, 56 }, w.string()); + + try w.writeU16Big(15000); + try t.expectSlice(u8, &[_]u8{ 155, 233, 51, 249, 7, 209, 129, 21, 195, 187, 171, 56, 58, 152 }, w.string()); +} + +test "skip & view" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + + var view = try w.skip(4); + try w.write("hello world!!"); + + view.writeU32Big(@intCast(w.len() - 4)); + + try w.writeByte('\n'); + try t.expectSlice(u8, &[_]u8{ 0, 0, 0, 13, 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', '!', '\n' }, w.string()); +} + +test "writeAt" { + var w = try Buffer.init(t.allocator, 200); + defer w.deinit(); + + try w.write("hello"); + try w.write(&.{ 0, 0, 0, 0, 0 }); + try w.write("world"); + + w.writeAt(" ", 5); + w.writeAt("123 ", 6); + try t.expectString("hello 123 world", w.string()); +} + +test "cString" { + var w = try Buffer.init(t.allocator, 10); + defer w.deinit(); + try t.expectString("", try w.cString()); + + try w.write("123456789"); + try t.expectString("123456789", try w.cString()); + + try w.write("0"); + try t.expectString("1234567890", try w.cString()); +} + +fn testString(allocator: Allocator, random: std.Random) []const u8 { + var s = allocator.alloc(u8, random.uintAtMost(u8, 100) + 1) catch unreachable; + for (0..s.len) |i| { + s[i] = random.uintAtMost(u8, 90) + 32; + } + return s; +} diff --git a/zig/pg-deps-buffer/src/pool.zig b/zig/pg-deps-buffer/src/pool.zig new file mode 100644 index 0000000..4b08518 --- /dev/null +++ b/zig/pg-deps-buffer/src/pool.zig @@ -0,0 +1,152 @@ +const std = @import("std"); +const builtin = @import("builtin"); + +const Buffer = @import("buffer.zig").Buffer; + +const Mutex = std.Thread.Mutex; +const Allocator = std.mem.Allocator; + +pub const Pool = struct { + mutex: Mutex, + available: usize, + allocator: Allocator, + buffer_size: usize, + buffers: []*Buffer, + + pub fn init(allocator: Allocator, pool_size: u16, buffer_size: usize) !Pool { + const buffers = try allocator.alloc(*Buffer, pool_size); + + for (0..pool_size) |i| { + const sb = try allocator.create(Buffer); + sb.* = try Buffer.init(allocator, buffer_size); + buffers[i] = sb; + } + + return .{ .mutex = .{}, .buffers = buffers, .allocator = allocator, .available = pool_size, .buffer_size = buffer_size }; + } + + pub fn deinit(self: *Pool) void { + const allocator = self.allocator; + for (self.buffers) |sb| { + sb.deinit(); + allocator.destroy(sb); + } + allocator.free(self.buffers); + } + + pub fn acquire(self: *Pool) !*Buffer { + return self.acquireWithAllocator(self.allocator); + } + + pub fn acquireWithAllocator(self: *Pool, dyn_allocator: Allocator) !*Buffer { + const buffers = self.buffers; + + self.mutex.lock(); + const available = self.available; + if (available == 0) { + // dont hold the lock over factory + self.mutex.unlock(); + + const allocator = self.allocator; + const sb = try allocator.create(Buffer); + sb.* = try Buffer.init(allocator, self.buffer_size); + sb._da = dyn_allocator; + return sb; + } + const index = available - 1; + const sb = buffers[index]; + self.available = index; + self.mutex.unlock(); + sb._da = dyn_allocator; + return sb; + } + + pub fn release(self: *Pool, sb: *Buffer) void { + sb.reset(); + self.mutex.lock(); + + var buffers = self.buffers; + const available = self.available; + if (available == buffers.len) { + self.mutex.unlock(); + const allocator = self.allocator; + sb.deinit(); + allocator.destroy(sb); + return; + } + buffers[available] = sb; + self.available = available + 1; + self.mutex.unlock(); + } +}; + +const t = @import("t.zig"); +test "pool: acquire and release" { + var p = try Pool.init(t.allocator, 2, 100); + defer p.deinit(); + + const sb1a = p.acquire() catch unreachable; + const sb2a = p.acquire() catch unreachable; + const sb3a = p.acquire() catch unreachable; // this should be dynamically generated + + try t.expectEqual(false, sb1a == sb2a); + try t.expectEqual(false, sb2a == sb3a); + + p.release(sb1a); + + const sb1b = p.acquire() catch unreachable; + try t.expectEqual(true, sb1a == sb1b); + + p.release(sb3a); + p.release(sb2a); + p.release(sb1b); +} + +test "pool: dynamic allocator" { + var p = try Pool.init(t.allocator, 2, 5); + defer p.deinit(); + + var arena = std.heap.ArenaAllocator.init(t.allocator); + defer arena.deinit(); + + var sb = p.acquireWithAllocator(arena.allocator()) catch unreachable; + try sb.write("hello world how's it going?"); + try sb.write("he"); + try sb.write("hello world"); + try sb.write("are you doing well? I hope so, I don't love how this is being implemented, but I think the feature is worthwhile"); + p.release(sb); +} + +test "pool: threadsafety" { + var p = try Pool.init(t.allocator, 3, 20); + defer p.deinit(); + + // initialize this to 0 since we're asserting that it's 0 + for (p.buffers) |sb| { + sb.buf[0] = 0; + } + + const t1 = try std.Thread.spawn(.{}, testPool, .{&p}); + const t2 = try std.Thread.spawn(.{}, testPool, .{&p}); + const t3 = try std.Thread.spawn(.{}, testPool, .{&p}); + + t1.join(); + t2.join(); + t3.join(); +} + +fn testPool(p: *Pool) void { + var r = t.getRandom(); + const random = r.random(); + + for (0..5000) |_| { + var sb = p.acquire() catch unreachable; + // no other thread should have set this to 255 + std.debug.assert(sb.buf[0] == 0); + + sb.buf[0] = 255; + std.Thread.sleep(random.uintAtMost(u32, 100000)); + sb.buf[0] = 0; + p.release(sb); + } +} diff --git a/zig/pg-deps-buffer/src/t.zig b/zig/pg-deps-buffer/src/t.zig new file mode 100644 index 0000000..b7e603d --- /dev/null +++ b/zig/pg-deps-buffer/src/t.zig @@ -0,0 +1,17 @@ +const std = @import("std"); +pub const allocator = std.testing.allocator; + +// std.testing.expectEqual won't coerce expected to actual, which is a problem +// when expected is frequently a comptime. +// https://github.com/ziglang/zig/issues/4437 +pub fn expectEqual(expected: anytype, actual: anytype) !void { + try std.testing.expectEqual(@as(@TypeOf(actual), expected), actual); +} +pub const expectString = std.testing.expectEqualStrings; +pub const expectSlice = std.testing.expectEqualSlices; + +pub fn getRandom() std.Random.DefaultPrng { + var seed: u64 = undefined; + std.posix.getrandom(std.mem.asBytes(&seed)) catch unreachable; + return std.Random.DefaultPrng.init(seed); +} diff --git a/zig/pg-deps-buffer/test_runner.zig b/zig/pg-deps-buffer/test_runner.zig new file mode 100644 index 0000000..00f457f --- /dev/null +++ b/zig/pg-deps-buffer/test_runner.zig @@ -0,0 +1,294 @@ +// in your build.zig, you can specify a custom test runner: +// const tests = b.addTest(.{ +// .root_module = $MODULE_BEING_TESTED, +// .test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple }, +// }); + +const std = @import("std"); +const builtin = @import("builtin"); + +const Allocator = std.mem.Allocator; + +const BORDER = "=" ** 80; + +// use in custom panic handler +var current_test: ?[]const u8 = null; + +pub fn main() !void { + var mem: [8192]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&mem); + + const allocator = fba.allocator(); + + const env = Env.init(allocator); + defer env.deinit(allocator); + + var slowest = SlowTracker.init(allocator, 5); + defer slowest.deinit(); + + var pass: usize = 0; + var fail: usize = 0; + var skip: usize = 0; + var leak: usize = 0; + + Printer.fmt("\r\x1b[0K", .{}); // beginning of line and clear to end of line + + for (builtin.test_functions) |t| { + if (isSetup(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nsetup \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + for (builtin.test_functions) |t| { + if (isSetup(t) or isTeardown(t)) { + continue; + } + + var status = Status.pass; + slowest.startTiming(); + + const is_unnamed_test = isUnnamed(t); + if (env.filter) |f| { + if (!is_unnamed_test and std.mem.indexOf(u8, t.name, f) == null) { + continue; + } + } + + const friendly_name = blk: { + const name = t.name; + var it = std.mem.splitScalar(u8, name, '.'); + while (it.next()) |value| { + if (std.mem.eql(u8, value, "test")) { + const rest = it.rest(); + break :blk if (rest.len > 0) rest else name; + } + } + break :blk name; + }; + + current_test = friendly_name; + std.testing.allocator_instance = .{}; + const result = t.func(); + current_test = null; + + const ns_taken = slowest.endTiming(friendly_name); + + if (std.testing.allocator_instance.deinit() == .leak) { + leak += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - Memory Leak\n{s}\n", .{ BORDER, friendly_name, BORDER }); + } + + if (result) |_| { + pass += 1; + } else |err| switch (err) { + error.SkipZigTest => { + skip += 1; + status = .skip; + }, + else => { + status = .fail; + fail += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - {s}\n{s}\n", .{ BORDER, friendly_name, @errorName(err), BORDER }); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + if (env.fail_first) { + break; + } + }, + } + + if (env.verbose) { + const ms = @as(f64, @floatFromInt(ns_taken)) / 1_000_000.0; + Printer.status(status, "{s} ({d:.2}ms)\n", .{ friendly_name, ms }); + } else { + Printer.status(status, ".", .{}); + } + } + + for (builtin.test_functions) |t| { + if (isTeardown(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nteardown \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + const total_tests = pass + fail; + const status = if (fail == 0) Status.pass else Status.fail; + Printer.status(status, "\n{d} of {d} test{s} passed\n", .{ pass, total_tests, if (total_tests != 1) "s" else "" }); + if (skip > 0) { + Printer.status(.skip, "{d} test{s} skipped\n", .{ skip, if (skip != 1) "s" else "" }); + } + if (leak > 0) { + Printer.status(.fail, "{d} test{s} leaked\n", .{ leak, if (leak != 1) "s" else "" }); + } + Printer.fmt("\n", .{}); + try slowest.display(); + Printer.fmt("\n", .{}); + std.posix.exit(if (fail == 0) 0 else 1); +} + +const Printer = struct { + fn fmt(comptime format: []const u8, args: anytype) void { + std.debug.print(format, args); + } + + fn status(s: Status, comptime format: []const u8, args: anytype) void { + switch (s) { + .pass => std.debug.print("\x1b[32m", .{}), + .fail => std.debug.print("\x1b[31m", .{}), + .skip => std.debug.print("\x1b[33m", .{}), + else => {}, + } + std.debug.print(format ++ "\x1b[0m", args); + } +}; + +const Status = enum { + pass, + fail, + skip, + text, +}; + +const SlowTracker = struct { + const SlowestQueue = std.PriorityDequeue(TestInfo, void, compareTiming); + max: usize, + slowest: SlowestQueue, + timer: std.time.Timer, + + fn init(allocator: Allocator, count: u32) SlowTracker { + const timer = std.time.Timer.start() catch @panic("failed to start timer"); + var slowest = SlowestQueue.init(allocator, {}); + slowest.ensureTotalCapacity(count) catch @panic("OOM"); + return .{ + .max = count, + .timer = timer, + .slowest = slowest, + }; + } + + const TestInfo = struct { + ns: u64, + name: []const u8, + }; + + fn deinit(self: SlowTracker) void { + self.slowest.deinit(); + } + + fn startTiming(self: *SlowTracker) void { + self.timer.reset(); + } + + fn endTiming(self: *SlowTracker, test_name: []const u8) u64 { + var timer = self.timer; + const ns = timer.lap(); + + var slowest = &self.slowest; + + if (slowest.count() < self.max) { + // Capacity is fixed to the # of slow tests we want to track + // If we've tracked fewer tests than this capacity, than always add + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + { + // Optimization to avoid shifting the dequeue for the common case + // where the test isn't one of our slowest. + const fastest_of_the_slow = slowest.peekMin() orelse unreachable; + if (fastest_of_the_slow.ns > ns) { + // the test was faster than our fastest slow test, don't add + return ns; + } + } + + // the previous fastest of our slow tests, has been pushed off. + _ = slowest.removeMin(); + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + fn display(self: *SlowTracker) !void { + var slowest = self.slowest; + const count = slowest.count(); + Printer.fmt("Slowest {d} test{s}: \n", .{ count, if (count != 1) "s" else "" }); + while (slowest.removeMinOrNull()) |info| { + const ms = @as(f64, @floatFromInt(info.ns)) / 1_000_000.0; + Printer.fmt(" {d:.2}ms\t{s}\n", .{ ms, info.name }); + } + } + + fn compareTiming(context: void, a: TestInfo, b: TestInfo) std.math.Order { + _ = context; + return std.math.order(a.ns, b.ns); + } +}; + +const Env = struct { + verbose: bool, + fail_first: bool, + filter: ?[]const u8, + + fn init(allocator: Allocator) Env { + return .{ + .verbose = readEnvBool(allocator, "TEST_VERBOSE", true), + .fail_first = readEnvBool(allocator, "TEST_FAIL_FIRST", false), + .filter = readEnv(allocator, "TEST_FILTER"), + }; + } + + fn deinit(self: Env, allocator: Allocator) void { + if (self.filter) |f| { + allocator.free(f); + } + } + + fn readEnv(allocator: Allocator, key: []const u8) ?[]const u8 { + const v = std.process.getEnvVarOwned(allocator, key) catch |err| { + if (err == error.EnvironmentVariableNotFound) { + return null; + } + std.log.warn("failed to get env var {s} due to err {}", .{ key, err }); + return null; + }; + return v; + } + + fn readEnvBool(allocator: Allocator, key: []const u8, deflt: bool) bool { + const value = readEnv(allocator, key) orelse return deflt; + defer allocator.free(value); + return std.ascii.eqlIgnoreCase(value, "true"); + } +}; + +pub const panic = std.debug.FullPanic(struct { + pub fn panicFn(msg: []const u8, first_trace_addr: ?usize) noreturn { + if (current_test) |ct| { + std.debug.print("\x1b[31m{s}\npanic running \"{s}\"\n{s}\x1b[0m\n", .{ BORDER, ct, BORDER }); + } + std.debug.defaultPanic(msg, first_trace_addr); + } +}.panicFn); + +fn isUnnamed(t: std.builtin.TestFn) bool { + const marker = ".test_"; + const test_name = t.name; + const index = std.mem.indexOf(u8, test_name, marker) orelse return false; + _ = std.fmt.parseInt(u32, test_name[index + marker.len ..], 10) catch return false; + return true; +} + +fn isSetup(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:beforeAll"); +} + +fn isTeardown(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:afterAll"); +} diff --git a/zig/pg-deps-metrics/.gitignore b/zig/pg-deps-metrics/.gitignore new file mode 100644 index 0000000..dca1103 --- /dev/null +++ b/zig/pg-deps-metrics/.gitignore @@ -0,0 +1,2 @@ +zig-out/ +.zig-cache/ diff --git a/zig/pg-deps-metrics/LICENSE b/zig/pg-deps-metrics/LICENSE new file mode 100644 index 0000000..011dfad --- /dev/null +++ b/zig/pg-deps-metrics/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2024 Karl Seguin. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/zig/pg-deps-metrics/Makefile b/zig/pg-deps-metrics/Makefile new file mode 100644 index 0000000..be17f4d --- /dev/null +++ b/zig/pg-deps-metrics/Makefile @@ -0,0 +1,5 @@ +F= + +.PHONY: t +t: + TEST_FILTER="${F}" zig build test -freference-trace --summary all diff --git a/zig/pg-deps-metrics/build.zig b/zig/pg-deps-metrics/build.zig new file mode 100644 index 0000000..07f4bd4 --- /dev/null +++ b/zig/pg-deps-metrics/build.zig @@ -0,0 +1,48 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) !void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const metrics_module = b.addModule("metrics", .{ + .root_source_file = b.path("src/metrics.zig"), + .target = target, + .optimize = optimize, + }); + + { + // setup example + const example_module = b.createModule(.{ + .root_source_file = b.path("example/main.zig"), + .target = target, + .optimize = optimize, + }); + const example = b.addExecutable(.{ + .name = "metrics demo", + .root_module = example_module, + }); + example.root_module.addImport("metrics", metrics_module); + b.installArtifact(example); + + const run_example_cmd = b.addRunArtifact(example); + run_example_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_example_cmd.addArgs(args); + } + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_example_cmd.step); + } + + { + // setup tests + const lib_test = b.addTest(.{ + .root_module = metrics_module, + .test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple }, + }); + const run_test = b.addRunArtifact(lib_test); + run_test.has_side_effects = true; + + const test_step = b.step("test", "Run tests"); + test_step.dependOn(&run_test.step); + } +} diff --git a/zig/pg-deps-metrics/build.zig.zon b/zig/pg-deps-metrics/build.zig.zon new file mode 100644 index 0000000..cddfc54 --- /dev/null +++ b/zig/pg-deps-metrics/build.zig.zon @@ -0,0 +1,6 @@ +.{ + .name = .metrics, + .paths = .{""}, + .version = "0.0.0", + .fingerprint = 0x228aaae778b8b15b, +} diff --git a/zig/pg-deps-metrics/example/lib/lib.zig b/zig/pg-deps-metrics/example/lib/lib.zig new file mode 100644 index 0000000..3860055 --- /dev/null +++ b/zig/pg-deps-metrics/example/lib/lib.zig @@ -0,0 +1,21 @@ +// This folder simulates a 3rd party library that the application, main.zig +// is using. + +const metrics = @import("metrics.zig"); + +// Expose initializeMetrics to give control to the application over whether or +// not the library metrics should be enabled. +pub const initializeMetrics = metrics.initialize; + +// Expose writeMetrics to the application +pub const writeMetrics = metrics.write; + +// We want to collect metrics about this +pub fn doSomething() !void { + metrics.active(10); + + // vectored metrics can fail, hence the try + try metrics.hit(.{.status = 200, .path = "/robots.txt"}); + + try metrics.latency(.{.path = "/"}, 3.2); +} diff --git a/zig/pg-deps-metrics/example/lib/metrics.zig b/zig/pg-deps-metrics/example/lib/metrics.zig new file mode 100644 index 0000000..b0072dc --- /dev/null +++ b/zig/pg-deps-metrics/example/lib/metrics.zig @@ -0,0 +1,49 @@ +const std = @import("std"); +const m = @import("metrics"); + +const Allocator = std.mem.Allocator; + +// public to be exposed to other files within this library, not to be exposed +// directly to the application. +var metrics = m.initializeNoop(Metrics); + +const HitLabel = struct{status: u16, path: []const u8}; +const LatencyLabel = struct{path: []const u8}; + +const Metrics = struct { + hits: Hits, + active: Active, + latency: Latency, + + const Hits = m.CounterVec(u32, HitLabel); + const Active = m.Gauge(u32); + const Latency = m.HistogramVec( + f64, + LatencyLabel, + &.{0.05, 0.10, 0.50, 1, 2.5, 5, 10} + ); +}; + +pub fn hit(labels: HitLabel) !void { + return metrics.hits.incr(labels); +} + +pub fn active(value: u32) void { + metrics.active.set(value); +} + +pub fn latency(labels: LatencyLabel, value: f32) !void { + return metrics.latency.observe(labels, value); +} + +pub fn initialize(allocator: Allocator, comptime opts: m.RegistryOpts) !void { + metrics = .{ + .hits = try Metrics.Hits.init(allocator, "lib_hits", .{}, opts), + .active = Metrics.Active.init("lib_active", .{}, opts), + .latency = try Metrics.Latency.init(allocator, "lib_latency", .{}, opts), + }; +} + +pub fn write(writer: anytype) !void { + return m.write(&metrics, writer); +} diff --git a/zig/pg-deps-metrics/example/main.zig b/zig/pg-deps-metrics/example/main.zig new file mode 100644 index 0000000..46ddd7a --- /dev/null +++ b/zig/pg-deps-metrics/example/main.zig @@ -0,0 +1,38 @@ +// This example attempts to demonstrate how both library developers and +// application developers can use this lirary. + +// The "lib" subfolder emulates a library. + +// Application developers can also define their own "metric" for their own +// application-specifie dmtrics + +const std = @import("std"); +const m = @import("metrics"); + +// simulates a library that has metrics +const lib = @import("lib/lib.zig"); + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + // The application initializes the metrics for all the libraries it wishes + // to get metrics from. Optionally, the application can force a metric + // name prefix and can exclude specific metrics + try lib.initializeMetrics(allocator, .{ + .prefix = "", // default to "" + .exclude = null, // defaults to null + }); + + // this will use some of the libraries metrics + try lib.doSomething(); + + // the application can output the library metric to a writer + var buffer: [1024]u8 = undefined; + var stdout = std.fs.File.stdout().writerStreaming(&buffer); + try lib.writeMetrics(&stdout.interface); + try stdout.interface.flush(); + // try anotherLib.writeMetrics(writer); +} + + diff --git a/zig/pg-deps-metrics/readme.md b/zig/pg-deps-metrics/readme.md new file mode 100644 index 0000000..b36683c --- /dev/null +++ b/zig/pg-deps-metrics/readme.md @@ -0,0 +1,321 @@ +# Prometheus Metric Library for Zig +This library is designed for both library and application developers. I do hope to streamline setup when comptime allocations are allowed. + +It supports, counters, gauges and histograms and the labeled-variant of each. + +Please see the example project. It demonstrates how a library developer, and how an application developer can initialize and output them. + +## Metric Setup +Setup is a bit tedious, and I welcome suggestions for improvement. + +Let's start with a basic example. While the metrics within this library can be used directly, I believe that each library/application should create its own `Metrics` struct that encapsulates all metrics. A global instance of this struct can be created and initialized at comptime into a "noop" state. + +```zig +const m = @import("metrics"); + +// defaults to noop metrics, making this safe to use +// whether or not initializeMetrics is called +var metrics = m.initializeNoop(Metrics); + +const Metrics = struct { + // counter can be a unsigned integer or floats + hits: m.Counter(u32), + + // gauge can be an integer or float + connected: m.Gauge(u16), +}; + +// meant to be called within the application +pub fn hit() void { + metrics.hits.incr(); +} + +// meant to be called within the application +pub fn connected(value: u16) void { + metrics.connected.set(value); +} + +// meant to be called once on application startup +pub fn initializeMetrics(comptime opts: m.RegistryOpts) !void { + metrics = .{ + .hits = m.Counter(u32).init("hits", .{}, opts), + .connected = m.Gauge(u16).init("connected", .{}, opts), + }; +} + +// thread safe +pub fn writeMetrics(writer: *std.io.Writer) !void { + return m.write(&metrics, writer); +} +``` + +The call to `m.initializeNoop(Metrics)` creates a `Metrics` and initializes each metric (`hits`, `connected` and `latency`) to a "noop" implementation (tagged unions are used). The `initializeMetrics` is called on application startup and sets these metrics to real implementation. + +For library developers, this means their global metrics are always safe to use (all methods call noop). For application developers, it gives them control over which metrics to enable. + +All metrics take a name and **two options**. Why two options? The first is designed for library developers, the second is designed to give application developers additional control. + +Currently the first option has a single field: +* `help: ?[]const u8 = nulls` - Used to generate the `# HELP $HELP` output line + +The second option should has two fields: +* `prefix: []const u8 = ""` - Appends `prefix` to the start of each metric name. +* `exclude: ?[]const []const u8 = null` - A list of metric names to exclude (not including the prefix). + +`CounterVec`, `GaugeVec`, `Histogram` and `HistogramVec` also require an allocator. + +### Note for Library Developers +Library developers are free to change the above as needed. However, having libraries consistently expose an `initializeMetrics` and `writeMetrics` should help application developers. + +Library developers should ask their users to call `try initializeMetrics(allocator, .{})` on startup and `try writeMetrics(writer)` to generate the metrics. + +The `RegistryOpts` parameter should be supplied by the application and passed to each metric-initializer as-is. + +### Labels (vector-metrics) +Every metric type supports a vectored variant. This allows labels to be attached to metrics. This metrics require an `std.mem.Allocator` and, as you'll see in the metric API section, most of their methods can fail. + +```zig +var metrics = m.initializeNoop(Metrics); + +const Metrics = struct { + hits: m.CounterVec(u32, struct{status: u16, name: []const u8}), +}; + +// All labeled metrics require an allocator +pub fn initializeMetrics(allocator: Allocator, opts: m.RegistryOpts) !void { + metrics = .{ + .hits = try m.CounterVec(u32, struct{status: u16, name: []const u8}).init(allocator, "hits", .{}, opts), + }; +} +``` + +The labels are strongly types. Valid label types are: `ErrorSet`, `Enum`, `Type`, `Bool`, `Int` and `[]const u8` + +The `CounterVec(u32, ...)` has to be typed twice: once in the definition of `Metrics` and once in `initializeMetrics`. This can be improved slightly. + +```zig +var metrics = m.initializeNoop(Metrics); + +const Metrics = struct { + hits: Hits, + + const Hits = m.CounterVec(u32, struct{status: u16, name: []const u8}); +}; + +pub fn initializeMetrics(allocator: Allocator, opts: m.RegistryOpts) !void { + metrics = .{ + .hits = try Metrics.Hits.init(allocator, "hits", .{}, opts), + }; +} + +// Labels are compile-time checked. Using "anytype" here +// is just lazy so we don't have to declare the label structure +pub fn hit(labels: anytype) !void { + return metrics.hits.incr(labels); +} +``` + +The above would be called as: + +```zig +// import your metrics file +const metrics = @import("metrics.zig"); +metrics.hit(.{.status = 200, .path = "/about.txt"}); +``` + +Internally, every metric is a union between a "noop" and an actual implementation. This allows metrics to be globally initialized as noop and then enabled on startup. The benefit of this approach is that library developers can safely and easily use their metrics whether or not the application has enabled them. + +### Histograms +Histograms are setup like `Counter` and `Gauge`, and have a vectored-variant, but they require a comptime list of buckets: + +```zig +const Metrics = struct { + latency: Latency, + + const Latency = m.Histogram(f32, &.{0.005, 0.01, 0.05, 0.1, 0.25, 1, 5, 10}); +}; + +pub fn initializeMetrics(opts: m.RegistryOpts) !void { + metrics = .{ + .latency = Metrics.Latency.init("hits", .{}, opts), + }; +} +``` + +The `HistogramVec` is even more verbose, requiring the label struct and bucket list. And, like all vectored metrics, requires an `std.mem.Allocator` and can fail: + +```zig +var metrics = m.initializeNoop(Metrics); + +const Metrics = struct { + latency: Latency, + + const Latency = m.HistogramVec( + u32, + struct{path: []const u8}, + &.{5, 10, 25, 50, 100, 250, 500, 1000} + ); +}; + +pub fn initializeMetrics(allocator: Allocator, opts: m.RegistryOpts) !void { + metrics = .{ + .latency = try Metrics.Latency.init(allocator, "hits", .{}, opts), + }; +} + +// Labels are compile-time checked. Using "anytype" here +// is just lazy so we don't have to declare the label structure +// Would be called as: +// @import("metrics.zig").recordLatency(.{.path = "robots.txt"}, 2); +pub fn recordLatency(labels: anytype, value: u32) !void { + return metrics.latency.observe(labels, value); +} +``` + +## Metrics + +### Utility +The package exposes the following utility functions. + +#### `initializeNoop(T) T` +Creates an initializes metric `T` with `noop` implementation of every metric field. `T` should contain only metrics (`Counter`, `Gauge`, `Historgram` or their vectored variants) and primitive fields (int, bool, []const u8, enum, float). + +`initializeNoop(T)` will set any non-metric field to its default value. + +This method is designed to allow a global "metrics" instance to exist and be safe to use within libraries. + +#### `write(metrics: anytype, writer: *std.Io.Writer) !void` +Calls the `write(writer) !void` method on every metric field within `metrics`. + +Library developers are expected to wrap this method in a `writeMetric(writer: *std.io.Writer) !void` function. This function requires a pointer to your metrics. + +### Counter(T) +A `Counter(T)` is used for incrementing values. `T` can be an unsigned integer or a float. Its two main methods are `incr()` and `incrBy(value: T)`. `incr()` is a short version of `incrBy(1)`. + +#### `init(comptime name: []const, comptime opts: Opts, comptime ropts: RegistryOpts) !Counter(T)` +Initializes the counter. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + + +#### `incr(self: *Counter(T)) void` +Increments the counter by 1. + +#### `incrBy(self: *Counter(T), value: T) void` +Increments the counter by `value`. + +#### `write(self: *const Counter(T), writer: *std.io.Writer) !void` +Writes the counter to `writer`. + +### CounterVec(T, L) +A `CounterVec(T, L)` is used for incrementing values with labels. `T` can be an unsigned integer or a float. `L` must be a struct where the field names and types will define the lables. Its two main methods are `incr(labels: L)` and `incrBy(labels: L, value: T)`. `incr(L)` is a short version of `incrBy(L, 1)`. + +#### `init(allocator: Allocator, comptime name: []const, comptim eopts: Opts, comptime ropts: RegistryOpts) !CounterVec(T, L)` +Initializes the counter. Name must be given at comptime. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + +#### `deinit(self: *CounterVec(T, L)) void` +Deallocates the counter + +#### `incr(self: *CounterVec(T, L), labels: L) !void` +Increments the counter by 1. Vectored metrics can fail. + +#### `incrBy(self: *CounterVec(T, L), labels: L, value: T) !void` +Increments the counter by `value`. Vectored metrics can fail. + +#### `remove(self: *CounterVec(T, L), labels: L) void` +Removes the labeled value from the counter. Safe to call if `labels` is not an existing label. + +#### `write(self: *CounterVec(T, L), writer: *std.io.Writer) !void` +Writes the counter to `writer`. + +### Gauge(T) +A `Gauge(T)` is used for setting values. `T` can be an integer or a float. Its main methods are `incr()`, `incrBy(value: T)` and `set(value: T)`. `incr()` is a short version of `incrBy(1)`. + +#### `init(comptime name: []const, comptime opts: Opts, comptime ropts: RegistryOpts) !Gauge(T)` +Initializes the gauge. Name must be given at comptime. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + +#### `incr(self: *Gauge(T)) void` +Increments the gauge by 1. + +#### `incrBy(self: *Gauge(T), value: T) void` +Increments the gauge by `value`. + +#### `set(self: *Gauge(T), value: T) void` +Sets the the gauge to `value`. + +#### `write(self: *Gauge(T), writer: *std.io.Writer) !void` +Writes the gauge to `writer`. + +### GaugeVec(T, L) +A `GaugeVec(T, L)` is used for incrementing values with labels. `T` can be an integer or a float. `L` must be a struct where the field names and types will define the lables. Its main methods are `incr(labels: L)`, `incrBy(labels: L, value: T)` and `set(labels: L, value: T)`. `incr(L)` is a short version of `incrBy(L, 1)`. + +#### `init(allocator: Allocator, comptime name: []const, comptime opts: Opts, comptime ropts: RegistryOpts) !GaugeVec(T, L)` +Initializes the gauge. Name must be given at comptime. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + +#### `deinit(self: *GaugeVec(T, L)) void` +Deallocates the gauge + +#### `incr(self: *GaugeVec(T, L), labels: L) !void` +Increments the gauge by 1. Vectored metrics can fail. + +#### `incrBy(self: *GaugeVec(T, L), labels: L, value: T) !void` +Increments the gauge by `value`. Vectored metrics can fail. + +#### `set(self: *GaugeVec(T, L), labels: L, value: T) !void` +Sets the gauge to `value`. Vectored metrics can fail. + +#### `remove(self: *GaugeVec(T, L), labels: L) void` +Removes the labeled value from the gauge. Safe to call if `labels` is not an existing label. + +#### `write(self: *GaugeVec(T, L), writer: *std.io.Writer) !void` +Writes the gauge to `writer`. + +### Histogram(T, []T) +A `Histogram(T, []T)` is used to track the size and frequency of events. `T` can be an unsigned integer or a float. Its main methods is `observe(T)`. + +Observed valued will fall within one of the provided buckets, `[]T`. The buckets must be in ascending order. A final "infinite" bucket *should not* be provided. + +#### `init(comptime name: []const, comptime opts: Opts, comptime ropts: RegistryOpts) !Histogram(T, []T)` +Initializes the histogram. Name must be given at comptime. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + +#### `observe(self: *Histogram(T, []T), value: T) void` +Observes `value`, bucketing it based on the provided comptime buckets. + +#### `write(self: *Histogram(T, []T), writer: *std.io.Writer) !void` +Writes the histogram to `writer`. + +### Histogram(T, L, []T) +A `Histogram(T, L, []T)` is used to track the size and frequency of events. `T` can be an unsigned integer or a float. `L` must be a struct where the field names and types will define the lables. Its main methods is `observe(T)`. + +Observed valued will fall within one of the provided buckets, `[]T`. The buckets must be in ascending order. A final "infinite" bucket *should not* be provided. + +#### `init(allocator: Allocator, comptime name: []const, comptime opts: Opts, comptime ropts: RegistryOpts) !Histogram(T, L, []T)` +Initializes the histogram. Name must be given at comptime. + +Opts is: +* `help: ?[]const` - optional help text to include in the prometheus output + +#### `deinit(self: *Histogram(T, L, []T)) void` +Deallocates the histogram + +#### `observe(self: Histogram(T, L, []T), value: T) void` +Observes `value`, bucketing it based on the provided comptime buckets. + +#### `remove(self: *Histogram(T, L, []T), labels: L) void` +Removes the labeled value from the histogram. Safe to call if `labels` is not an existing label. + +#### `write(self: Histogram(T, L, []T), writer: *std.io.Writer) !void` +Writes the histogram to `writer`. diff --git a/zig/pg-deps-metrics/src/counter.zig b/zig/pg-deps-metrics/src/counter.zig new file mode 100644 index 0000000..af301fa --- /dev/null +++ b/zig/pg-deps-metrics/src/counter.zig @@ -0,0 +1,471 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const m = @import("metric.zig"); +const MetricVec = m.MetricVec; + +const RegistryOpts = @import("registry.zig").Opts; + +const Opts = struct { + help: ?[]const u8 = null, +}; + +pub fn Counter(comptime V: type) type { + assertCounterType(V); + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = Impl.init(ropts.prefix ++ name, opts) }, + } + } + + pub fn incr(self: *Self) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.incr(), + } + } + + pub fn incrBy(self: *Self, count: V) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.incrBy(count), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + pub const Impl = struct { + count: V, + preamble: []const u8, + + pub fn init(comptime name: []const u8, comptime opts: Opts) Impl { + return .{ + .count = 0, + .preamble = comptime m.preamble(name, .counter, true, opts.help), + }; + } + + pub fn incr(self: *Impl) void { + self.incrBy(1); + } + + pub fn incrBy(self: *Impl, count: V) void { + _ = @atomicRmw(V, &self.count, .Add, count, .monotonic); + } + + pub fn write(self: *const Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + const count = @atomicLoad(V, &self.count, .monotonic); + try m.write(count, writer); + return writer.writeByte('\n'); + } + }; + }; +} + +// Counter with labels +pub fn CounterVec(comptime V: type, comptime L: type) type { + assertCounterType(V); + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) !Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = try Impl.init(allocator, ropts.prefix ++ name, opts) }, + } + } + + pub fn deinit(self: *Self) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.deinit(), + } + } + + pub fn incr(self: *Self, labels: L) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.incr(labels), + } + } + + pub fn incrBy(self: *Self, labels: L, count: V) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.incrBy(labels, count), + } + } + + pub fn remove(self: *Self, labels: L) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.remove(labels), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + pub const Impl = struct { + vec: MetricVec(L), + preamble: []const u8, + allocator: Allocator, + lock: std.Thread.RwLock, + values: MetricVec(L).HashMap(Value), + + pub const Value = struct { + count: V, + attributes: []const u8, + }; + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts) !Impl { + return .{ + .lock = .{}, + .allocator = allocator, + .vec = try MetricVec(L).init(name), + .values = MetricVec(L).HashMap(Value){}, + .preamble = comptime m.preamble(name, .counter, false, opts.help), + }; + } + + pub fn deinit(self: *Impl) void { + const allocator = self.allocator; + + var it = self.values.iterator(); + while (it.next()) |kv| { + MetricVec(L).free(allocator, kv.key_ptr.*); + allocator.free(kv.value_ptr.attributes); + } + self.values.deinit(allocator); + } + + pub fn incr(self: *Impl, labels: L) !void { + return self.incrBy(labels, 1); + } + + pub fn incrBy(self: *Impl, labels: L, count: V) !void { + const allocator = self.allocator; + + { + self.lock.lockShared(); + defer self.lock.unlockShared(); + if (self.values.getPtr(labels)) |existing| { + _ = @atomicRmw(V, &existing.count, .Add, count, .monotonic); + return; + } + } + + // It's possible that another thread will come in and create this + // missing label, and we'll check for that, but we'll assume not and + // do our allocations here, outside of any locks. + const attributes = try MetricVec(L).buildAttributes(allocator, labels); + errdefer allocator.free(attributes); + + const owned_labels = try MetricVec(L).dupe(allocator, labels); + errdefer MetricVec(L).free(allocator, owned_labels); + + const counter = Value{ + .count = count, + .attributes = attributes, + }; + + self.lock.lock(); + defer self.lock.unlock(); + + const gop = try self.values.getOrPut(allocator, owned_labels); + if (gop.found_existing) { + MetricVec(L).free(allocator, owned_labels); + allocator.free(attributes); + gop.value_ptr.count += count; + return; + } + + gop.value_ptr.* = counter; + } + + pub fn remove(self: *Impl, labels: L) void { + const kv = blk: { + self.lock.lock(); + defer self.lock.unlock(); + break :blk self.values.fetchRemove(labels) orelse return; + }; + + const allocator = self.allocator; + MetricVec(L).free(allocator, kv.key); + allocator.free(kv.value.attributes); + } + + pub fn write(self: *Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + + const name = self.vec.name; + + self.lock.lockShared(); + defer self.lock.unlockShared(); + + var it = self.values.iterator(); + while (it.next()) |kv| { + try writer.writeAll(name); + + const value = kv.value_ptr.*; + try writer.writeAll(value.attributes); + try m.write(value.count, writer); + try writer.writeByte('\n'); + } + } + }; + }; +} + +fn assertCounterType(comptime T: type) void { + switch (@typeInfo(T)) { + .float => return, + .int => |int| { + if (int.signedness == .unsigned) return; + }, + else => {}, + } + @compileError("Counter metric must be an unsigned integer or a float, got: " ++ @typeName(T)); +} + +const t = @import("t.zig"); +test "Counter: noop incr/incrBy" { + // these should just not crash + var c = Counter(u32){ .noop = {} }; + c.incr(); + c.incrBy(10); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try c.write(&writer.writer); + try t.expectEqual(0, writer.writer.end); +} + +test "Counter: incr/incrBy" { + var c = Counter(u32).init("t1", .{}, .{}); + c.incr(); + try t.expectEqual(1, c.impl.count); + c.incrBy(10); + try t.expectEqual(11, c.impl.count); +} + +test "Counter: write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + var c = Counter(u32).init("metric_cnt_1_x", .{}, .{ .exclude = &.{"t_ex"} }); + + { + c.incr(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_cnt_1_x counter\nmetric_cnt_1_x 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + c.incrBy(399929123); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_cnt_1_x counter\nmetric_cnt_1_x 399929124\n", buf); + } +} + +test "Counter: exclude" { + var c = Counter(u32).init("t_ex", .{}, .{ .exclude = &.{"t_ex"} }); + c.incr(); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "Counter: prefix" { + var c = Counter(u32).init("t1_p", .{}, .{ .prefix = "hello_" }); + c.incr(); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE hello_t1_p counter\nhello_t1_p 1\n", buf); +} + +test "Counter: float incr/incrBy" { + var c = Counter(f32).init("t1", .{}, .{}); + c.incr(); + try t.expectEqual(1, c.impl.count); + c.incrBy(12.1); + try t.expectEqual(13.1, c.impl.count); +} + +test "Counter: float write" { + var writer: std.io.Writer.Allocating = .init( + t.allocator, + ); + defer writer.deinit(); + + var c = Counter(f64).init("metric_cnt_2_x", .{}, .{}); + + { + c.incr(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_cnt_2_x counter\nmetric_cnt_2_x 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + c.incrBy(123.991); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_cnt_2_x counter\nmetric_cnt_2_x 124.991\n", buf); + } +} + +test "CounterVec: noop incr/incrBy" { + // these should just not crash + var c = CounterVec(u32, struct { id: u32 }){ .noop = {} }; + defer c.deinit(); + try c.incr(.{ .id = 3 }); + try c.incrBy(.{ .id = 10 }, 20); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "CounterVec: incr/incrBy + write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + const preamble = "# HELP counter_vec_1 h1\n# TYPE counter_vec_1 counter\n"; + + // these should just not crash + var c = try CounterVec(u64, struct { id: []const u8 }).init(t.allocator, "counter_vec_1", .{ .help = "h1" }, .{}); + defer c.deinit(); + + { + try c.incr(.{ .id = "a" }); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_1{id=\"a\"} 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + try c.incr(.{ .id = "b" }); + try c.incr(.{ .id = "a" }); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_1{id=\"b\"} 1\ncounter_vec_1{id=\"a\"} 2\n", buf); + } + + { + writer.clearRetainingCapacity(); + try c.incrBy(.{ .id = "a" }, 20); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_1{id=\"b\"} 1\ncounter_vec_1{id=\"a\"} 22\n", buf); + } + + { + writer.clearRetainingCapacity(); + c.remove(.{ .id = "not_found" }); + c.remove(.{ .id = "a" }); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_1{id=\"b\"} 1\n", buf); + } +} + +test "CounterVec: float incr/incrBy + write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + const preamble = "# HELP counter_vec_xx_2 h1\n# TYPE counter_vec_xx_2 counter\n"; + + // these should just not crash + var c = try CounterVec(f32, struct { id: []const u8 }).init(t.allocator, "counter_vec_xx_2", .{ .help = "h1" }, .{}); + defer c.deinit(); + + { + try c.incr(.{ .id = "a" }); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_xx_2{id=\"a\"} 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + try c.incr(.{ .id = "b" }); + try c.incr(.{ .id = "a" }); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_xx_2{id=\"b\"} 1\ncounter_vec_xx_2{id=\"a\"} 2\n", buf); + } + + { + writer.clearRetainingCapacity(); + try c.incrBy(.{ .id = "a" }, 0.25); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_xx_2{id=\"b\"} 1\ncounter_vec_xx_2{id=\"a\"} 2.25\n", buf); + } +} + +test "Counter: concurrent create" { + const EquitiesCounter = CounterVec(u64, struct { + symbol: []const u8, + type: []const u8, + }); + + const preamble = "# TYPE counter_vec_concurrent counter\n"; + + const run = struct { + fn run(c: *EquitiesCounter) void { + c.incrBy(.{ .symbol = "AAPL", .type = "trade" }, 1) catch {}; + } + }.run; + + for (1..100) |_| { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + var c = try EquitiesCounter.init(t.allocator, "counter_vec_concurrent", .{}, .{}); + defer c.deinit(); + + var th1 = try std.Thread.spawn(.{}, run, .{&c}); + var th2 = try std.Thread.spawn(.{}, run, .{&c}); + th2.join(); + th1.join(); + + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "counter_vec_concurrent{symbol=\"AAPL\",type=\"trade\"} 2\n", buf); + } +} diff --git a/zig/pg-deps-metrics/src/gauge.zig b/zig/pg-deps-metrics/src/gauge.zig new file mode 100644 index 0000000..3a75b2f --- /dev/null +++ b/zig/pg-deps-metrics/src/gauge.zig @@ -0,0 +1,523 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const m = @import("metric.zig"); +const Metric = m.Metric; +const MetricVec = m.MetricVec; + +const RegistryOpts = @import("registry.zig").Opts; + +const Opts = struct { + help: ?[]const u8 = null, +}; + +pub fn Gauge(comptime V: type) type { + assertGaugeType(V); + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = Impl.init(ropts.prefix ++ name, opts) }, + } + } + + pub fn incr(self: *Self) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.incr(), + } + } + + pub fn set(self: *Self, value: V) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.set(value), + } + } + + pub fn incrBy(self: *Self, value: V) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.incrBy(value), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + pub const Impl = struct { + value: V, + preamble: []const u8, + + pub fn init(comptime name: []const u8, comptime opts: Opts) Impl { + return .{ + .value = 0, + .preamble = comptime m.preamble(name, .gauge, true, opts.help), + }; + } + + pub fn incr(self: *Impl) void { + self.incrBy(1); + } + + pub fn incrBy(self: *Impl, value: V) void { + _ = @atomicRmw(V, &self.value, .Add, value, .monotonic); + } + + pub fn set(self: *Impl, value: V) void { + @atomicStore(V, &self.value, value, .monotonic); + } + + pub fn write(self: *const Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + try m.write(@atomicLoad(V, &self.value, .monotonic), writer); + return writer.writeByte('\n'); + } + }; + }; +} + +// Gauge with labels +pub fn GaugeVec(comptime V: type, comptime L: type) type { + assertGaugeType(V); + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) !Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = try Impl.init(allocator, ropts.prefix ++ name, opts) }, + } + } + + pub fn deinit(self: *Self) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.deinit(), + } + } + + pub fn incr(self: *Self, labels: L) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.incr(labels), + } + } + + pub fn incrBy(self: *Self, labels: L, value: V) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.incrBy(labels, value), + } + } + + pub fn set(self: *Self, labels: L, value: V) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.set(labels, value), + } + } + + pub fn remove(self: *Self, labels: L) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.remove(labels), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + pub const Impl = struct { + vec: MetricVec(L), + preamble: []const u8, + allocator: Allocator, + lock: std.Thread.RwLock, + values: MetricVec(L).HashMap(Value), + + const Value = struct { + value: V, + attributes: []const u8, + }; + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts) !Impl { + return .{ + .lock = .{}, + .allocator = allocator, + .vec = try MetricVec(L).init(name), + .values = MetricVec(L).HashMap(Value){}, + .preamble = comptime m.preamble(name, .gauge, false, opts.help), + }; + } + + pub fn deinit(self: *Impl) void { + const allocator = self.allocator; + + var it = self.values.iterator(); + while (it.next()) |kv| { + MetricVec(L).free(allocator, kv.key_ptr.*); + allocator.free(kv.value_ptr.attributes); + } + self.values.deinit(allocator); + } + + pub fn incr(self: *Impl, labels: L) !void { + return self.incrBy(labels, 1); + } + + pub fn incrBy(self: *Impl, labels: L, value: V) !void { + return self.withValue(labels, value, atomicIncrCallback, incrCallback); + } + + fn atomicIncrCallback(value: V, entry: *Value) void { + entry.value += value; + } + + fn incrCallback(value: V, entry: *Value) void { + entry.value += value; + } + + pub fn set(self: *Impl, labels: L, value: V) !void { + return self.withValue(labels, value, atomicSetCallback, setCallback); + } + + fn setCallback(value: V, entry: *Value) void { + entry.value = value; + } + + fn atomicSetCallback(value: V, entry: *Value) void { + entry.value = value; + } + + pub fn remove(self: *Impl, labels: L) void { + const kv = blk: { + self.lock.lock(); + defer self.lock.unlock(); + break :blk self.values.fetchRemove(labels) orelse return; + }; + + const allocator = self.allocator; + MetricVec(L).free(allocator, kv.key); + allocator.free(kv.value.attributes); + } + + pub fn write(self: *Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + + const name = self.vec.name; + + self.lock.lockShared(); + defer self.lock.unlockShared(); + + var it = self.values.iterator(); + while (it.next()) |kv| { + try writer.writeAll(name); + + const value = kv.value_ptr.*; + try writer.writeAll(value.attributes); + try m.write(value.value, writer); + try writer.writeByte('\n'); + } + } + + // value is only used if this is the first time we've seen this label. + // if we've already seen this label, and thus have an existing entry in + // our map, then fa() or f() is executed. fa() is called when an atomic + // update is necessary, and f() is called when an atomic update isn't ( + // because a mutex is being held). + fn withValue(self: *Impl, labels: L, value: V, comptime fa: fn (V, *Value) void, comptime f: fn (V, *Value) void) !void { + const allocator = self.allocator; + + { + self.lock.lockShared(); + defer self.lock.unlockShared(); + if (self.values.getPtr(labels)) |existing| { + fa(value, existing); + return; + } + } + + // It's possible that another thread will come in and create this + // missing label, and we'll check for that, but we'll assume not and + // do our allocations here, outside of any locks. + const attributes = try MetricVec(L).buildAttributes(allocator, labels); + errdefer allocator.free(attributes); + + const owned_labels = try MetricVec(L).dupe(allocator, labels); + errdefer MetricVec(L).free(allocator, owned_labels); + + const gauge = Value{ + .value = value, + .attributes = attributes, + }; + + self.lock.lock(); + defer self.lock.unlock(); + + const gop = try self.values.getOrPut(allocator, owned_labels); + if (gop.found_existing) { + MetricVec(L).free(allocator, owned_labels); + allocator.free(attributes); + f(value, gop.value_ptr); + return; + } + gop.value_ptr.* = gauge; + } + }; + }; +} + +fn assertGaugeType(comptime T: type) void { + switch (@typeInfo(T)) { + .float, .int => return, + else => {}, + } + @compileError("Gauge metric must be an integer or float, got: " ++ @typeName(T)); +} + +const t = @import("t.zig"); +test "Gauge: noop incr/incrBy/set" { + // these should just not crash + var c = Gauge(u32){ .noop = {} }; + c.incr(); + c.incrBy(10); + c.set(100); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "Gauge: incr/incrBy/set" { + var g = Gauge(i32).init("t1", .{}, .{}); + + g.incr(); + try t.expectEqual(1, g.impl.value); + + g.incrBy(10); + try t.expectEqual(11, g.impl.value); + + g.incrBy(-2); + try t.expectEqual(9, g.impl.value); + + g.set(-10); + try t.expectEqual(-10, g.impl.value); +} + +test "Gauge: write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + var g = Gauge(i32).init("metric_grp_1_x", .{}, .{}); + + { + g.incr(); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_grp_1_x gauge\nmetric_grp_1_x 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + g.incrBy(399929123); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_grp_1_x gauge\nmetric_grp_1_x 399929124\n", buf); + } + + { + writer.clearRetainingCapacity(); + g.set(-329); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_grp_1_x gauge\nmetric_grp_1_x -329\n", buf); + } +} + +test "Gauge: float incr/incrBy/set" { + var c = Gauge(f32).init("t1", .{}, .{}); + c.incr(); + try t.expectEqual(1, c.impl.value); + c.incrBy(-3.9); + try t.expectEqual(-2.9, c.impl.value); + c.set(99.9); + try t.expectEqual(99.9, c.impl.value); +} + +test "Gauge: float write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + var c = Gauge(f64).init("metric_g_2_x", .{}, .{}); + + { + c.incr(); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_g_2_x gauge\nmetric_g_2_x 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + c.incrBy(-9.2); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_g_2_x gauge\nmetric_g_2_x -8.2\n", buf); + } + + { + writer.clearRetainingCapacity(); + c.set(8.888); + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString("# TYPE metric_g_2_x gauge\nmetric_g_2_x 8.888\n", buf); + } +} + +test "GaugeVec: noop incr/incrBy/set" { + // these should just not crash + var g = GaugeVec(u32, struct { id: u32 }){ .noop = {} }; + defer g.deinit(); + try g.incr(.{ .id = 3 }); + try g.incrBy(.{ .id = 10 }, 20); + try g.set(.{ .id = 3 }, 11); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "GaugeVec: incr/incrBy/set + write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + const preamble = "# HELP gauge_vec_1 h1\n# TYPE gauge_vec_1 gauge\n"; + + // these should just not crash + var g = try GaugeVec(i64, struct { id: []const u8 }).init(t.allocator, "gauge_vec_1", .{ .help = "h1" }, .{}); + defer g.deinit(); + + { + try g.incr(.{ .id = "a" }); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_1{id=\"a\"} 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + try g.incr(.{ .id = "b" }); + try g.incr(.{ .id = "a" }); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_1{id=\"b\"} 1\ngauge_vec_1{id=\"a\"} 2\n", buf); + } + + { + writer.clearRetainingCapacity(); + try g.incrBy(.{ .id = "a" }, 20); + try g.set(.{ .id = "c" }, 5); + try g.set(.{ .id = "b" }, -33); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_1{id=\"b\"} -33\ngauge_vec_1{id=\"a\"} 22\ngauge_vec_1{id=\"c\"} 5\n", buf); + } + + { + writer.clearRetainingCapacity(); + g.remove(.{ .id = "not_found" }); + g.remove(.{ .id = "a" }); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_1{id=\"b\"} -33\ngauge_vec_1{id=\"c\"} 5\n", buf); + } +} + +test "GaugeVec: float incr/incrBy/set + write" { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + const preamble = "# HELP gauge_vec_xx_2 h1\n# TYPE gauge_vec_xx_2 gauge\n"; + + // these should just not crash + var g = try GaugeVec(f64, struct { id: []const u8 }).init(t.allocator, "gauge_vec_xx_2", .{ .help = "h1" }, .{}); + defer g.deinit(); + + { + try g.incr(.{ .id = "a" }); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_xx_2{id=\"a\"} 1\n", buf); + } + + { + writer.clearRetainingCapacity(); + try g.incr(.{ .id = "b" }); + try g.incr(.{ .id = "a" }); + try g.set(.{ .id = "c\nc" }, 0.011); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_xx_2{id=\"b\"} 1\ngauge_vec_xx_2{id=\"a\"} 2\ngauge_vec_xx_2{id=\"c\\nc\"} 0.011\n", buf); + } + + { + writer.clearRetainingCapacity(); + try g.incrBy(.{ .id = "a" }, 0.25); + g.remove(.{ .id = "c\nc" }); + try g.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_xx_2{id=\"b\"} 1\ngauge_vec_xx_2{id=\"a\"} 2.25\n", buf); + } +} + +test "Gauge: concurrent create" { + const EquitiesGauge = GaugeVec(u64, struct { + symbol: []const u8, + type: []const u8, + }); + + const preamble = "# TYPE gauge_vec_concurrent gauge\n"; + + const run = struct { + fn run(c: *EquitiesGauge) void { + c.set(.{ .symbol = "AAPL", .type = "trade" }, 1) catch {}; + } + }.run; + + for (0..100) |_| { + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + var c = try EquitiesGauge.init(t.allocator, "gauge_vec_concurrent", .{}, .{}); + defer c.deinit(); + + var th1 = try std.Thread.spawn(.{}, run, .{&c}); + var th2 = try std.Thread.spawn(.{}, run, .{&c}); + th2.join(); + th1.join(); + + try c.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString(preamble ++ "gauge_vec_concurrent{symbol=\"AAPL\",type=\"trade\"} 1\n", buf); + } +} diff --git a/zig/pg-deps-metrics/src/histogram.zig b/zig/pg-deps-metrics/src/histogram.zig new file mode 100644 index 0000000..cda5e1b --- /dev/null +++ b/zig/pg-deps-metrics/src/histogram.zig @@ -0,0 +1,626 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const m = @import("metric.zig"); +const Metric = m.Metric; +const MetricVec = m.MetricVec; + +const RegistryOpts = @import("registry.zig").Opts; + +const Opts = struct { + help: ?[]const u8 = null, +}; + +pub fn Histogram(comptime V: type, comptime upper_bounds: []const V) type { + assertHistogramType(V); + assertUpperBounds(upper_bounds); + + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = comptime Impl.init(ropts.prefix ++ name, opts) }, + } + } + + pub fn observe(self: *Self, value: V) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.observe(value), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + pub const Impl = struct { + sum: V, + count: usize, + preamble: []const u8, + buckets: [upper_bounds.len]V, + output_sum_prefix: []const u8, + output_count_prefix: []const u8, + output_bucket_prefixes: [upper_bounds.len][]const u8, + output_bucket_inf_prefix: []const u8, + + pub fn init(comptime name: []const u8, comptime opts: Opts) Impl { + comptime { + const output_sum_prefix = std.fmt.comptimePrint("\n{s}_sum ", .{name}); + const output_count_prefix = std.fmt.comptimePrint("\n{s}_count ", .{name}); + + const output_bucket_inf_prefix = std.fmt.comptimePrint("{s}_bucket{{le=\"+Inf\"}} ", .{name}); + var output_bucket_prefixes: [upper_bounds.len][]const u8 = undefined; + + for (upper_bounds, 0..) |upper, i| { + output_bucket_prefixes[i] = std.fmt.comptimePrint("{s}_bucket{{le=\"{d}\"}} ", .{ name, upper }); + } + + return .{ + .sum = 0, + .count = 0, + .preamble = m.preamble(name, .histogram, false, opts.help), + .output_sum_prefix = output_sum_prefix, + .output_count_prefix = output_count_prefix, + .output_bucket_prefixes = output_bucket_prefixes, + .output_bucket_inf_prefix = output_bucket_inf_prefix, + .buckets = std.mem.zeroes([upper_bounds.len]V), + }; + } + } + + pub fn observe(self: *Impl, value: V) void { + _ = @atomicRmw(usize, &self.count, .Add, 1, .monotonic); + _ = @atomicRmw(V, &self.sum, .Add, value, .monotonic); + + const idx = blk: { + for (upper_bounds, 0..) |upper, i| { + if (value < upper) { + break :blk i; + } + } + // this is our implicit bucket to +Inf. Implicit because the count + // and sum, updated above, will contain this entry + return; + }; + + _ = @atomicRmw(V, &self.buckets[idx], .Add, 1, .monotonic); + } + + pub fn write(self: *Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + + var sum: V = 0; + for (self.output_bucket_prefixes, 0..) |prefix, i| { + sum += @atomicRmw(V, &self.buckets[i], .Xchg, 0, .monotonic); + try writer.writeAll(prefix); + try m.write(sum, writer); + try writer.writeByte('\n'); + } + + const total_count = @atomicRmw(usize, &self.count, .Xchg, 0, .monotonic); + { + // write +Inf + try writer.writeAll(self.output_bucket_inf_prefix); + try writer.printInt(total_count, 10, .lower, .{}); + } + + { + //write sum + // this includes a leading newline, hence we didn't need to write + // it after our output_bucket_inf_prefix + try writer.writeAll(self.output_sum_prefix); + try m.write(@atomicRmw(V, &self.sum, .Xchg, 0, .monotonic), writer); + } + + { + //write count + // this includes a leading newline, hence we didn't need to write + // it after our output_sum_prefix + try writer.writeAll(self.output_count_prefix); + try writer.printInt(total_count, 10, .lower, .{}); + try writer.writeByte('\n'); + } + } + }; + }; +} + +pub fn HistogramVec(comptime V: type, comptime L: type, comptime upper_bounds: []const V) type { + assertHistogramType(V); + assertUpperBounds(upper_bounds); + + return union(enum) { + noop: void, + impl: Impl, + + const Self = @This(); + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts, comptime ropts: RegistryOpts) !Self { + switch (ropts.shouldExclude(name)) { + true => return .{ .noop = {} }, + false => return .{ .impl = try Impl.init(allocator, ropts.prefix ++ name, opts) }, + } + } + + pub fn observe(self: *Self, labels: L, value: V) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.observe(labels, value), + } + } + + pub fn write(self: *Self, writer: *std.io.Writer) !void { + switch (self.*) { + .noop => {}, + .impl => |*impl| return impl.write(writer), + } + } + + // could get the allocator from impl.allocator, but taking it as a parameter + // makes the API the same between Histogram and HistogramVec + pub fn deinit(self: *Self) void { + switch (self.*) { + .noop => {}, + .impl => |*impl| impl.deinit(), + } + } + + pub const Impl = struct { + vec: MetricVec(L), + preamble: []const u8, + allocator: Allocator, + lock: std.Thread.RwLock, + values: MetricVec(L).HashMap(Value), + output_sum_prefix: []const u8, + output_count_prefix: []const u8, + output_bucket_prefixes: [upper_bounds.len][]const u8, + output_bucket_inf_prefix: []const u8, + + const Value = struct { + sum: V, + count: usize, + mutex: std.Thread.Mutex, + buckets: [upper_bounds.len]V, + // this gets glued to our output_bucket_prefixes + attributes: []const u8, + + fn observe(self: *Value, value: V, idx: ?usize) void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.observeLocked(value, idx); + } + + fn observeLocked(self: *Value, value: V, idx: ?usize) void { + self.sum += value; + self.count += 1; + if (idx) |idx_| { + self.buckets[idx_] += 1; + } + } + + fn getIndex(value: V) ?usize { + for (upper_bounds, 0..) |upper, i| { + if (value < upper) { + return i; + } + } + return null; + } + }; + + pub fn init(allocator: Allocator, comptime name: []const u8, comptime opts: Opts) !Impl { + const vec = try MetricVec(L).init(name); + + const output_sum_prefix = try std.fmt.allocPrint(allocator, "\n{s}_sum", .{name}); + errdefer allocator.free(output_sum_prefix); + + const output_count_prefix = try std.fmt.allocPrint(allocator, "\n{s}_count", .{name}); + errdefer allocator.free(output_count_prefix); + + const output_bucket_inf_prefix = try std.fmt.allocPrint(allocator, "{s}_bucket{{le=\"+Inf\",", .{name}); + errdefer allocator.free(output_bucket_inf_prefix); + + var output_bucket_prefixes: [upper_bounds.len][]const u8 = undefined; + var initialized: usize = 0; + errdefer { + for (0..initialized) |i| { + allocator.free(output_bucket_prefixes[i]); + } + } + + for (upper_bounds, 0..) |upper, i| { + output_bucket_prefixes[i] = try std.fmt.allocPrint(allocator, "{s}_bucket{{le=\"{d}\",", .{ name, upper }); + initialized += 1; + } + + return .{ + .vec = vec, + .lock = .{}, + .allocator = allocator, + .values = MetricVec(L).HashMap(Value){}, + .output_sum_prefix = output_sum_prefix, + .output_count_prefix = output_count_prefix, + .output_bucket_prefixes = output_bucket_prefixes, + .output_bucket_inf_prefix = output_bucket_inf_prefix, + .preamble = comptime m.preamble(name, .histogram, false, opts.help), + }; + } + + pub fn deinit(self: *Impl) void { + const allocator = self.allocator; + allocator.free(self.output_sum_prefix); + allocator.free(self.output_count_prefix); + allocator.free(self.output_bucket_inf_prefix); + for (self.output_bucket_prefixes) |obf| { + allocator.free(obf); + } + + var it = self.values.iterator(); + while (it.next()) |kv| { + MetricVec(L).free(allocator, kv.key_ptr.*); + allocator.free(kv.value_ptr.attributes); + } + self.values.deinit(allocator); + } + + pub fn observe(self: *Impl, labels: L, value: V) !void { + const allocator = self.allocator; + + // do this outside any lock + const idx: ?usize = blk: { + for (upper_bounds, 0..) |upper, i| { + if (value < upper) { + break :blk i; + } + } + // this is our implicit bucket to +Inf. Implicit because the count + // and sum, updated above, will contain this entry + break :blk null; + }; + + { + self.lock.lockShared(); + defer self.lock.unlockShared(); + if (self.values.getPtr(labels)) |existing| { + existing.observe(value, idx); + return; + } + } + + // It's possible that another thread will come in and create this + // missing label, and we'll check for that, but we'll assume not and + // do our allocations here, outside of any locks. + const attributes = try MetricVec(L).buildAttributes(allocator, labels); + errdefer allocator.free(attributes); + + const owned_labels = try MetricVec(L).dupe(allocator, labels); + errdefer MetricVec(L).free(allocator, owned_labels); + + const histogram = Value{ + .sum = 0, + .count = 0, + .mutex = .{}, + .attributes = attributes, + .buckets = std.mem.zeroes([upper_bounds.len]V), + }; + + self.lock.lock(); + defer self.lock.unlock(); + + const gop = try self.values.getOrPut(allocator, owned_labels); + if (gop.found_existing) { + MetricVec(L).free(allocator, owned_labels); + allocator.free(attributes); + } else { + gop.value_ptr.* = histogram; + } + + // since we've taking out a write lock out the entire histogram + // we can observe this value without taking an inner value lock. + gop.value_ptr.observeLocked(value, idx); + } + + pub fn remove(self: *Impl, labels: L) void { + const kv = blk: { + self.lock.lock(); + defer self.lock.unlock(); + break :blk self.values.fetchRemove(labels) orelse return; + }; + + const allocator = self.allocator; + MetricVec(L).free(allocator, kv.key); + allocator.free(kv.value.attributes); + } + + pub fn write(self: *Impl, writer: *std.io.Writer) !void { + try writer.writeAll(self.preamble); + + const output_sum_prefix = self.output_sum_prefix; + const output_count_prefix = self.output_count_prefix; + const output_bucket_inf_prefix = self.output_bucket_inf_prefix; + + self.lock.lockShared(); + defer self.lock.unlockShared(); + + var it = self.values.iterator(); + while (it.next()) |kv| { + var value = kv.value_ptr; + var bucket_counts: [upper_bounds.len]V = undefined; + + // copy our sum/count/bucket_counts out of Value into local variables + // to minimize our lock duration + value.mutex.lock(); + var buckets = &value.buckets; + const value_sum = value.sum; + const value_count = value.count; + for (buckets, 0..) |bucket_count, i| { + bucket_counts[i] = bucket_count; + buckets[i] = 0; + } + value.sum = 0; + value.count = 0; + value.mutex.unlock(); + + const attributes = value.attributes; + // attributes contains the opening and closing braces: {k="v"} + // but for the bucket values, we're appending the attribute to the + // pre-generated prefix, which already contains "{le="$bucket". + // So we strip out the leading "{" from our attribute so that we can + // glue is to our pre-generated prefix. + const append_attributes = attributes[1..]; + + var sum: V = 0; + for (self.output_bucket_prefixes, bucket_counts) |prefix, bucket_count| { + sum += bucket_count; + try writer.writeAll(prefix); + try writer.writeAll(append_attributes); + try m.write(sum, writer); + try writer.writeByte('\n'); + } + + { + // write +Inf + try writer.writeAll(output_bucket_inf_prefix); + try writer.writeAll(append_attributes); + try writer.printInt(value_count, 10, .lower, .{}); + } + + { + //write sum + // this includes a leading newline, hence we didn't need to write + // it after our output_bucket_inf_prefix + try writer.writeAll(output_sum_prefix); + try writer.writeAll(attributes); + try m.write(value_sum, writer); + } + + { + //write count + // this includes a leading newline, hence we didn't need to write + // it after our output_sum_prefix + try writer.writeAll(output_count_prefix); + try writer.writeAll(attributes); + try writer.printInt(value_count, 10, .lower, .{}); + try writer.writeByte('\n'); + } + } + } + }; + }; +} + +fn assertHistogramType(comptime T: type) void { + switch (@typeInfo(T)) { + .float => return, + .int => |int| { + if (int.signedness == .unsigned) return; + }, + else => {}, + } + @compileError("Histogram metric must be an unsigned integer or a float, got: " ++ @typeName(T)); +} + +fn assertUpperBounds(upper_bounds: anytype) void { + if (upper_bounds.len == 0) { + @compileError("Histogram upper bound cannot be empty"); + } + var last = upper_bounds[0]; + for (upper_bounds[1..]) |value| { + if (value < last) { + @compileError("Histogram upper bounds must be in ascending order"); + } + last = value; + } +} + +const t = @import("t.zig"); +test "Histogram: noop " { + // these should just not crash + var h = Histogram(u32, &.{0}){ .noop = {} }; + h.observe(2); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "Histogram: simple" { + var h = Histogram(f64, &.{ 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10 }).init("hst_1", .{}, .{}); + + var i: f64 = 0.001; + for (0..1000) |_| { + i = i + i / 100; + h.observe(i); + } + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + { + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hst_1 histogram + \\hst_1_bucket{le="0.005"} 161 + \\hst_1_bucket{le="0.01"} 231 + \\hst_1_bucket{le="0.025"} 323 + \\hst_1_bucket{le="0.05"} 393 + \\hst_1_bucket{le="0.1"} 462 + \\hst_1_bucket{le="0.25"} 554 + \\hst_1_bucket{le="0.5"} 624 + \\hst_1_bucket{le="1"} 694 + \\hst_1_bucket{le="2.5"} 786 + \\hst_1_bucket{le="5"} 855 + \\hst_1_bucket{le="10"} 925 + \\hst_1_bucket{le="+Inf"} 1000 + \\hst_1_sum 2116.7737194191777 + \\hst_1_count 1000 + \\ + , buf); + } + + { + writer.clearRetainingCapacity(); + h.observe(2.8); + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hst_1 histogram + \\hst_1_bucket{le="0.005"} 0 + \\hst_1_bucket{le="0.01"} 0 + \\hst_1_bucket{le="0.025"} 0 + \\hst_1_bucket{le="0.05"} 0 + \\hst_1_bucket{le="0.1"} 0 + \\hst_1_bucket{le="0.25"} 0 + \\hst_1_bucket{le="0.5"} 0 + \\hst_1_bucket{le="1"} 0 + \\hst_1_bucket{le="2.5"} 0 + \\hst_1_bucket{le="5"} 1 + \\hst_1_bucket{le="10"} 1 + \\hst_1_bucket{le="+Inf"} 1 + \\hst_1_sum 2.8 + \\hst_1_count 1 + \\ + , buf); + } +} + +test "HistogramVec: noop " { + // these should just not crash + var h = HistogramVec(u32, struct { status: u16 }, &.{0}){ .noop = {} }; + defer h.deinit(); + try h.observe(.{ .status = 200 }, 2); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "HistogramVec" { + var h = try HistogramVec(f64, struct { status: u16 }, &.{ 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10 }).init(t.allocator, "hst_1", .{}, .{}); + defer h.deinit(); + + var i: f64 = 0.001; + for (0..1000) |_| { + i = i + i / 100; + try h.observe(.{ .status = 200 }, i); + } + + i = 0.02; + for (0..100) |_| { + i = i + i / 50; + try h.observe(.{ .status = 400 }, i); + } + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + { + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hst_1 histogram + \\hst_1_bucket{le="0.005",status="200"} 161 + \\hst_1_bucket{le="0.01",status="200"} 231 + \\hst_1_bucket{le="0.025",status="200"} 323 + \\hst_1_bucket{le="0.05",status="200"} 393 + \\hst_1_bucket{le="0.1",status="200"} 462 + \\hst_1_bucket{le="0.25",status="200"} 554 + \\hst_1_bucket{le="0.5",status="200"} 624 + \\hst_1_bucket{le="1",status="200"} 694 + \\hst_1_bucket{le="2.5",status="200"} 786 + \\hst_1_bucket{le="5",status="200"} 855 + \\hst_1_bucket{le="10",status="200"} 925 + \\hst_1_bucket{le="+Inf",status="200"} 1000 + \\hst_1_sum{status="200"} 2116.7737194191777 + \\hst_1_count{status="200"} 1000 + \\hst_1_bucket{le="0.005",status="400"} 0 + \\hst_1_bucket{le="0.01",status="400"} 0 + \\hst_1_bucket{le="0.025",status="400"} 11 + \\hst_1_bucket{le="0.05",status="400"} 46 + \\hst_1_bucket{le="0.1",status="400"} 81 + \\hst_1_bucket{le="0.25",status="400"} 100 + \\hst_1_bucket{le="0.5",status="400"} 100 + \\hst_1_bucket{le="1",status="400"} 100 + \\hst_1_bucket{le="2.5",status="400"} 100 + \\hst_1_bucket{le="5",status="400"} 100 + \\hst_1_bucket{le="10",status="400"} 100 + \\hst_1_bucket{le="+Inf",status="400"} 100 + \\hst_1_sum{status="400"} 6.369539040617386 + \\hst_1_count{status="400"} 100 + \\ + , buf); + } + + { + try h.observe(.{ .status = 200 }, 9); + writer.clearRetainingCapacity(); + try h.write(&writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hst_1 histogram + \\hst_1_bucket{le="0.005",status="200"} 0 + \\hst_1_bucket{le="0.01",status="200"} 0 + \\hst_1_bucket{le="0.025",status="200"} 0 + \\hst_1_bucket{le="0.05",status="200"} 0 + \\hst_1_bucket{le="0.1",status="200"} 0 + \\hst_1_bucket{le="0.25",status="200"} 0 + \\hst_1_bucket{le="0.5",status="200"} 0 + \\hst_1_bucket{le="1",status="200"} 0 + \\hst_1_bucket{le="2.5",status="200"} 0 + \\hst_1_bucket{le="5",status="200"} 0 + \\hst_1_bucket{le="10",status="200"} 1 + \\hst_1_bucket{le="+Inf",status="200"} 1 + \\hst_1_sum{status="200"} 9 + \\hst_1_count{status="200"} 1 + \\hst_1_bucket{le="0.005",status="400"} 0 + \\hst_1_bucket{le="0.01",status="400"} 0 + \\hst_1_bucket{le="0.025",status="400"} 0 + \\hst_1_bucket{le="0.05",status="400"} 0 + \\hst_1_bucket{le="0.1",status="400"} 0 + \\hst_1_bucket{le="0.25",status="400"} 0 + \\hst_1_bucket{le="0.5",status="400"} 0 + \\hst_1_bucket{le="1",status="400"} 0 + \\hst_1_bucket{le="2.5",status="400"} 0 + \\hst_1_bucket{le="5",status="400"} 0 + \\hst_1_bucket{le="10",status="400"} 0 + \\hst_1_bucket{le="+Inf",status="400"} 0 + \\hst_1_sum{status="400"} 0 + \\hst_1_count{status="400"} 0 + \\ + , buf); + } +} diff --git a/zig/pg-deps-metrics/src/metric.zig b/zig/pg-deps-metrics/src/metric.zig new file mode 100644 index 0000000..4c6ca63 --- /dev/null +++ b/zig/pg-deps-metrics/src/metric.zig @@ -0,0 +1,670 @@ +const std = @import("std"); + +const ascii = std.ascii; +const Wyhash = std.hash.Wyhash; +const Allocator = std.mem.Allocator; + +const MetricType = enum { + counter, + gauge, + histogram, +}; + +// Used by metrics that have labels (CounterVec, GaugeVec, HistogramVec). This +// does Metric does (own the metric name, own the preamble), but also does a lot +// more with respect to labels (beause, regardless of what the underlying metric +// is, label handling is the same). +pub fn MetricVec(comptime L: type) type { + const ti = @typeInfo(L); + if (std.meta.activeTag(ti) != .@"struct") { + @compileError("Vec type must be a struct, got: " ++ @typeName(L)); + } + + const fields = ti.@"struct".fields; + inline for (fields) |f| { + validateLabel(f.name, f.type); + } + + // When we serialize attributes, we'll store each serialized attribute into + // an array of this type. + const SerializedValues = [fields.len]SerializedValue; + + // The length of the serialized attributes without the values. + // If L is struct{status: int, path: []const u8}, then this would be the length + // of: {status="",path=""} + // We use this when we need to allocate the attribute string, taking this length + // and adding it to the length of the serialized values. + const static_attribute_len = comptime blk: { + // +2 for the '{' and '}' around the entire attribute string + // +1 for the trailing space + // +fields.len - 1 for the comma separator between attributes + var len: usize = 2 + 1 + fields.len - 1; + for (fields) |f| { + // +1 for the '=' separator between attribute name and value + // +2 for the '"' around the value + len += f.name.len + 3; + } + break :blk len; + }; + + return struct { + // The name of the metric. Unlike with a plain Metric, this doesn't include + // a trailing space (because attributes are glued to the metric name) + name: []const u8, + + // The label names (which are the names of L's fields) + labels: [fields.len][]const u8, + + // std.AutoHashMap doesn't handle structs with slices (i.e. []const u8) fields + // So we create our own context (hash and eql) which supports the type allowed + // by our validateLabel + pub fn HashMap(comptime V: type) type { + return std.HashMapUnmanaged(L, V, HashContext(L), 80); + } + + const Self = @This(); + + pub fn init(comptime name: []const u8) !Self { + comptime validateName(name); + + comptime var labels: [fields.len][]const u8 = undefined; + inline for (fields, 0..) |f, i| { + labels[i] = f.name; + } + + return .{ + .name = name, + .labels = labels, + }; + } + + // The key of labeled metrics is the label itself (L), or more specifically + // the values. These need to exist for the lifetime of the entry in the map + // so we dupe it. We only allow a small number of types in our labels and + // the only type that needs to be allocated is a []const u8. + pub fn dupe(allocator: Allocator, value: L) !L { + var owned: L = undefined; + inline for (fields) |f| { + switch (@typeInfo(f.type)) { + .pointer => @field(owned, f.name) = try allocator.dupe(u8, @field(value, f.name)), + else => @field(owned, f.name) = @field(value, f.name), // all other fields are primitives + } + } + return owned; + } + + // Frees memory allocated by the above dupe function. + pub fn free(allocator: Allocator, value: L) void { + inline for (fields) |f| { + switch (@typeInfo(f.type)) { + .pointer => allocator.free(@field(value, f.name)), + else => {}, // all other fields are primitives + } + } + } + + // Every labeled metric has a hashmap of label => VALUE + // Where VALUE is going to be a metric specific value (like a number for + // a counter) as well as the serialized attribute string. For example + // given a CounterVec(u64, struct{status: u16}) and the label: + // .{.status = 200} + // The counter's hashmap will have an extra with the key being the label + // itself, a u64 count and the serialized label value: + // 200 => .{ + // .count = 1, + // .attributes = "{status=\"200\"}\n" + // } + // + // Given: + // .{.status = 200} + // this function builds the attribute string: + // "{status=\"200\"}\n" + pub fn buildAttributes(allocator: Allocator, values: L) ![]const u8 { + + // We begin by serializing all the values in L. We only support a few label + // types and some don't require allocation. The serializeValue function + // returns a SerializedValue which contains the serialized (string) + // representation of the value and a boolean to indicate if an allocation + // took place. This is needed so that we can properly clean up. + var len: usize = 0; + var serialized: SerializedValues = undefined; + inline for (fields, 0..) |f, i| { + const s = try serializeValue(allocator, @field(values, f.name)); + serialized[i] = s; + len += s.str.len; + } + + // Any allocations done in serializeValue is short lived, because we'll + // copy everything into the final attribute string. + defer { + for (serialized) |s| { + if (s.allocated) allocator.free(s.str); + } + } + + var buf = try allocator.alloc(u8, static_attribute_len + len); + buf[0] = '{'; + var pos: usize = 1; + inline for (fields, 0..) |f, i| { + { + // write the key + const value = f.name; + const end = pos + value.len; + @memcpy(buf[pos..end], value); + pos = end; + } + + buf[pos] = '='; + buf[pos + 1] = '"'; + pos += 2; + + { + // write the value + const value = serialized[i]; + const end = pos + value.str.len; + @memcpy(buf[pos..end], value.str); + buf[end] = '"'; + pos = end + 1; + } + buf[pos] = ','; + pos += 1; + } + // -1 to overwrite the last trailing comma + buf[pos - 1] = '}'; + // space between our attribute string and the metric value + buf[pos] = ' '; + return buf; + } + }; +} + +// Writes value to writer. Value can either be an integer or float. +pub fn write(value: anytype, writer: *std.io.Writer) !void { + switch (@typeInfo(@TypeOf(value))) { + .int => return writer.printInt(value, 10, .lower, .{}), + .float => return writer.print("{d}", .{value}), + else => unreachable, // there are guards that prevent this from being possible + } +} + +// Validates that a metric name is valid, based on: +// https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels +fn validateName(comptime name: []const u8) void { + if (name.len == 0) { + @compileError("Empty metric name is not valid"); + } + + { + const c = name[0]; + if (!ascii.isAlphabetic(c) and c != '_' and c != ':') { + @compileError("Metric name must begin with a letter, underscore or colon ([a-zA-Z_:])"); + } + } + + for (name[1..]) |c| { + if (!ascii.isAlphanumeric(c) and c != '_' and c != ':') { + @compileError("Metric name can only contain ascii letters, numbers, underscores and colons ([a-zA-Z_:][a-zA-Z0-9_:]*)"); + } + } +} + +// Validates that a label is valid. Validates both the name and the type. +// The validity of the name is based on: +// https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels +// The validity of the type is based on what our HashContext supports +fn validateLabel(comptime name: []const u8, comptime T: type) void { + if (name.len == 0) { + @compileError("Empty label name is not valid"); + } + + { + const c = name[0]; + if (!ascii.isAlphabetic(c) and c != '_') { + @compileError("Label name must begin with a letter, underscore or colon ([a-zA-Z_])"); + } + + if (c == '_' and name.len > 1 and name[1] == '_') { + @compileError("Label names starting with double underscore are reserved"); + } + } + + for (name[1..]) |c| { + if (!ascii.isAlphanumeric(c) and c != '_') { + @compileError("Label name can only contain ascii letters, numbers and underscores ([a-zA-Z_:][a-zA-Z0-9_]*)"); + } + } + + switch (@typeInfo(T)) { + .error_set, .@"enum", .type, .bool, .int => return, + .pointer => |ptr| { + switch (ptr.size) { + .slice => { + if (ptr.child == u8) { + return; + } + }, + else => {}, + } + }, + else => {}, + } + @compileError("Label data types " ++ @typeName(T) ++ " is not supported"); +} + +// attribite value => text +fn serializeValue(allocator: Allocator, value: anytype) !SerializedValue { + switch (@typeInfo(@TypeOf(value))) { + .int => { + const digits = numberOfDigits(value); + const buf = try allocator.alloc(u8, digits); + var writer: std.io.Writer = .fixed(buf); + try writer.printInt(value, 10, .lower, .{}); + std.debug.assert(writer.end == digits); + return .{ .str = buf, .allocated = true }; + }, + .bool => return .{ .str = if (value) "true" else "false" }, + .pointer => { + // validateLabel would only allow a []u8, so if we're here, it has to be a []u8 + // but the value might need to be escaped, + var escape_count: usize = 0; + for (value) |c| { + if (c == '\\' or c == '\n' or c == '"') { + escape_count += 1; + } + } + if (escape_count == 0) { + return .{ .str = value }; + } + var pos: usize = 0; + var escaped = try allocator.alloc(u8, value.len + escape_count); + for (value) |c| { + switch (c) { + '\\' => { + escaped[pos] = '\\'; + pos += 1; + escaped[pos] = '\\'; + }, + '\n' => { + escaped[pos] = '\\'; + pos += 1; + escaped[pos] = 'n'; + }, + '"' => { + escaped[pos] = '\\'; + pos += 1; + escaped[pos] = '"'; + }, + else => escaped[pos] = c, + } + pos += 1; + } + return .{ .str = escaped, .allocated = true }; + }, + .type => return .{ .str = @typeName(value) }, + .@"enum" => return .{ .str = @tagName(value) }, + .error_set => return .{ .str = @errorName(value) }, + else => unreachable, + } +} + +// When allocate is true, then str was allocated (and needed to be freed) +const SerializedValue = struct { + str: []const u8, + allocated: bool = false, +}; + +// Return the number of digits in a number, including a negative sign. +fn numberOfDigits(value: anytype) usize { + const adj: usize = if (value < 0) 1 else 0; + var v = @abs(value); + var count: usize = 1; + while (true) { + if (v < 10) return count + adj; + if (v < 100) return count + adj + 1; + if (v < 1000) return count + adj + 2; + if (v < 10000) return count + adj + 3; + if (v < 100000) return count + adj + 4; + if (v < 1000000) return count + adj + 5; + v = v / 1000000; + count += 6; + } +} + +// See MetricVec.HashMap above +fn HashContext(comptime K: type) type { + return struct { + const Self = @This(); + + const fields = @typeInfo(K).@"struct".fields; + + pub fn hash(_: Self, key: K) u64 { + var hasher = Wyhash.init(0); + inline for (fields) |field| { + hashValue(&hasher, @field(key, field.name)); + } + return hasher.final(); + } + + // similar to std.mem.eql, but compares string values + pub fn eql(_: Self, a: K, b: K) bool { + inline for (fields) |field| { + const value_a = @field(a, field.name); + const value_b = @field(b, field.name); + switch (@typeInfo(@TypeOf(value_a))) { + .pointer => if (std.mem.eql(u8, value_a, value_b) == false) return false, + else => if (value_a != value_b) return false, + } + } + return true; + } + + // Similar to what you'd find in std/hash/auto_hash.zig + // but only accepts a subset of types (only those types we support as label + // values) and accepts a []u8 value. + fn hashValue(hasher: *Wyhash, value: anytype) void { + const V = @TypeOf(value); + switch (@typeInfo(V)) { + .int => |int| switch (int.signedness) { + .signed => hashValue(hasher, @as(@Int(int.bits, .unsigned), @bitCast(value))), + .unsigned => { + if (std.meta.hasUniqueRepresentation(V)) { + hasher.update(std.mem.asBytes(&value)); + } else { + const byte_size = comptime std.math.divCeil(comptime_int, @bitSizeOf(V), 8) catch unreachable; + hasher.update(std.mem.asBytes(&value)[0..byte_size]); + } + }, + }, + .@"enum" => hashValue(hasher, @intFromEnum(value)), + .error_set => hashValue(hasher, @intFromError(value)), + .bool => hasher.update(if (value) &.{1} else &.{0}), + .pointer => hasher.update(value), // validateLabelType ensures this was a []u8 + .type => hasher.update(@typeName(value)), + else => unreachable, // validateLabelType only allows the above types + } + } + }; +} + +// The "preamble" is the optional "# HELP $DESC\n" and "# TYPE $TYPE\n" string +// which is output before the metric value. Help is optional, when null the +// "# HELP ..." line is omitted. +pub fn preamble(comptime name: []const u8, comptime tpe: MetricType, comptime append_name: bool, comptime help_: ?[]const u8) []const u8 { + comptime { + const suffix = if (append_name) name ++ " " else ""; + const type_line = std.fmt.comptimePrint("# TYPE {s} {s}\n{s}", .{ name, @tagName(tpe), suffix }); + const help = help_ orelse return type_line; + + // Help text requires \\ and \n to be escaped. Let's count how many of those we have. + var escape_count: usize = 0; + for (help) |c| { + if (c == '\\' or c == '\n') { + escape_count += 1; + } + } + + var h = help; + if (escape_count > 0) { + // We need to escape at least one special character. We need to allocate + // a new string to hold the escaped value + + // Since we know the original length and the # of characters that need to + // be escaped, we know the final length) + var escaped: [help.len + escape_count]u8 = undefined; + + var pos: usize = 0; + for (help) |c| { + switch (c) { + '\\' => { + escaped[pos] = '\\'; + pos += 1; + escaped[pos] = '\\'; + }, + '\n' => { + escaped[pos] = '\\'; + pos += 1; + escaped[pos] = 'n'; + }, + else => escaped[pos] = c, + } + pos += 1; + } + + h = &escaped; + } + + return std.fmt.comptimePrint("# HELP {s} {s}\n{s}", .{ name, h, type_line }); + } +} + +const t = @import("t.zig"); +test "preamble: no help" { + const p = comptime preamble("metric_test_1", .counter, true, null); + try t.expectString("# TYPE metric_test_1 counter\nmetric_test_1 ", p); +} + +test "preamble: no help, histogram" { + // histogram doesn't include the metric name it he preamble + const p = comptime preamble("metric_test_1", .histogram, false, null); + try t.expectString("# TYPE metric_test_1 histogram\n", p); +} + +test "preamble: simple help" { + const p = comptime preamble("metric_test_2", .gauge, true, "this is a valid help line"); + try t.expectString("# HELP metric_test_2 this is a valid help line\n# TYPE metric_test_2 gauge\nmetric_test_2 ", p); +} + +test "preamble: escape help" { + const p = comptime preamble("metric_test_3", .histogram, false, "th\\is is a\nvalid help line"); + try t.expectString("# HELP metric_test_3 th\\\\is is a\\nvalid help line\n# TYPE metric_test_3 histogram\n", p); +} + +test "MetricVec: labels" { + const m = try MetricVec(struct { + active: bool, + name: []const u8, + }).init("metric_vec_test_1"); + try t.expectSlice([]const u8, &.{ "active", "name" }, &m.labels); +} + +test "MetricVec: dupe/free" { + const base = t.AllLabels{ + .id = -320, + .key = 4199, + .active = true, + .err = error.OutOfMemory, + .state = .start, + .tag = "teg", + }; + const d = try MetricVec(t.AllLabels).dupe(t.allocator, .{ + .id = -320, + .key = 4199, + .active = true, + .err = error.OutOfMemory, + .state = .start, + .tag = "teg", + }); + defer MetricVec(t.AllLabels).free(t.allocator, d); + + try t.expectEqual(-320, d.id); + try t.expectEqual(4199, d.key); + try t.expectEqual(true, d.active); + try t.expectEqual(error.OutOfMemory, d.err); + try t.expectEqual(.start, d.state); + try t.expectString("teg", d.tag); + try t.expectEqual(false, d.tag.ptr == base.tag.ptr); +} + +test "MetricVec: buildAttributes" { + { + const l = try MetricVec(t.AllLabels).buildAttributes(t.allocator, .{ + .id = -320, + .key = 4199, + .active = true, + .err = error.OutOfMemory, + .state = .start, + .tag = "teg", + }); + defer t.allocator.free(l); + try t.expectString("{id=\"-320\",key=\"4199\",active=\"true\",err=\"OutOfMemory\",state=\"start\",tag=\"teg\"} ", l); + } + + { + // Escape string + const l = try MetricVec(struct { n: []const u8 }).buildAttributes(t.allocator, .{ + .n = "hello\nworld, how's\\it \"going\"", + }); + defer t.allocator.free(l); + try t.expectString("{n=\"hello\\nworld, how's\\\\it \\\"going\\\"\"} ", l); + } +} + +test "HashContext" { + var h = MetricVec(t.AllLabels).HashMap(i32){}; + defer h.deinit(t.allocator); + + const k1a = t.AllLabels{ + .id = -320, + .key = 4199, + .active = true, + .err = error.OutOfMemory, + .state = .start, + .tag = "teg", + }; + + // same as k1b, but our string (tag) is a different ptr + const k1b = t.AllLabels{ + .id = -320, + .key = 4199, + .active = true, + .err = error.OutOfMemory, + .state = .start, + .tag = try t.allocator.dupe(u8, "teg"), + }; + defer t.allocator.free(k1b.tag); + + { + const gop = try h.getOrPut(t.allocator, k1a); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 1; + } + + { + const gop = try h.getOrPut(t.allocator, k1a); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(1, gop.value_ptr.*); + gop.value_ptr.* = 2; + } + + { + const gop = try h.getOrPut(t.allocator, k1b); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(2, gop.value_ptr.*); + gop.value_ptr.* = 3; + } + + { + // different int + var k = k1a; + k.id = 320; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 3; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(3, gop.value_ptr.*); + } + } + + { + // different u13 + var k = k1a; + k.key += 1; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 4; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(4, gop.value_ptr.*); + } + } + + { + // different bool + var k = k1a; + k.active = !k.active; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 5; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(5, gop.value_ptr.*); + } + } + + { + // different err + var k = k1a; + k.err = error.Other; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 6; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(6, gop.value_ptr.*); + } + } + + { + // different enum + var k = k1a; + k.state = .end; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 7; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(7, gop.value_ptr.*); + } + } + + { + // different string + var k = k1a; + k.tag = "tag"; + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(false, gop.found_existing); + gop.value_ptr.* = 8; + } + { + const gop = try h.getOrPut(t.allocator, k); + try t.expectEqual(true, gop.found_existing); + try t.expectEqual(8, gop.value_ptr.*); + } + } +} + +test "numberOfDigits" { + try t.expectEqual(1, numberOfDigits(@as(i32, 0))); + try t.expectEqual(1, numberOfDigits(@as(u9, 1))); + try t.expectEqual(2, numberOfDigits(@as(i16, -1))); + try t.expectEqual(10, numberOfDigits(@as(u33, 1234567890))); + try t.expectEqual(19, numberOfDigits(@as(usize, 9223372036854775807))); + try t.expectEqual(20, numberOfDigits(@as(i64, -9223372036854775808))); +} diff --git a/zig/pg-deps-metrics/src/metrics.zig b/zig/pg-deps-metrics/src/metrics.zig new file mode 100644 index 0000000..1a341dc --- /dev/null +++ b/zig/pg-deps-metrics/src/metrics.zig @@ -0,0 +1,156 @@ +const std = @import("std"); + +const counter = @import("counter.zig"); +pub const Counter = counter.Counter; +pub const CounterVec = counter.CounterVec; + +const gauge = @import("gauge.zig"); +pub const Gauge = gauge.Gauge; +pub const GaugeVec = gauge.GaugeVec; + +const histogram = @import("histogram.zig"); +pub const Histogram = histogram.Histogram; +pub const HistogramVec = histogram.HistogramVec; + +pub const RegistryOpts = @import("registry.zig").Opts; + +// This allows a library developer to safely use a library-wide metrics +// instance by defaulting all metrics to "noop" variants. Library developers +// can then expose a function to the main application, say: +// try thelib.initializeMetrics(allocator) +// which then initializes the metrics instance to real implementations. +// Whether or not the application calls in initializeMetrics, the library +// can safely use the metrics instance, sine this function initialized it to +// noop. +pub fn initializeNoop(comptime T: type) T { + switch (@typeInfo(T)) { + .@"struct" => |struct_info| { + var m: T = undefined; + inline for (struct_info.fields) |field| { + switch (@typeInfo(field.type)) { + .@"union" => @field(m, field.name) = .{ .noop = {} }, + else => { + if (field.default_value_ptr) |default_value_ptr| { + const default_value = @as(*align(1) const field.type, @ptrCast(default_value_ptr)).*; + @field(m, field.name) = default_value; + } + }, + } + } + return m; + }, + else => @compileError("initializeNoop expects a struct"), + } +} + +pub fn write(metrics: anytype, writer: *std.io.Writer) !void { + const S = @typeInfo(@TypeOf(metrics)).pointer.child; + const fields = @typeInfo(S).@"struct".fields; + + inline for (fields) |f| { + switch (@typeInfo(f.type)) { + .@"union" => try @constCast(&@field(metrics, f.name)).write(writer), + else => {}, + } + } +} + +test { + std.testing.refAllDecls(@This()); +} + +const t = @import("t.zig"); +test "initializeNoop + write" { + const x = initializeNoop(struct { + status: u16 = 33, + hits: CounterVec(u32, struct { status: u16 }), + active: Gauge(u64), + latency: Histogram(u32, &.{ 0, 2 }), + }); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + try write(&x, &writer.writer); + const buf = writer.writer.buffered(); + try t.expectEqual(0, buf.len); +} + +test "metrics: write" { + const M = struct { + hits: Hits, + active: Gauge(u64), + timing: Timing, + + const Hits = CounterVec(u32, struct { status: u16 }); + const Timing = HistogramVec(u32, struct { path: []const u8 }, &.{ 5, 10, 25, 50, 100, 250, 500, 1000 }); + }; + + var m = M{ .active = Gauge(u64).init("active", .{}, .{}), .hits = try M.Hits.init(t.allocator, "hits", .{}, .{}), .timing = try M.Timing.init(t.allocator, "timing", .{ .help = "the timing" }, .{ .prefix = "x_" }) }; + defer m.hits.deinit(); + defer m.timing.deinit(); + + var writer: std.io.Writer.Allocating = .init(t.allocator); + defer writer.deinit(); + + m.active.set(919); + try m.hits.incr(.{ .status = 199 }); + + { + try write(&m, &writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hits counter + \\hits{status="199"} 1 + \\# TYPE active gauge + \\active 919 + \\# HELP x_timing the timing + \\# TYPE x_timing histogram + \\ + , buf); + } + + m.active.set(32); + try m.hits.incr(.{ .status = 199 }); + try m.hits.incr(.{ .status = 3 }); + try m.timing.observe(.{ .path = "/a" }, 2); + try m.timing.observe(.{ .path = "/a" }, 8); + try m.timing.observe(.{ .path = "/b" }, 7); + + { + writer.clearRetainingCapacity(); + try write(&m, &writer.writer); + const buf = writer.writer.buffered(); + try t.expectString( + \\# TYPE hits counter + \\hits{status="3"} 1 + \\hits{status="199"} 2 + \\# TYPE active gauge + \\active 32 + \\# HELP x_timing the timing + \\# TYPE x_timing histogram + \\x_timing_bucket{le="5",path="/b"} 0 + \\x_timing_bucket{le="10",path="/b"} 1 + \\x_timing_bucket{le="25",path="/b"} 1 + \\x_timing_bucket{le="50",path="/b"} 1 + \\x_timing_bucket{le="100",path="/b"} 1 + \\x_timing_bucket{le="250",path="/b"} 1 + \\x_timing_bucket{le="500",path="/b"} 1 + \\x_timing_bucket{le="1000",path="/b"} 1 + \\x_timing_bucket{le="+Inf",path="/b"} 1 + \\x_timing_sum{path="/b"} 7 + \\x_timing_count{path="/b"} 1 + \\x_timing_bucket{le="5",path="/a"} 1 + \\x_timing_bucket{le="10",path="/a"} 2 + \\x_timing_bucket{le="25",path="/a"} 2 + \\x_timing_bucket{le="50",path="/a"} 2 + \\x_timing_bucket{le="100",path="/a"} 2 + \\x_timing_bucket{le="250",path="/a"} 2 + \\x_timing_bucket{le="500",path="/a"} 2 + \\x_timing_bucket{le="1000",path="/a"} 2 + \\x_timing_bucket{le="+Inf",path="/a"} 2 + \\x_timing_sum{path="/a"} 10 + \\x_timing_count{path="/a"} 2 + \\ + , buf); + } +} diff --git a/zig/pg-deps-metrics/src/registry.zig b/zig/pg-deps-metrics/src/registry.zig new file mode 100644 index 0000000..023c58b --- /dev/null +++ b/zig/pg-deps-metrics/src/registry.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +// Not sure what I want to do about a "registry" +// But I think I want to wait until comptime allocation is available + +pub const Opts = struct { + prefix: []const u8 = "", + exclude: ?[]const []const u8 = null, + + pub fn shouldExclude(self: Opts, name: []const u8) bool { + const excludes = self.exclude orelse return false; + for (excludes) |exclude| { + if (std.mem.eql(u8, exclude, name)) { + return true; + } + } + return false; + } +}; + +const t = @import("t.zig"); +test "Registry.Opts: shouldExclude" { + try t.expectEqual(false, (Opts{}).shouldExclude("abc")); + try t.expectEqual(false, (Opts{ .exclude = &.{ "ABC", "other" } }).shouldExclude("abc")); + try t.expectEqual(true, (Opts{ .exclude = &.{ "abc", "other" } }).shouldExclude("abc")); + try t.expectEqual(true, (Opts{ .exclude = &.{ "a", "otaher", "abc" } }).shouldExclude("abc")); +} diff --git a/zig/pg-deps-metrics/src/t.zig b/zig/pg-deps-metrics/src/t.zig new file mode 100644 index 0000000..73adbf7 --- /dev/null +++ b/zig/pg-deps-metrics/src/t.zig @@ -0,0 +1,24 @@ +const std = @import("std"); + +pub const allocator = std.testing.allocator; + +pub const expectEqual = std.testing.expectEqual; +pub const expectFmt = std.testing.expectFmt; +pub const expectError = std.testing.expectError; +pub const expectSlice = std.testing.expectEqualSlices; +pub const expectString = std.testing.expectEqualStrings; + +// a structure with all the label types we supprot +pub const AllLabels = struct { + id: i32, + key: u13, + active: bool, + err: anyerror, + state: State, + tag: []const u8, + + const State = enum { + start, + end, + }; +}; diff --git a/zig/pg-deps-metrics/test_runner.zig b/zig/pg-deps-metrics/test_runner.zig new file mode 100644 index 0000000..00f457f --- /dev/null +++ b/zig/pg-deps-metrics/test_runner.zig @@ -0,0 +1,294 @@ +// in your build.zig, you can specify a custom test runner: +// const tests = b.addTest(.{ +// .root_module = $MODULE_BEING_TESTED, +// .test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple }, +// }); + +const std = @import("std"); +const builtin = @import("builtin"); + +const Allocator = std.mem.Allocator; + +const BORDER = "=" ** 80; + +// use in custom panic handler +var current_test: ?[]const u8 = null; + +pub fn main() !void { + var mem: [8192]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&mem); + + const allocator = fba.allocator(); + + const env = Env.init(allocator); + defer env.deinit(allocator); + + var slowest = SlowTracker.init(allocator, 5); + defer slowest.deinit(); + + var pass: usize = 0; + var fail: usize = 0; + var skip: usize = 0; + var leak: usize = 0; + + Printer.fmt("\r\x1b[0K", .{}); // beginning of line and clear to end of line + + for (builtin.test_functions) |t| { + if (isSetup(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nsetup \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + for (builtin.test_functions) |t| { + if (isSetup(t) or isTeardown(t)) { + continue; + } + + var status = Status.pass; + slowest.startTiming(); + + const is_unnamed_test = isUnnamed(t); + if (env.filter) |f| { + if (!is_unnamed_test and std.mem.indexOf(u8, t.name, f) == null) { + continue; + } + } + + const friendly_name = blk: { + const name = t.name; + var it = std.mem.splitScalar(u8, name, '.'); + while (it.next()) |value| { + if (std.mem.eql(u8, value, "test")) { + const rest = it.rest(); + break :blk if (rest.len > 0) rest else name; + } + } + break :blk name; + }; + + current_test = friendly_name; + std.testing.allocator_instance = .{}; + const result = t.func(); + current_test = null; + + const ns_taken = slowest.endTiming(friendly_name); + + if (std.testing.allocator_instance.deinit() == .leak) { + leak += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - Memory Leak\n{s}\n", .{ BORDER, friendly_name, BORDER }); + } + + if (result) |_| { + pass += 1; + } else |err| switch (err) { + error.SkipZigTest => { + skip += 1; + status = .skip; + }, + else => { + status = .fail; + fail += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - {s}\n{s}\n", .{ BORDER, friendly_name, @errorName(err), BORDER }); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + if (env.fail_first) { + break; + } + }, + } + + if (env.verbose) { + const ms = @as(f64, @floatFromInt(ns_taken)) / 1_000_000.0; + Printer.status(status, "{s} ({d:.2}ms)\n", .{ friendly_name, ms }); + } else { + Printer.status(status, ".", .{}); + } + } + + for (builtin.test_functions) |t| { + if (isTeardown(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nteardown \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + const total_tests = pass + fail; + const status = if (fail == 0) Status.pass else Status.fail; + Printer.status(status, "\n{d} of {d} test{s} passed\n", .{ pass, total_tests, if (total_tests != 1) "s" else "" }); + if (skip > 0) { + Printer.status(.skip, "{d} test{s} skipped\n", .{ skip, if (skip != 1) "s" else "" }); + } + if (leak > 0) { + Printer.status(.fail, "{d} test{s} leaked\n", .{ leak, if (leak != 1) "s" else "" }); + } + Printer.fmt("\n", .{}); + try slowest.display(); + Printer.fmt("\n", .{}); + std.posix.exit(if (fail == 0) 0 else 1); +} + +const Printer = struct { + fn fmt(comptime format: []const u8, args: anytype) void { + std.debug.print(format, args); + } + + fn status(s: Status, comptime format: []const u8, args: anytype) void { + switch (s) { + .pass => std.debug.print("\x1b[32m", .{}), + .fail => std.debug.print("\x1b[31m", .{}), + .skip => std.debug.print("\x1b[33m", .{}), + else => {}, + } + std.debug.print(format ++ "\x1b[0m", args); + } +}; + +const Status = enum { + pass, + fail, + skip, + text, +}; + +const SlowTracker = struct { + const SlowestQueue = std.PriorityDequeue(TestInfo, void, compareTiming); + max: usize, + slowest: SlowestQueue, + timer: std.time.Timer, + + fn init(allocator: Allocator, count: u32) SlowTracker { + const timer = std.time.Timer.start() catch @panic("failed to start timer"); + var slowest = SlowestQueue.init(allocator, {}); + slowest.ensureTotalCapacity(count) catch @panic("OOM"); + return .{ + .max = count, + .timer = timer, + .slowest = slowest, + }; + } + + const TestInfo = struct { + ns: u64, + name: []const u8, + }; + + fn deinit(self: SlowTracker) void { + self.slowest.deinit(); + } + + fn startTiming(self: *SlowTracker) void { + self.timer.reset(); + } + + fn endTiming(self: *SlowTracker, test_name: []const u8) u64 { + var timer = self.timer; + const ns = timer.lap(); + + var slowest = &self.slowest; + + if (slowest.count() < self.max) { + // Capacity is fixed to the # of slow tests we want to track + // If we've tracked fewer tests than this capacity, than always add + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + { + // Optimization to avoid shifting the dequeue for the common case + // where the test isn't one of our slowest. + const fastest_of_the_slow = slowest.peekMin() orelse unreachable; + if (fastest_of_the_slow.ns > ns) { + // the test was faster than our fastest slow test, don't add + return ns; + } + } + + // the previous fastest of our slow tests, has been pushed off. + _ = slowest.removeMin(); + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + fn display(self: *SlowTracker) !void { + var slowest = self.slowest; + const count = slowest.count(); + Printer.fmt("Slowest {d} test{s}: \n", .{ count, if (count != 1) "s" else "" }); + while (slowest.removeMinOrNull()) |info| { + const ms = @as(f64, @floatFromInt(info.ns)) / 1_000_000.0; + Printer.fmt(" {d:.2}ms\t{s}\n", .{ ms, info.name }); + } + } + + fn compareTiming(context: void, a: TestInfo, b: TestInfo) std.math.Order { + _ = context; + return std.math.order(a.ns, b.ns); + } +}; + +const Env = struct { + verbose: bool, + fail_first: bool, + filter: ?[]const u8, + + fn init(allocator: Allocator) Env { + return .{ + .verbose = readEnvBool(allocator, "TEST_VERBOSE", true), + .fail_first = readEnvBool(allocator, "TEST_FAIL_FIRST", false), + .filter = readEnv(allocator, "TEST_FILTER"), + }; + } + + fn deinit(self: Env, allocator: Allocator) void { + if (self.filter) |f| { + allocator.free(f); + } + } + + fn readEnv(allocator: Allocator, key: []const u8) ?[]const u8 { + const v = std.process.getEnvVarOwned(allocator, key) catch |err| { + if (err == error.EnvironmentVariableNotFound) { + return null; + } + std.log.warn("failed to get env var {s} due to err {}", .{ key, err }); + return null; + }; + return v; + } + + fn readEnvBool(allocator: Allocator, key: []const u8, deflt: bool) bool { + const value = readEnv(allocator, key) orelse return deflt; + defer allocator.free(value); + return std.ascii.eqlIgnoreCase(value, "true"); + } +}; + +pub const panic = std.debug.FullPanic(struct { + pub fn panicFn(msg: []const u8, first_trace_addr: ?usize) noreturn { + if (current_test) |ct| { + std.debug.print("\x1b[31m{s}\npanic running \"{s}\"\n{s}\x1b[0m\n", .{ BORDER, ct, BORDER }); + } + std.debug.defaultPanic(msg, first_trace_addr); + } +}.panicFn); + +fn isUnnamed(t: std.builtin.TestFn) bool { + const marker = ".test_"; + const test_name = t.name; + const index = std.mem.indexOf(u8, test_name, marker) orelse return false; + _ = std.fmt.parseInt(u32, test_name[index + marker.len ..], 10) catch return false; + return true; +} + +fn isSetup(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:beforeAll"); +} + +fn isTeardown(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:afterAll"); +} diff --git a/zig/pg/.gitignore b/zig/pg/.gitignore new file mode 100644 index 0000000..9499047 --- /dev/null +++ b/zig/pg/.gitignore @@ -0,0 +1,4 @@ +zig-out/ +.zig-cache/ +tests/server.key +tests/server.crt diff --git a/zig/pg/LICENSE b/zig/pg/LICENSE new file mode 100644 index 0000000..011dfad --- /dev/null +++ b/zig/pg/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2024 Karl Seguin. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/zig/pg/Makefile b/zig/pg/Makefile new file mode 100644 index 0000000..ed60dd1 --- /dev/null +++ b/zig/pg/Makefile @@ -0,0 +1,20 @@ +F= + +.PHONY: t +t: + TEST_FILTER="${F}" zig build test --summary all -freference-trace + +.PHONY: d +d: + cd tests && docker compose up + +.PHONY: ssl +ssl: + openssl req -days 3650 -new -text -nodes -subj '/C=SG/ST=SG/L=SG/O=Personal/OU=Personal/CN=localhost' -keyout tests/server.key -out tests/server.csr + openssl req -days 3650 -x509 -text -in tests/server.csr -key tests/server.key -out tests/server.crt + rm tests/server.csr + cp tests/server.crt tests/root.crt + + openssl req -days 3650 -new -nodes -subj '/C=SG/ST=SG/L=SG/O=Personal/OU=Personal/CN=localhost/CN=testclient1' -keyout tests/client.key -out tests/client.csr + openssl x509 -days 3650 -req -CAcreateserial -in tests/client.csr -CA tests/root.crt -CAkey tests/server.key -out tests/client.crt + rm tests/client.csr diff --git a/zig/pg/build.zig b/zig/pg/build.zig new file mode 100644 index 0000000..3994aa3 --- /dev/null +++ b/zig/pg/build.zig @@ -0,0 +1,98 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) !void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + // setup our dependencies + const dep_opts = .{ .target = target, .optimize = optimize }; + + // Expose this as a module that others can import + const pg_module = b.addModule("pg", .{ + .target = target, + .optimize = optimize, + .root_source_file = b.path("src/pg.zig"), + .imports = &.{ + .{ .name = "buffer", .module = b.dependency("buffer", dep_opts).module("buffer") }, + .{ .name = "metrics", .module = b.dependency("metrics", dep_opts).module("metrics") }, + }, + }); + + var openssl = false; + const openssl_lib_name = b.option([]const u8, "openssl_lib_name", ""); + const openssl_lib_path = b.option(std.Build.LazyPath, "openssl_lib_path", ""); + const openssl_include_path = b.option(std.Build.LazyPath, "openssl_include_path", ""); + + if (openssl_include_path) |p| { + openssl = true; + pg_module.addIncludePath(p); + } + if (openssl_lib_path) |p| { + openssl = true; + pg_module.addLibraryPath(p); + } + if (openssl_lib_name != null) { + openssl = true; + } + + if (openssl) { + pg_module.linkSystemLibrary("crypto", .{}); + pg_module.linkSystemLibrary(openssl_lib_name orelse "ssl", .{}); + pg_module.link_libc = true; + } + + var column_names = false; + const column_names_opt = b.option(bool, "column_names", ""); + + if (column_names_opt) |val| { + column_names = val; + } + + // -Diouring=true switches the transport to a per-connection + // io_uring ring on Linux. No-op on other platforms. + const iouring = b.option(bool, "iouring", "Use io_uring transport on Linux") orelse false; + + { + const options = b.addOptions(); + options.addOption(bool, "openssl", openssl); + options.addOption(bool, "column_names", column_names); + options.addOption(bool, "iouring", iouring); + pg_module.addOptions("config", options); + } + + { + // test step + const lib_test = b.addTest(.{ + .root_module = b.createModule(.{ + .target = target, + .optimize = optimize, + .root_source_file = b.path("src/pg.zig"), + .imports = &.{ + .{ .name = "buffer", .module = b.dependency("buffer", dep_opts).module("buffer") }, + .{ .name = "metrics", .module = b.dependency("metrics", dep_opts).module("metrics") }, + }, + }), + .test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple }, + }); + if (openssl_lib_path) |p| + lib_test.root_module.addLibraryPath(p); + if (openssl_include_path) |p| + lib_test.root_module.addIncludePath(p); + lib_test.root_module.linkSystemLibrary("crypto", .{}); + lib_test.root_module.linkSystemLibrary("ssl", .{}); + + { + const options = b.addOptions(); + options.addOption(bool, "openssl", true); + options.addOption(bool, "column_names", false); + options.addOption(bool, "iouring", false); + lib_test.root_module.addOptions("config", options); + } + + const run_test = b.addRunArtifact(lib_test); + run_test.has_side_effects = true; + + const test_step = b.step("test", "Run unit tests"); + test_step.dependOn(&run_test.step); + } +} diff --git a/zig/pg/build.zig.zon b/zig/pg/build.zig.zon new file mode 100644 index 0000000..e035955 --- /dev/null +++ b/zig/pg/build.zig.zon @@ -0,0 +1,17 @@ +.{ + .name = .pg, + .paths = .{""}, + .version = "0.0.0", + .fingerprint = 0xbd309ff281fb9f5a, + .dependencies = .{ + .buffer = .{ + // Vendored copy (identical to karlseguin/buffer.zig @ 30f9512f). + .path = "../pg-deps-buffer", + }, + .metrics = .{ + // Vendored copy of karlseguin/metrics.zig @ 13d8706e with a small + // Zig 0.16 compat fix: @Type(.{.int=...}) -> @Int(bits, signedness). + .path = "../pg-deps-metrics", + }, + }, +} diff --git a/zig/pg/example/build.zig b/zig/pg/example/build.zig new file mode 100644 index 0000000..c7f234c --- /dev/null +++ b/zig/pg/example/build.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const pg_module = b.dependency("pg", .{}).module("pg"); + + const exe = b.addExecutable(.{ + .name = "example", + .root_module = b.createModule(.{ + .root_source_file = b.path("main.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "pg", .module = pg_module }, + }, + }), + }); + + b.installArtifact(exe); + + const run_cmd = b.addRunArtifact(exe); + run_cmd.step.dependOn(b.getInstallStep()); + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); +} diff --git a/zig/pg/example/build.zig.zon b/zig/pg/example/build.zig.zon new file mode 100644 index 0000000..68a2241 --- /dev/null +++ b/zig/pg/example/build.zig.zon @@ -0,0 +1,11 @@ +.{ + .name = .example, + .paths = .{""}, + .version = "0.0.0", + .fingerprint = 0x6eec9b9fcafa54bf, + .dependencies = .{ + .pg = .{ + .path = ".." + }, + }, +} diff --git a/zig/pg/example/main.zig b/zig/pg/example/main.zig new file mode 100644 index 0000000..7f0f501 --- /dev/null +++ b/zig/pg/example/main.zig @@ -0,0 +1,271 @@ +const std = @import("std"); +const builtin = @import("builtin"); + +const pg = @import("pg"); + +pub const log = std.log.scoped(.example); + +pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = if (builtin.mode == .Debug) gpa.allocator() else std.heap.c_allocator; + + // While a connection can be created directly, pools should be used in most + // cases. The pool's `acquire` method, to get a connection is thread-safe. + // The pool may start 1 background thread to reconnect disconnected + // connections (or connections in an invalid state). + var pool = pg.Pool.init(allocator, .{ + .size = 5, + .connect = .{ + .port = 5432, + .host = "127.0.0.1", + }, + .auth = .{ + .username = "postgres", + .database = "postgres", + .timeout = 10_000, + } + }) catch |err| { + log.err("Failed to connect: {}", .{err}); + std.posix.exit(1); + }; + defer pool.deinit(); + + + + // One-off commands can be executed directly using the pool using the + // exec, execOpts, query, queryOpts, row, rowOpts functions. But, due to + // Zig's lack of error payloads, if these fail, you won't be able to retrieve + // a more detailed error + _ = try pool.exec("drop table if exists pg_example_users", .{}); + + + // We're using a block to scope the defer conn.release(). In your own code + // the scope might naturally be a function or if block. Remember that Zig's + // defer thankfully executes at the end of the block (unlike Go where defer + // executes at the end of the function). + { + // You can acquire/release connections from the pool. Ideal if you want + // to execute multiple statements and also exposes a more detailed error. + var conn = try pool.acquire(); + defer conn.release(); + + // exec returns the # of rows affected for insert/select/delete + _ = conn.exec("create table pg_example_users (id integer, name text)", .{}) catch |err| { + if (conn.err) |pg_err| { + // conn.err is an optional PostgreSQL error. It has many fields, + // many of which are nullable, but the `message`, `code` and + // `severity` are always present. + log.err("create failure: {s}", .{pg_err.message}); + } + return err; + }; + } + + + // Of course, exec can take parameters: + _ = try pool.exec( + "insert into pg_example_users (id, name) values ($1, $2), ($3, $4)", + .{1, "Leto", 2, "Ghanima"} + ); + + + { + log.info("Example 1", .{}); + // we can fetch a single row: + var conn = try pool.acquire(); + defer conn.release(); + + var row = (try conn.row("select name from pg_example_users where id = $1", .{1})) orelse unreachable; + // having to deal with row.deinit() is an unfortunate consequence of the + // conn.row and pool.row API. Sorry! + defer row.deinit() catch {}; + + // name will become invalid after row.deinit() is called. dupe it if you + // need it to live longer. + const name = try row.get([]const u8, 0); + log.info("User 1: {s}", .{name}); + } + + { + log.info("\n\nExample 2", .{}); + // or we can fetch multiple rows: + var conn = try pool.acquire(); + defer conn.release(); + + var result = try conn.query("select * from pg_example_users order by id", .{}); + defer result.deinit(); + + while (try result.next()) |row| { + const id = try row.get(i32, 0); + // string values are only valid until the next call to next() + // dupe the value if needed + const name = try row.get([]const u8, 1); + log.info("User {d}: {s}", .{id, name}); + } + } + + { + log.info("\n\nExample 3", .{}); + // pgz uses a configurable read and write buffer to communicate with Postgresql. + // Larger messages require dynamic allocation. By default this uses the allocator + // given to `Pool.init` or `Conn.open`. A different allocator can be specified + // on a per-query basis. For example, if you're using an HTTP framework that + // gives you a per-request arena, you could use that arena + var conn = try pool.acquire(); + defer conn.release(); + + // Because we pass this allocator to queryOpts, *IF* pg needs to allocate + // to read the response, it'll use this allocator. + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + // use queryOpts, execOpts or rowOpts when specifying optional parameters + var result = try conn.queryOpts("select * from pg_example_users order by id", .{}, .{.allocator = arena.allocator()}); + defer result.deinit(); + + while (try result.next()) |row| { + const id = try row.get(i32, 0); + // string values are only valid until the next call to next() + // dupe the value if needed + const name = try row.get([]const u8, 1); + log.info("User {d}: {s}", .{id, name}); + } + } + + { + log.info("\n\nExample 4", .{}); + // We can bind and fetch arrays. This simple statement showcases both: + var row = (try pool.row("select $1::bool[]", .{[_]bool{true, false, false}})) orelse unreachable; + + // again, sorry that row.deinit() can error. + defer row.deinit() catch {}; + + var it = try row.get(pg.Iterator(bool), 0); + while (it.next()) |value| { + log.info("{any}", .{value}); + } + + // There's a `alloc` helper on the iterator too, to turn it into a slice + // const values = try it.alloc(allocator); + // defer allocator.free(values); + } + + { + log.info("\n\nExample 5", .{}); + // Instead of `row.get`, you can use `row.getCol` to get a column by name + // But to work, you must tell pgz to load the column_names + // exec, query and row all have variants that take an option: + // execOpts, queryOpts and rowOpts + var row = (try pool.rowOpts("select $1 as name", .{"teg"}, .{.column_names = true})) orelse unreachable; + defer row.deinit() catch {}; + + log.info("{s}", .{try row.getCol([]const u8, "name")}); + } + + { + log.info("\n\nExample 6", .{}); + // There's a cost to looking up a value by name. If you're going to do + // it in a loop, consider storing the column_index in a variable: + + var conn = try pool.acquire(); + defer conn.release(); + + // again, we have to tell pg to load the column names + var result = try conn.queryOpts( + \\ select $1 as id, now() as time + \\ union all + \\ select $2, now() + interval '1 hour' + , .{"25ed0ed1-a35b-41a0-a6bd-89ddb3b8b716", "e2242aa2-db4e-4dd5-8677-76bc19b9e0f5"}, .{.column_names = true}, + ); + defer result.deinit(); + + const id_index = result.columnIndex("id").?; + const time_index = result.columnIndex("time").?; + while (try result.next()) |row| { + const id = try row.get([]const u8, id_index); + const unix_micro = try row.get(i64, time_index); + log.info("{s} {d}", .{id, unix_micro}); + } + } + + { + // There is basic functionality for turning a row into a struct. + const User = struct { + name: []const u8, + power: i32, + }; + + { + log.info("\n\nExample 7", .{}); + // There are many things to be aware of. By default, the field order + // and column order is used. By default, any non-primitive value is only + // valid until the next call to `next()` or `deinit(); + var row = (try pool.row("select 'Goku', 9001", .{})) orelse unreachable; + defer row.deinit() catch {}; + + const user = try row.to(User, .{}); + log.info("{s} {d}", .{user.name, user.power}); + } + + { + log.info("\n\nExample 8", .{}); + // We can match by name instead, and provide an allocator to dupe + // any string/iterator. This allows values to live beyond the next + // call to next/deinit, but we must free the values. Consider using + // an arena to more easily manage allocated memory. + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + var row = (try pool.rowOpts("select 4000 as power, 'Vegeta' as name", .{}, .{.column_names = true})) orelse unreachable; + defer row.deinit() catch {}; + + const user = try row.to(User, .{.map = .name, .allocator = arena.allocator()}); + log.info("{s} {d}", .{user.name, user.power}); + } + + { + log.info("\n\nExample 9", .{}); + // As always, there's overhead to doing name lookups. If you're going + // to map many rows into structures, consider creating a mapper. + // Also, here rather than providing our own allocator, we're telling + // pg to use its own result-scoped arena. This has pros and cons. + // The pro is that it's simple and values will exist beyond each + // call to next() and will be cleaned up by deinit(). + // The con is that a lot of memory might accumulate in the arena. + + var result = try pool.queryOpts( + \\ select 4000 as power, 'Vegeta' as name + \\ union all + \\ select 9001, 'Goku' + , .{}, .{.column_names = true}); + defer result.deinit(); + + // dupe = true tells the mapper to dupe values using the + // internal result arena + var mapper = result.mapper(User, .{.dupe = true}); + while (try mapper.next()) |user| { + log.info("{s} {d}", .{user.name, user.power}); + } + } + + { + log.info("\n\nExample 10", .{}); + // unsafe operations avoid some runtime checks, but can reach + // "unreachable" if the types are wrong. + var conn = try pool.acquire(); + defer conn.release(); + + var result = try conn.query("select * from pg_example_users order by id", .{}); + defer result.deinit(); + + // notice the `nextUsafe` + while (try result.nextUnsafe()) |row| { + // unlike the "safe" variant, there's no "try" here. This assumes + // you're 100% sure column 0 is an i32 and column 1 is a string + const id = row.get(i32, 0); + const name = row.get([]const u8, 1); + log.info("User {d}: {s}", .{id, name}); + } + } + } +} diff --git a/zig/pg/readme.md b/zig/pg/readme.md new file mode 100644 index 0000000..1575e71 --- /dev/null +++ b/zig/pg/readme.md @@ -0,0 +1,695 @@ +# Native PostgreSQL driver for Zig + +A native PostgresSQL driver / client for Zig. Supports [LISTEN](#listen--notify). + +See or run [example/main.zig](https://github.com/karlseguin/pg.zig/blob/master/example/main.zig) for a number of examples. + +## Install +1) Add pg.zig as a dependency in your `build.zig.zon`: + +```bash +zig fetch --save git+https://github.com/karlseguin/pg.zig#master +``` + +2) In your `build.zig`, add the `pg` module as a dependency to your program: + +```zig +const pg_module = b.dependency("pg", .{}).module("pg"); + +// the executable from your executable/library +const exe = b.addExecutable(.{ + .name = "example", + ... + .imports = &.{ + .{ .name = "pg", .module = pg_module }, + }, +}); +``` + +## Example +```zig +var pool = try pg.Pool.init(allocator, .{ + .size = 5, + .connect = .{ + .port = 5432, + .host = "127.0.0.1", + }, + .auth = .{ + .username = "postgres", + .database = "postgres", + .password = "postgres", + .timeout = 10_000, + } +}); +defer pool.deinit(); + +var result = try pool.query("select id, name from users where power > $1", .{9000}); +defer result.deinit(); + +while (try result.next()) |row| { + const id = try row.get(i32, 0); + // this is only valid until the next call to next(), deinit() or drain() + const name = try row.get([]u8, 1); +} +``` + +## Pool +The pool keeps a configured number of database connection open. The `acquire()` method is used to retrieve a connection from the pool. The pool may start one background thread to attempt to reconnect disconnected connections (or connections which are in an invalid state). + +### init(allocator: std.mem.allocator, opts: Opts) !*Pool +Initializes a connection pool. Pool options are: + +* `size` - Number of connections to maintain. Defaults to `10`. +* `auth`: - See [Conn.auth](#authopts-opts-void) +* `connect`: - See the [Conn.open](#openallocator-stdmemallocator-opts-opts-conn). +* `timeout`: - The amount of time, in milliseconds, to wait for a connection to be available when `acquire()` is called. +* `connect_on_init_count`: - The # of connections in the pool to eagerly connect during `init`. Defaults to `null` which will initiliaze all connections (`size`). The background reconnector is used to setup the remaining (`size - connect_on_init_count`) connections. This can be set to `0`, to prevent `init` from failing except in extreme cases (i.e. OOM), but that will hide any configuration/connection issue until the first query is executed. + +### initUri(allocator: std.mem.Allocator, uri: std.Uri, opts: Opts) !*Pool +Initializes a connection pool using a std.Uri. When using this function, the `auth` and `connect` fields of `opts` should **not** be set, as these will automatically set based on the provided `uri`. + +```zig +const uri = try std.Uri.parse("postgresql://username:password@localhost:5432/database_name"); +const pool = try pg.Pool.initUri(allocator, uri, 5, 10_000); +defer pool.deinit(); +``` + +### acquire() !\*Conn +Returns a [\*Conn](#conn) for the connection pool. Returns an `error.Timeout` if the connection cannot be acquired (i.e. if the pool remains empty) for the `timeout` configuration passed to `init`. + + +```zig +const conn = try pool.acquire(); +defer pool.release(conn); +_ = try conn.exec("...", .{...}); +``` + +### release(conn: \*Conn) void +Releases the conection back into the pool. Calling `pool.release(conn)` is the same as calling `conn.release()`. + +### newListener() !Listener +Returns a new [Listener](#listen--notify). This function creates a new connection, it does not use/acquire a connection from the pool. It is a convenience function for cases which have already setup a pool (with the connection and authentication configuration) and want to create a listening connection using those settings. + +### exec / query / queryOpts / row / rowOpts +For single-query operations, the pool offers wrappers around the connection's `exec`, `query`, `queryOpts`, `row` and `rowOpts` methods. These are convenience methods. + +`pool.exec` acquires, executes and releases the connection. + +`pool.query` and `pool.queryOpts` acquire and execute the query. The connection is automatically returned to the pool when `result.deinit()` is called. Note that this is a special behavior of `pool.query`. When the result comes explicitly from a `conn.query`, `result.deinit()` does not automatically release the connection back into the pool. + +`pool.row` and `pool.rowOpts` acquire and execute the query. The connection is automatically returned to the pool when `row.deinit()` is called. Note that this is a special behavior of `pool.row`. When the result comes explicitly from a `conn.row`, `row.deinit()` does not automatically release the connection back into the pool. + +## Conn + +### open(allocator: std.mem.Allocator, opts: Opts) !Conn +Opens a connection, or returns an error. Prefer creating connections through the pool. Connection options are: + +* `host` - Defaults to `"127.0.0.1"` +* `port` - Defaults to `5432` +* `write_buffer` - Size of the write buffer, used when sending messages to the server. Will temporarily allocate more space as needed. If you're writing large SQL or have large parameters (e.g. long text values), making this larger might improve performance a little. Defaults to `2048`, cannot be less than `128`. +* `read_buffer` - Size of the read buffer, used when reading data from the server. Will temporarily allocate more space as needed. Given most apps are going to be reading rows of data, this can have large impact on performance. Defaults to `4096`. +* `result_state_size` - Each `Result` (retrieved via a call to `query`) carries metadata about the data (e.g. the type of each column). For results with less than or equal to `result_state_size` columns, a static `state` container is used. Queries with more columns require a dynamic allocation. Defaults to `32`. + +### deinit(conn: \*Conn) void +Closes the connection and releases its resources. This method should not be used when the connection comes from the pool. + +### auth(opts: Opts) !void +Authentications the request. Prefer creating connections through the pool. Auth options are: + +* `username`: Defaults to `"postgres"` +* `password`: Defaults to `null` +* `database`: Defaults to `null` +* `timeout` : Defaults to `10_000` (milliseconds) +* `application_name`: Defaults to `null` +* `params`: Defaults to `null`. An `std.StringHashMap([]const u8)` + +### release(conn: \*Conn) void +Releases the connection back to the pool. The pool might decide to close the connection and open a new one. + +### exec(sql: []const u8, args: anytype) !?usize +Executes the query with arguments, returns the number of rows affected, or null. Should not be used with a query that returns rows. + +### query(sql: []const u8, args: anytype) !Result +Executes the query with arguments, returns [Result](#result). `deinit`, and possibly `drain`, must be called on the returned `result`. + +### queryOpts(sql: []const u8, args: anytype, opts: Conn.QueryOpts) !Result +Same as `query` but takes options: + +- `timeout: ?u32` - This is not reliable and should probably not be used. Currently it simply puts a recv socket timeout. On timeout, the connection will likely no longer be valid (which the pool will detect and handle when the connection is released) and the underlying query will likely still execute. Defaults to `null`. +- `column_names: bool` - Whether or not the `result.column_names` should be populated. When true, this requires memory allocation (duping the column names). Defaults to `false` unless the `column_names` build option was set to true. +- `allocator` - The allocator to use for any allocations needed when executing the query and reading the results. When `null` this will default to the connection's allocator. If you were executing a query in a web-request and each web-request had its own arena tied to the lifetime of the request, it might make sense to use that arena. Defaults to `null`. +- `release_conn: bool` - Whether or not to call `conn.release()` when `result.deinit()` is called. Useful for writing a function that acquires a connection from a `Pool` and returns a `Result`. When `query` or `row` are called from a `Pool` this is forced to `true`. Otherwise, defaults to `false`. + +### row(sql: []const u8, args: anytype) !?QueryRow +Executes the query with arguments, returns a single row. Returns an error if the query returns more than one row. Returns `null` if the query returns no row. `deinit` must be called on the returned `QueryRow`. + +### rowOpts(sql: []const u8, args: anytype, opts: Conn.QueryOpts) !Result +Same as `row` but takes the same options as `queryOpts`. + +### prepare(sql: []const u8) !Stmt +Creates a [Stmt](#stmt). It is generally better to use `query`, `row` or `exec`. + +### prepareOpts(sql: []const u8, opts: Conn.QueryOpts) !Stmt +Same as `prepare` but takes the same options as `queryOpts`. + +### begin() !void +Calls `_ = try execOpts("begin", .{}, .{})` + +### commit() !void +Calls `_ = try execOpts("commit", .{}, .{})` + +### rollback() !void +Calls `_ = try execOpts("rollback", .{}, .{})` + +## Result +The `conn.query` and `conn.queryOpts` methods return a `pg.Result` which is used to read rows and values. + +### Fields +* `number_of_columns: usize` - Number of columns in the result. +* `column_names: [][]const u8` - Names of the column, empty unless the query was executed with the `column_names = true` option or the `column_names` build option was set to true. + +### deinit(result: \*Result) void +Releases resources associated with the result. + +### drain(result: \*Result) !void +If you do not iterate through the result until `next` returns `null`, you must call `drain`. + +Why can't `deinit` handle this? If `deinit` also drained, you'd have to handle a possible error in `deinit` and you can't `try` in a defer. Thus, this is done to provide better ergonomics for the normal case - the normal case being where `next` is called until it returns `null`. In these cases, just `defer result.deinit()`. + +### next(result: \*Result) !?Row +Iterates to the next row of the result, or returns null if there are no more rows. + +### columnIndex(result: \*Result, name: []const u8) ?usize +Returns the index of the column with the given name. This is only valid when the query is executed with the `column_names = true` option or the `column_names` build option was set to true. + +### mapper(result: \*Result, T: type, opts: MapperOpts) Mapper(T) +Returns a Mapper which can be used to create a T for each row. Mapping from column to field is done by name. This is an optimized version of [row.to](#tot-type-opts-toopts-t) when iterating through multiple rows with the `{.map = .name}`. + +See [row.to](#tot-type-opts-toopts-t) and [Mapper](#mapper) for more information. + +## Row +The `row` represents a single row from a result. Any non-primitive value that you get from the `row` are valid only until the next call to `next`, `deinit` or `drain`. + +### Fields +Only advance usage will need access to the row fields: + +* `oids: []i32` - The PG OID value for each column in the row. See `result.number_of_columns` for the length of this slice. Might be useful if you're trying to read a non-natively supported type. +* `values: []Value` - The underlying byte value for each column in the row. See `result.number_of_columns` for the length of this slice. Might be useful if you're trying to read a non-natively supported type. Has two fields, `is_null: bool` and `data: []const u8`. + +### get(comptime T: type, col: usize) !T +Gets a value from the row at the specified column index (0-based). **Type mapping is strict.** For example, you **cannot** use `i32` to read an `smallint` column. + +For any supported type, you can use an optional instead. Therefore, if you use `row.get(i16, 0)` the return type is `i16`. If you use `row.get(?i16, 0)` the return type is `!?i16`. + +* `u8` - `char` +* `i16` - `smallint` +* `i32` - `int` +* `i64` - Depends on the underlying column type. A `timestamp(tz)` will be converted to microseconds since unix epoch. Otherwise, a `bigint`. +* `f32` - `float4` +* `f64` - Depends on the underlying column type. A `numeric` will be converted to an `f64`. Otherwise, a `float`. +* `bool` - `bool` +* `[]const u8` - Returns the raw underlying data. Can be used for any column type to get the PG-encoded value. For `text` and `bytea` columns, this will be the expected value. For `numeric`, this will be a text representation of the number. For `UUID` this will be a 16-byte slice (use `pg.uuidToHex [36]u8` if you want a hex-encoded UUID). For `JSON` and `JSONB` this will be the serialized JSON value. +* `[]u8` - Same as []const u8 but returns a mutable value. +* `pg.Numeric` - See numeric section. +* `pg.Cidr` - See CIDR/INET section. + +### getCol(comptime T: type, column_name: []const u8) !T +Same as `get` but uses the column name rather than its position. Only valid when the `column_names = true` option is passed to `queryOpts` or the `column_names` build option was set to true. + +This relies on calling `result.columnIndex` which iterates through `result.column_names` fields. In some cases, this is more efficient than `StringHashMap` lookup, in others, it is worse. For performance-sensitive code, prefer using `get`, or cache the column index in a local variables outside of the `next()` loop: + +```zig +const id_idx = result.columnIndex("id").? +while (try result.next()) |row| { + // try row.get(i32, id_idx) +} +``` + +### Array Columns +Use `row.get(pg.Iterator(i32))` to return an [!Iterator](#iteratort) over an array column. Supported array types are: + +* `u8` and `?u8` - `char[]` +* `i16` and `?i16` - `smallint[]` +* `i32` and `?i32` - `int[]` +* `i64` and `?i64` - `bigint[]` or `timestamp(tz)[]` (see `get`) +* `f32` and `?f32` - `float4` +* `f64` and `?f64` - `float8` +* `bool` and `?bool` - `bool[]` +* `[]const u8` and `[]?const u8` - More strict than `get([]u8)`. Supports: `text[]`, `char(n)[]`, `bytea[]`, `uuid[]`, `json[]` and `jsonb[]`. +* `[]u8` - Same as `[]const u8` but returns mutable value. +* `pg.Numeric` - See numeric section. +* `pg.Cidr` - See CIDR/INET section. + +### record(col: usize) Record +Gets a [Record](#record) by column position. + +### recordCol(column_name: []const u8) Record +Gets an [Record](#record) by column name. See [getCol](#getcolcomptime-t-type-column_name-const-u8-t) for performance notes. + +### to(T: type, opts: ToOpts) !T +Populates and returns a `T`. + +`opts` values are: +* `dupe` - Duplicate string columns using the internal arena. When set to `true` non-scalar values are valid until `deinit` is called on the `row`/`result`. Defaults to `false`. +* `allocator` - Allocator to use to duplicate non-scalar values (i.e. strings). It is the caller's responsibility to free any non-scalar values from their structure. Defaults to `null`. +* `map` - `.ordinal` or `.name`, defaults to `.ordinal`. + +Setting `allocator` implies `dupe`, but uses the specified allocator rather than the internal arena. By default (when `dupe` is `false` and `allocator` is `null`), non-scalar values (i.e. strings) are only valid until the next call to `next()` or `drain()` or `deinit()`. + +When `.map = .ordinal`, the default, the order of the field names must match the order of the columns. + +When `.map = .name`, the query must be executed with the `{.column_names = true}` option or the `column_names` build option must be set. Columns with no field equivalent are ignored. Fields with no column equivalent are set to their default value; if they do not have a default value the function will return `error.FieldColumnMismatch`. If you're going to use this in a loop with a `result`, consider using a [Mapper](#mapper) to avoid the name->index lookup on each iteration. + +Slice fields can either be mapped to a `pg.Iterator(T)` or a slice. When mapped to a `slice`, an allocator MUST be provided. When mapping to an array of strings (i.e. [][]const u8), the values are duped, and thus both the values and the slice itself must be freed. When mapping to a slice of primitives (i.e. []i32) the slice must be freed. When mapping to an `pg.Iterator(T)` with a custom allocator (`.{.allocator = allocator}`), the iterator must be freed by calling `iteartor.deinit(allocator)`. Whether you're mapping to an `pg.Iterator(T)` or a slice, I Strongly suggest you use an ArenaAllocator. + +## QueryRow +A `QueryRow` is returned from a call to `conn.row` or `conn.rowOpts` and wraps both a `Result` and a `Row.` It exposes the same methods as `Row` as well as `deinit`, which must be called once the `QueryRow` is no longer needed. This is a rare case where `deinit()` can fail. In most cases, you can simply throw away the error (because failure is extremely rare and, if the connection came from a pool, it should repair itself). + +## Iterator(T) +The iterator returned from `row.get(pg.Iterator(T), col)` can be iterated using the `next() ?T` call: + +```zig +var names = try row.get(pg.Iterator([]const u8), 0); +while (names.next()) |name| { + ... +} +``` + +### Fields +* `len` - the number of values in the iterator. +* `is_null` - Whether the array itself was null. + +### alloc(it: Iterator(T), allocator: std.mem.Allocator) ![]T +Allocates a slice and populates it with all values. + +If the slice is a `[]u8` or `[]const u8`, the string is also duplicated. It is the responsibility of the caller to free the string values AND the slice. + +### fill(it: Iterator(T), into: []T) void +Fill `into` with values of the iterator. `into` can be smaller than `it.len`, in which case only `into.len` values will be filled. This can be a bit faster than calling `next()` multiple times. Values are not duplicated; they are only valid until the next iterations. + +## Record +Returned by `row.record(col)` for fetching a PostgreSQL record-type, for example from this query: + +```sql +select row('over', 9000) +``` + +In many cases, PostgreSQL will mark the inner-types as "unknown", which is likely to cause assertion failures in this library. The solution is to type each value: + +```sql +select row('over'::text, 9000::int) +``` + +### Fields +* `number_of_columns` - the number of columns in the record + +### next(T) T +Gets the next column in the record. This behaves similarly [row.get](#getcomptime-t-type-col-usize-t) with the same supported types for `T`, including nullables. + +## Mapper +A mapper is used to iterate through a result and turn a row into an instance of `T`. When converting a single row, or using ordinal mapping, prefer using [row.to](#tot-type-opts-toopts-t). The mapper is an optimization over `row.to` with the `{.map = .name}` option which only has to do the name -> index lookup once. + +To use a mapper, the `{.column_names = true}` option must be passed to the query/row function or the `column_names` build option must be set. + +```zig +const User = struct { + id: i32, + name: []const u8, +}; + +///... + +var result = try conn.queryOpts("select id, name from users", .{}, .{.column_names = true}); +defer result.deinit(); + +var mapper = result.mapper(User, .{}); +while (try mapper.next()) |user| { + // use: user.id and user.name +} +``` + +A column with no matching field is ignored. A field with no matching column is set to its default fault. If no default value is defined, `mapper.next()` will return `error.FieldColumnMismatch`. + +The 2nd argument to `result.mapper` is an option: + +* `dupe` - Duplicate string columns using the internal arena. When set to `true` non-scalar values are valid until `deinit` is called on the `row`/`result`. Defaults to `false` +* `allocator` - Allocator to use to duplicate non-scalar values (i.e. strings). It is the caller's responsibility to free any non-scalar values from their structure. Defaults to `null`. + +Setting `allocator` implies `dupe`, but uses the specified allocator rather than the internal arena. By default (when `dupe` is `false` and `allocator` is `null`), non-scalar values (i.e. strings) are only valid until the next call to `next()` or `drain()` or `deinit()`. + +## Stmt +For most queries, you should use the `conn.query(...)`, `conn.row(...)` or `conn.exec(...)` methods. For queries with parameters, these methods look like: + +```zig +var stmt = try Stmt.init(conn, opts); +errdefer stmt.deinit(); + +try stmt.prepare(sql, null); +inline for (parameters) |param| { + try stmt.bind(param); +} + +return stmt.execute(); +``` + +You can create a statement directly using `conn.prepare(sql)` or `conn.prepareOpts(sql, ConnQueryOpts{...})` and call `stmt.bind(value: anytype)` and `execute()` directly. + +The main reason to do this is to have more flexibility in binding parameters (e.g. such as when creating dynanmic SQL where all the parameters aren't fixed at compile-time). + +Note that `stmt.deinit()` should only be called if `stmt.execute()` is not called or returns an error. Once `stmt.execute()` returns a [Result](#result), `stmt` should be considered invalid. As we can see in the above example, `stmt.deinit()` is only called on `errdefer`. + +## Caching Prepared Statements +When you execute a statement with parameters, we first ask PostgreSQL to "parse" the statement (which creates an execution plan) and then describe it. We can then bind the parameters and execute the statement. + +If you plan on executing the same query(ies) repeatedly, it's possible to have PostgreSQL cache the execution plan and pg.zig cache the description. However, there are some caveats with this approach (which are not specific to pg.zig). First, if you're using a connection pooler (like pgpool or PgBouncer), make sure to read the documentation and configure it to work properly with cached prepared statements. Historically, cached prepared statements and connection poolers have not worked well together. + +Secondly, note that the cache is per-connection. If you have a pool of 50 connections, and you execute the query against connections from the pool, then you should expect the full parse -> describe -> bind -> execute flow for those 50 connections (plus whatever new connections the pool might open). + +Caching is enabled by passing the `cache_name` option: + +```zig +const result = try conn.queryOpts( + "select * from saiyans where power > $1", + .{9000}, + .{.cache_name = "super"} +); +``` + +Technically, once cached, the SQL statement is ignored. So, after executing the above, you could execute the following **on the same connection**: + +```zig +const result = try conn.queryOpts( + "this isn't valid SQL", + .{1000}, + .{.cache_name = "super"} +); +``` + +And it **will** work. But you're playing with fire, and you should just include the same SQL and the same cache name for each execution. If you want to use the `.column_names = true` option, then it _must_ be included in the first query which generated the cache entry (again, in short, just _always_ use the same SQL and the same options). + +You can call `try conn.deallocate("super")` to remove a cache entry. But this is only done for the connection on which it is called. This would make sense, for example, if you get a connection from the pool, execute the same query multiple times, deallocate the cached entry, and return the connection back to the pool. Note that the name to deallocate, `super`, is not sanitized and is open to SQL injection - don't pass a user-supplied value to `deallocte`. + +## Important Notice 1 - Bind vs Read +When you read a value, e.g. using `row.get`, the library is strict and won't help you with type conversion. If you're column is a smallint, you have to `get` +an `i16. + +Conversely, when binding a value to an SQL parameter, the library is a little more generous. For example, an `u64` will bind to an `i32` provided the value is within range. + +This is particularly relevant for types which are expressed as `[]u8`. For example a UUID can be a raw binary `[16]u8` or a hex-encoded `[36]u8`. Where possible (e.g. UUID, MacAddr, MacAddr8), the library will support binding either the raw binary data or text-representation. When reading, the raw binary value is always returned. + +## Important Notice 2 - Invalid Connections +Strongly consider using `pg.Pool` rather than using `pg.Conn` directly. The pool will attempt to reconnect disconnected connections or connections which are in an invalid state. Until more real world testing is done, you should assume that connections will get into invalid states. + +## Important Notice 3 - Errors +Zig errorsets do not support arbitrary payloads. This is problematic in a database driver where most applications probably care about the details of an error. The library takes a simple approach. If `error.PG` is returned, `conn.err` should be set and will contains a PG error object: + +```zig +_ = conn.exec("drop table x", .{}) catch |err| { + if (err == error.PG) { + if (conn.err) |pge| { + std.log.err("PG {s}\n", .{pge.message}); + } + } + return err; +}; +``` + +In the above snippet, it's possible to skip the `if (err == error.PG)` check, but in that case `conn.err` could be set from some previous command (`conn.err` is always reset when acquired from the pool). + +If `error.PG` is returned from a non-connection object, like a query result, the associated connection will have its `conn.err` set. In other words, `conn.err` is the only thing you ever have to check. + +A PG error always exposes the following fields: +* `code: []const u8` - https://www.postgresql.org/docs/current/errcodes-appendix.html +* `message: []const u8` +* `severity: []const u8` + +And optionally (depending on the error and the version of the server): +* `column: ?[]const u8 = null` +* `constraint: ?[]const u8 = null` +* `data_type_name: ?[]const u8 = null` +* `detail: ?[]const u8 = null` +* `file: ?[]const u8 = null` +* `hint: ?[]const u8 = null` +* `internal_position: ?[]const u8 = null` +* `internal_query: ?[]const u8 = null` +* `line: ?[]const u8 = null` +* `position: ?[]const u8 = null` +* `routine: ?[]const u8 = null` +* `schema: ?[]const u8 = null` +* `severity2: ?[]const u8 = null` +* `table: ?[]const u8 = null` +* `where: ?[]const u8 = null` + +The `isUnique() bool` method can be called on the error to determine whether or not the error was a unique violation (i.e. error code `23505`). + +## Unsafe Fast Mode +For raw performance and danger, you can use the `unsafe` variants of many functions, e.g. `pool.rowUnsafe()` and `result.nextUnsafe()`. These versions return an unsafe row which skips type checking: + +Safe: +```zig +while (try result.next()) |row| { + // notice the try + const id = try row.get(i32, 0); +} +``` + +Unsafe: + +Safe: +```zig +while (try result.nextUnsafe()) |row| { + // no try + const id = row.get(i32, 0); +} +``` + +If the types (or nullability) is wrong, in `Debug` and `ReleaseSafe` you'll get a failed assertion. In `ReleaseSmall` and `ReleaseFast` it' it's undefined behavior. This is dangerous, but can be useful when you're reading a very large number of rows in a loop. + +## Type Support +All implementations have to deal with things like: how to support unsigned integers, given that PostgreSQL only has signed integers. Or, how to support UUIDs when the language has no UUID type. This section documents the exact behavior. + +### Arrays +Multi-dimensional arrays aren't supported. The array lower bound is always 0 (or 1 in PG). + +### text, bool, bytea, char, char(n), custom enums +No surprises, arrays supported. + +When reading a `char[]`, it's tempting to use `row.get([]u8, 0)`, but this is incorrect. A `char[]` is an array, and thus `row.get(pg.Iterator(u8), 0)` must be used. + +### smallint, int, bigint +When binding an integer, the library will coerce the Zig value to the parameter type, as long as it fits. Thus, a `u64` can be bound to a `smallint`, if the value fits, else an error will be returned. + +Array binding is strict. For example, an `[]i16` must be used for a `smallint[]` parameter. The only exception is that the unsigned variant, e.g. `[]u16` can be used provided all values fit. + +When reading a column, you must use the correct type. + +### Floats +When binding, `@floatCast` is used based on the SQL parameter type. Array binding is strict. When reading a value, you must use the correct type. + +### Numeric +Until standard support comes to Zig (either in the stdlib or a de facto standard library), numeric support is half-baked. When binding a value to a parameter, you can use a f32, f64, comptime_float or string. The same applies to binding to a numeric array. + +You can `get(pg.Numeric, $COL)` to return a `pg.Numeric`. The `pg.Numeric` type only has 2 useful methods: `toFloat` and `toString`. You can also use `num.estimatedStringLen` to get the max size of the string representation: + +```zig +const numeric = row.get(pg.Numeric, 0); +var buf = allocator.alloc(u8, numeric.estimatedStringLen()); +defer allocator.free(buf) +const str = numeric.toString(&buf); +``` + +Using `row.get(f64, 0)` on a numeric is the same as `row.get(pg.Numeric, 0).toFloat()`. + +You should consider simply casting the numeric to `::double` or `::text` within SQL in order to rely on PostgreSQL's own robust numeric to float/text conversion. + +However, `pg.Numeric` has fields for the underlying wire-format of the numeric value. So if you require precision and the text representation isn't sufficient, you can parse the fields directly. `types/numeric.zig` is relatively well documented and tries to explain the fields. Note that any non-primitive fields, e.g. the `digits: []u8`, is only valid until the next call to `result.next`, `result.deinit`, `result.drain` or `row.deinit`. + +### UUID +When a `[]u8` is bound to a UUID column, it must either be a 16-byte slice, or a valid 36-byte hex-encoded UUID. Arrays behave the same. + +When reading a `uuid` column with `[]u8` a 16-byte slice will be returned. Use the `pg.uuidToHex() ![36]u8` helper if you need it hex-encoded. + +### INET/CIDR +You can bind a string value to a `cidr`, `inet`, `cidr[]` or `inet[]` parameter. + +When reading a value via `row.get` or `row.iterator` you should use `pg.Cidr`. It exposes 3 fields: + +* `address: []u8` - Will be a 4 or 16 byte slice depending on the family. +* `family: Family` - An enum, either `Family.v4` of `Family.v6`. +* `netmask: u8` - The network mask. + +### MacAddr/MacAddr8 +You can bind a `[]u8` to either a `macaddr` or a `macaddr8`. These can be either binary representation (6-bytes for `macaddr` or 8 bytes for `macaddr8`) or a text-representation supported by PostgreSQL. This works, like UUID, because there's no ambiguity in the length. The same applied for array variants - it's even possible to mix and match formats within the array. + +When reading a value via `row.get` or `row.iterator` using `[]u8`, the binary representation is always returned. + +### Timestamp(tz) +When you bind an `i64` to a timestamp(tz) parameter, the value is assumed to be the number of microseconds since unix epoch (e.g. `std.time.microTimestamp()`). Array binding works the same. You can also bind a string, which will pass the string as-is and depend on PostgreSQL to do the conversion. This is true for arrays as well. + +When reading a `timestamp` column with `i64`, the number of microseconds since unix epoch will be returned. + +### JSON and JSONB +When binding a value to a JSON or JSONB parameter, you can either supply a serialized value (i.e. `[]u8`) or a struct which will be serialized using `std.json.stringify`. + +When binding to an array of JSON or JSONB, automatic serialization is not support and thus an array of serialized values must be provided. + +When reading a `JSON` or `JSONB` column with `[]u8`, the serialized JSON will be returned. + +### PgLSN, xid8, xid +PgLSN and xid8 can be bound and read as i64. + +xid can be bound and read as i32. + +### Arbitrary Binary Encoding +For other types, either open an issue (ideally, with a sample query/data), or you can use binary encoding directly. + +For reading, you can use `[]u8` to get the raw binary encoded data and parse it yourself. + +For writing, wrap your raw encoded data in `pg.Binary{.data = ....}`. + +## Listen / Notify +You can create a `pg.Listener` either from an existing `Pool` or directly. + +Creating a new Listener directly is a lot like creating a new connection. See [Conn.open](#openallocator-stdmemallocator-opts-opts-conn) and [Conn.auth](#authopts-opts-void). + +```zig +// see the Conn.ConnectOpts +var listener = try pg.Listener.open(allocator, .{ + .host = "127.0.0.1", + .port = 5432, +}); +defer listener.deinit(); + +try listener.auth(.{ + .username = "leto", + .password = "ghanima", + .database = "caladan", +}); + +// add 1 or more channels to listen to +try listener.listen("chan_1", .{}); +try listener.listen("chan_2", .{}); + +// .next() blocks until there's a notification or an error +while (listener.next()) |notification| { + std.debug.print("Channel: {s}\nPayload: {s}", .{notification.channel, notification.payload}); +} + +// The error handling is explained, sorry about this API. Zig error payloads plz +switch (listener.err.?) { + .pg => |pg| std.debug.print("{s}\n", .{pg.message}), + .err => |err| std.debug.print("{s}\n", .{@errorName(err)}), +} +``` + +When using the pool, a new connection/session is created. It *does not* use a connection from the pool. This is merely a convenience function if you're also using normal connections through a pool. + +```zig +var listener = try pool.newListener(); +defer listener.deinit(); + +// listen to 1 or more channels +try listener.listen("chan_1"); + +// same as above +``` + +### Listen Timeout +When calling `listen`, a timeout in milliseconds can be specified: + +```zig +try listener.listen("chan_1", .{}); +``` + +If multiple calls to `listen` are made, the last timeout will be used. If no message is received in `timeout` milliseconds, `next()` will return `null` and `listener.err.?.err == error.WouldBlock`. + +### Reconnects +A listener will not automatically reconnect on error/disconnect. The pub/sub nature of LISTEN/NOTIFY mean that delivery is at-most-once and auto-reconnecting can hide that fact. Put the above code in a `while (true) {...}` loop. + +### Stop +It is safe to call `listener.stop()` from a different thread. When called, `next()` will return `null` and `listener.stopped` will be `true`. + +## Errors +The handling of errors isn't great. Blame Zig's lack of error payloads and the awkwardness of using `try` within a `while` condition. + +`listener.next()` can only return `null` on error. When `null` is returned, `listener.err` will be non-null. Unlike the `Conn` this is a tagged union that can either be `err` for a normal Zig error (e.g. error.ConnectionResetByPeer) or `pg` a detailed PostgresSQL error. + +## Metrics +A few basic metrics are collected using [metrics.zig](https://github.com/karlseguin/metrics.zig), a prometheus-compatible library. These can be written to an `std.io.Writer` using `try pg.writeMetrics(writer)`. As an example using [httpz](https://github.com/karlseguin/http.zig): + +```zig +pub fn metrics(_: *httpz.Request, res: *httpz.Response) !void { + const writer = res.writer(); + try pg.writeMetrics(writer); + + // also write out the httpz metrics + try httpz.writeMetrics(writer); +} +``` + +The metrics are: + +* `pg_queries` - counts the number of queries. +* `pg_pool_empty` - counts how often the pool is empty. +* `pg_pool_dirty` - counts how often a connection is released back into the pool in an unclean state (thus requiring the connection to be closed and the pool to re-open another connection). This could indicate that results aren't being fully drained (either by calling `next()` until `null` is returned or explicitly calling the `drain()` method). +* `pg_alloc_params` - counts the number of parameter states that were allocated. This indicates that your queries have more parameters than `result_state_size`. If this happens often, consider increasing `result_state_size`. +* `pg_alloc_columns` - counts the number of columns states that were allocated. This indicates that your queries are returning more columns than `result_state_size`. If this happens often, consider increasing `result_state_size`. +* `pg_alloc_reader` - counts the number of bytes allocated while reading messages from PostgreSQL. This generally happens as a result of large result (e.g. selecting large text fields). Controlled by the `read_buffer` configuration option. + +## TLS (Experimental) +TLS is supported via openssl. When loading the module, you must enable openssl by including at least 1 openssl setting: + +```zig +const pg_module = b.dependency("pg", .{ + .target = target, + .optimize = optimize, + .openssl_lib_name = @as([]const u8, "ssl"), + .openssl_lib_path = std.Build.LazyPath{.cwd_relative = "/path/to/openssl/lib"}, + .openssl_include_path = std.Build.LazyPath{.cwd_relative = "/path/to/openssl/include"}, +}).module("pg") +``` + +When not specified, the system defaults are use for the library and include paths. These should only be set if openssl is installed in a non-default location. In most cases specifying `.openssl_lib_name = "ssl"` or, for some systems `.openssl_lib_name = "openssl"` should be enough. + +Set the connection's `tls` option to either `.required` or `.{verify_full = null}`. When using a custom root certificate, specify the path: `.{verify_full = "/path/to/root.crt"}`. + +```zig +var pool = try pg.Pool.init(allocator, .{ + .connect = .{ .port = 5432, .host = "ip_or_hostname", .tls = .{.verify_full = null}}, + .auth = .{ .... }, + .size = 5, +}); + +// OR +const uri = try std.Uri.parse("postgresql://user:password@hostname/DBNAME?sslmode=require"); +var pool = try pg.Pool.initUri(allocator, uri, 10, 5_000); +``` + +In your main file, you can define a global `pub const pg_stderr_tls = true;` to have pg.zig print possible TLS-related errors to stderr. Alternatively, if you get an error, you `pg.printSSLError();` to hopefully print an error message to stderr which can be included in a ticket. This can safely be called in a `catch` clause, and will display nothing if the error is NOT SSL-related. Note that using the global `pg_stderr_tls` is more likely to print useful information in the case of certification verification problems. + +## Enabling Column Names by Default + +To execute all queries as if the `column_names` option was set to true, you can provide the `column_names` argument to `b.dependency()`: + +```zig +const pg_module = b.dependency("pg", .{ + .target = target, + .optimize = optimize, + .column_names = true, +}).module("pg"); +``` + +## Tests + +Launch the Postgres database with the provided Docker Compose configuration: + +```console +cd tests/ +docker compose up +``` + +Run tests: + +```console +zig build test +``` diff --git a/zig/pg/src/auth.zig b/zig/pg/src/auth.zig new file mode 100644 index 0000000..6ebcdad --- /dev/null +++ b/zig/pg/src/auth.zig @@ -0,0 +1,387 @@ +const std = @import("std"); +extern "c" fn arc4random_buf(buf: *anyopaque, nbytes: usize) void; +const lib = @import("lib.zig"); +const Buffer = @import("buffer").Buffer; + +const proto = lib.proto; +const Reader = lib.Reader; +const Stream = lib.Stream; + +const Allocator = std.mem.Allocator; + +const Opts = lib.Conn.AuthOpts; + +// Weird return (but Zig has no error payloads, so..) +// null on success +// a []const on a PG error +// - can be be passed to proto.Error.parse(owned) +// - is only valid until the next call to reader.read() +// (we expect our caller to clone the value) +// a normal zig error on any other error +pub fn auth(stream: *Stream, buf: *Buffer, reader: *Reader, opts: Opts) !?[]const u8 { + try reader.startFlow(null, opts.timeout); + + // ignore errors on endFlow, because it's troublesome to handle, and only + // something really bad (like OOM) can happen, and that'll surface again + // as soon as the app tries to use the connection. + defer reader.endFlow() catch {}; + + { + // write our startup message + const startup_message = proto.StartupMessage{ + .username = opts.username, + .application_name = opts.application_name, + .database = opts.database orelse opts.username, + }; + + buf.resetRetainingCapacity(); + try startup_message.write(buf); + try stream.writeAll(buf.string()); + } + + // read the server's response + { + const msg = try reader.next(); + switch (msg.type) { + 'R' => {}, + 'E' => return msg.data, + else => return error.UnexpectedDBMessage, + } + + switch (try proto.AuthenticationRequest.parse(msg.data)) { + .ok => return null, + .sasl => |sasl| if (try saslAuth(sasl, stream, buf, reader, opts)) |raw_pg_err| { + return raw_pg_err; + }, + .md5 => |salt| try md5PasswordAuth(salt, stream, buf, opts), + .password => try passwordAuth(opts.password orelse "", stream, buf), + } + } + + { + // if we're here, it's because we sent more data to the server (e.g. a password) + // and we're now waiting for a reply, server should send a final auth ok message + const msg = try reader.next(); + switch (msg.type) { + 'R' => {}, + 'E' => return msg.data, + else => return error.UnexpectedDBMessage, + } + + switch (try proto.AuthenticationRequest.parse(msg.data)) { + .ok => return null, + else => return error.UnexpectedDBMessage, + } + } +} + +fn saslAuth(req: proto.AuthenticationRequest.SASL, stream: *Stream, buf: *Buffer, reader: *Reader, opts: Opts) !?[]const u8 { + if (!req.scram_sha_256) { + return error.UnexpectedDBMessage; + } + var sasl_buf: [1024]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&sasl_buf); + var sasl = try SASL.init(fba.allocator()); + + { + // send the client initial response + const msg = proto.SASLInitialResponse{ + .response = sasl.client_first_message, + .mechanism = "SCRAM-SHA-256", + }; + buf.resetRetainingCapacity(); + try msg.write(buf); + try stream.writeAll(buf.string()); + } + + { + // read the server continue response + const msg = try reader.next(); + switch (msg.type) { + 'R' => {}, + 'E' => return msg.data, + else => return error.InvalidSASLFlow, + } + const c = try proto.AuthenticationSASLContinue.parse(msg.data); + try sasl.serverResponse(c.data); + } + + { + // send the client final response + const msg = proto.SASLResponse{ + .data = try sasl.clientFinalMessage(opts.password orelse ""), + }; + buf.resetRetainingCapacity(); + try msg.write(buf); + try stream.writeAll(buf.string()); + } + + { + // read the server final response + const msg = try reader.next(); + switch (msg.type) { + 'R' => {}, + 'E' => return msg.data, + else => return error.InvalidSASLFlow, + } + const final = try proto.AuthenticationSASLFinal.parse(msg.data); + try sasl.verifyServerFinal(final.data); + } + return null; +} + +fn md5PasswordAuth(salt: []const u8, stream: *Stream, buf: *Buffer, opts: Opts) !void { + var hash: [16]u8 = undefined; + { + var hasher = std.crypto.hash.Md5.init(.{}); + hasher.update(opts.password orelse ""); + hasher.update(opts.username); + hasher.final(&hash); + } + + { + const hex_hash = std.fmt.bytesToHex(&hash, .lower); + var hasher = std.crypto.hash.Md5.init(.{}); + hasher.update(&hex_hash); + hasher.update(salt); + hasher.final(&hash); + } + var hashed_password: [35]u8 = undefined; + const password = try std.fmt.bufPrint(&hashed_password, "md5{s}", .{&std.fmt.bytesToHex(&hash, .lower)}); + try passwordAuth(password, stream, buf); +} + +fn passwordAuth(password: []const u8, stream: *Stream, buf: *Buffer) !void { + buf.resetRetainingCapacity(); + const pw = proto.PasswordMessage{ .password = password }; + try pw.write(buf); + try stream.writeAll(buf.string()); +} + +const SASL = struct { + allocator: Allocator, + client_first_message: []u8, + auth_message: ?[]const u8 = null, + salted_password: ?[32]u8 = null, + server_response: ?ServerResponse = null, + + const Base64Encoder = std.base64.standard.Encoder; + const Base64Decoder = std.base64.standard.Decoder; + + pub fn init(allocator: Allocator) !SASL { + var nonce: [18]u8 = undefined; + arc4random_buf(&nonce, nonce.len); + + var client_first_message = try allocator.alloc(u8, 32); + client_first_message[0] = 'n'; + client_first_message[1] = ','; + client_first_message[2] = ','; + client_first_message[3] = 'n'; + client_first_message[4] = '='; + client_first_message[5] = ','; + client_first_message[6] = 'r'; + client_first_message[7] = '='; + _ = Base64Encoder.encode(client_first_message[8..], &nonce); + + return .{ + .allocator = allocator, + .client_first_message = client_first_message, + }; + } + + pub fn serverResponse(self: *SASL, data: []const u8) !void { + if (data.len < 8) { + return error.InvalidLength; + } + + // Specification states the attribute positions are fixed, so we expect r=X,s=Y,i=Z + if (data[0] != 'r' or data[1] != '=') { + return error.InvalidNoncePrefix; + } + + const owned = try self.allocator.dupe(u8, data); + + var res = ServerResponse{ + .raw = owned, + .nonce = undefined, + .base64_salt = undefined, + .iterations = undefined, + }; + + var pos: usize = 2; + { + const sep = std.mem.indexOfScalarPos(u8, owned, pos, ',') orelse return error.MissingSalt; + res.nonce = owned[2..sep]; + pos = sep + 1; + } + + { + const value_start = pos + 2; + if (owned.len < value_start or owned[pos] != 's' or owned[pos + 1] != '=') { + return error.InvalidSaltPrefix; + } + pos = value_start; + + const sep = std.mem.indexOfScalarPos(u8, owned, pos, ',') orelse return error.MissingIterations; + res.base64_salt = owned[pos..sep]; + pos = sep + 1; + } + + { + const value_start = pos + 2; + if (owned.len < value_start or owned[pos] != 'i' or owned[pos + 1] != '=') { + return error.InvalidIterationPrefix; + } + pos = value_start; + const sep = std.mem.indexOfScalarPos(u8, owned, pos, ',') orelse owned.len; + res.iterations = std.fmt.parseInt(u32, owned[pos..sep], 10) catch return error.InvalidIteration; + } + + self.server_response = res; + } + + pub fn clientFinalMessage(self: *SASL, password: []const u8) ![]const u8 { + const sr = self.server_response orelse return error.MissingServerResponse; + const allocator = self.allocator; + + const salt = blk: { + const s = try allocator.alloc(u8, try Base64Decoder.calcSizeForSlice(sr.base64_salt)); + try Base64Decoder.decode(s, sr.base64_salt); + break :blk s; + }; + + const unproved = try std.fmt.allocPrint(allocator, "c=biws,r={s}", .{sr.nonce}); + const auth_message = try std.fmt.allocPrint(allocator, "{s},{s},{s}", .{ self.client_first_message[3..], sr.raw, unproved }); + const salted_password = blk: { + var buf: [32]u8 = undefined; + try std.crypto.pwhash.pbkdf2(&buf, password, salt, sr.iterations, std.crypto.auth.hmac.sha2.HmacSha256); + break :blk buf; + }; + + const proof = blk: { + var client_key: [32]u8 = undefined; + std.crypto.auth.hmac.sha2.HmacSha256.create(&client_key, "Client Key", &salted_password); + + var stored_key: [32]u8 = undefined; + std.crypto.hash.sha2.Sha256.hash(&client_key, &stored_key, .{}); + + var client_signature: [32]u8 = undefined; + std.crypto.auth.hmac.sha2.HmacSha256.create(&client_signature, auth_message, &stored_key); + + var proof: [32]u8 = undefined; + for (client_key, client_signature, 0..) |ck, cs, i| { + proof[i] = ck ^ cs; + } + + var encoded_proof: [44]u8 = undefined; + _ = Base64Encoder.encode(&encoded_proof, &proof); + break :blk encoded_proof; + }; + + self.auth_message = auth_message; + self.salted_password = salted_password; + return std.fmt.allocPrint(allocator, "{s},p={s}", .{ unproved, proof }); + } + + pub fn verifyServerFinal(self: *SASL, data: []const u8) !void { + if (data.len < 46) { + return error.InvalidLength; + } + const auth_message = self.auth_message orelse return error.MissingAutMessage; + const salted_password = if (self.salted_password) |*sp| sp else return error.MissingSaltedPassword; + + const computed_signature = blk: { + var server_key: [32]u8 = undefined; + std.crypto.auth.hmac.sha2.HmacSha256.create(&server_key, "Server Key", salted_password); + + var server_signature: [32]u8 = undefined; + std.crypto.auth.hmac.sha2.HmacSha256.create(&server_signature, auth_message, &server_key); + + var encoded_signature: [44]u8 = undefined; + _ = Base64Encoder.encode(&encoded_signature, &server_signature); + break :blk encoded_signature; + }; + + // don't tell me about timing leaks unless there's also something in std to deal with it + if (std.mem.eql(u8, &computed_signature, data[2..]) == false) { + return error.InvalidServerSignature; + } + } +}; + +pub const ServerResponse = struct { + raw: []const u8, + nonce: []const u8, + base64_salt: []const u8, + iterations: u32, +}; + +const t = @import("lib.zig").testing; +test "SASL: init" { + defer t.reset(); + var sasl1 = try SASL.init(t.arena.allocator()); + + try t.expectString("n,,n=,r=", sasl1.client_first_message[0..8]); + + var sasl2 = try SASL.init(t.arena.allocator()); + try t.expectString("n,,n=,r=", sasl2.client_first_message[0..8]); + + var sasl3 = try SASL.init(t.arena.allocator()); + try t.expectString("n,,n=,r=", sasl3.client_first_message[0..8]); + + var sasl4 = try SASL.init(t.arena.allocator()); + try t.expectString("n,,n=,r=", sasl4.client_first_message[0..8]); + + // The nonce should be random. It's unlikely that if we generate 4, we'd get + // the same value at a given byte. + const nonce1 = sasl1.client_first_message[8..]; + const nonce2 = sasl2.client_first_message[8..]; + const nonce3 = sasl3.client_first_message[8..]; + const nonce4 = sasl4.client_first_message[8..]; + for (0..18) |i| { + try t.expectEqual(true, nonce1[i] != nonce2[i] or + nonce2[i] != nonce3[i] or + nonce1[i] != nonce3[i] or + nonce3[i] != nonce4[i] or + nonce1[i] != nonce4[i] or + nonce2[i] != nonce4[i]); + } +} + +test "SASL: serverResponse invalid" { + //invalid response + const InvalidTest = struct { + input: []const u8, + expected: anyerror, + }; + + const test_cases = [_]InvalidTest{ + .{ .input = "", .expected = error.InvalidLength }, + .{ .input = "r", .expected = error.InvalidLength }, + .{ .input = "r=", .expected = error.InvalidLength }, + .{ .input = "s=abc,r=123,i=32", .expected = error.InvalidNoncePrefix }, + .{ .input = "r=abc123,i=32,s=aaa", .expected = error.InvalidSaltPrefix }, + .{ .input = "r=abc123,s=aaa,x=32", .expected = error.InvalidIterationPrefix }, + .{ .input = "r=abc123", .expected = error.MissingSalt }, + .{ .input = "r=abc123,s=aaaa", .expected = error.MissingIterations }, + .{ .input = "r=abc123,s=aaaa,i=123a", .expected = error.InvalidIteration }, + }; + + defer t.reset(); + var sasl = try SASL.init(t.arena.allocator()); + + for (test_cases) |tc| { + try t.expectError(tc.expected, sasl.serverResponse(tc.input)); + try t.expectEqual(null, sasl.server_response); + } +} + +test "SASL: serverResponse" { + defer t.reset(); + var sasl = try SASL.init(t.arena.allocator()); + + try sasl.serverResponse("r=abc123,s=aaaaxa,i=4096"); + try t.expectString("abc123", sasl.server_response.?.nonce); + try t.expectString("aaaaxa", sasl.server_response.?.base64_salt); + try t.expectEqual(4096, sasl.server_response.?.iterations); +} diff --git a/zig/pg/src/conn.zig b/zig/pg/src/conn.zig new file mode 100644 index 0000000..dd38f3d --- /dev/null +++ b/zig/pg/src/conn.zig @@ -0,0 +1,2370 @@ +const std = @import("std"); + +fn posixTimestamp() i64 { + var ts: std.c.timespec = undefined; + _ = std.c.clock_gettime(.REALTIME, &ts); + return ts.sec; +} +const lib = @import("lib.zig"); +const Buffer = @import("buffer").Buffer; + +const proto = lib.proto; +const types = lib.types; +const Pool = lib.Pool; +const Stmt = lib.Stmt; +const SSLCtx = lib.SSLCtx; +const Reader = lib.Reader; +const Result = lib.Result; +const Stream = lib.Stream; +const Timeout = lib.Timeout; +const QueryRow = lib.QueryRow; +const QueryRowUnsafe = lib.QueryRowUnsafe; +const DynamicValue = lib.types.DynamicValue; +const has_openssl = lib.has_openssl; + +const os = std.os; +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +pub const Conn = struct { + // If we own the ssl context (which only happens if the connection is + // created directly and NOT through a pool), then we have to free it + _ssl_ctx: ?*SSLCtx, + + // If we get a postgreSQL error, this will be set. + err: ?proto.Error, + + // The underlying data for err + _err_data: ?[]const u8, + + _stream: Stream, + + _pool: ?*Pool = null, + + // Track connection age for pool rotation + _query_count: u64 = 0, + _created_at: i64 = 0, + // The current transation state, this is whatever the last ReadyForQuery + // message told us + _state: State, + + // A buffer used for writing to PG. This can grow dynamically as needed. + _buf: Buffer, + + // Used to read data from PG. Has its own buffer which can grow dynamically + _reader: Reader, + + _allocator: Allocator, + + // Holds information describing the query that we're executing. If the query + // returns more columns than an appropriately sized ResultState is created as + // needed. + _result_state: Result.State, + + // Holds information describing the parameters that PG is expecting. If the + // query has more parameters, than an appropriately sized one is created. + // This is separate from _result_state because: + // (a) they are populated separately + // (b) have distinct lifetimes + // (c) they likely have different lengths; + _param_oids: []i32, + + // cache_name => data necessary to re-execute previously prepared statement. + // Bounded: when cache exceeds MAX_CACHED_STMTS, the least-recently-used + // entry is evicted (its arena is freed). + _prepared_statements: std.StringHashMapUnmanaged(Stmt.Describe), + _stmt_cache_order: std.ArrayListUnmanaged([]const u8), + _stmt_cache_max: usize, + + const State = enum { + idle, + + // something bad happened + fail, + + // we're doing a query + query, + + // we're in a transaction + transaction, + }; + + pub const Opts = struct { + host: ?[]const u8 = null, + port: ?u16 = null, + write_buffer: ?u16 = null, + read_buffer: ?u16 = null, + result_state_size: u16 = 32, + tls: TLS = .off, + _hostz: ?[:0]const u8 = null, + + pub const TLS = union(enum) { + off: void, + require: void, + verify_full: ?[]const u8, + }; + }; + + pub const AuthOpts = struct { + username: []const u8 = "postgres", + password: ?[]const u8 = null, + database: ?[]const u8 = null, + timeout: u32 = 10_000, + application_name: ?[]const u8 = null, + startup_parameters: ?std.StringHashMap([]const u8) = null, + }; + + pub const QueryOpts = struct { + timeout: ?u32 = null, + column_names: bool = lib.default_column_names, + + allocator: ?Allocator = null, + // Whether a call to result.deinit() should automatically release the + // connection back to the pool. Meant to be used internally by pool.query() + // and the other pool utility wrappers, but applications might find it useful + // to use in their own helpers + release_conn: bool = false, + + // When not null, the prepared statement will be cached and re-used + // by subsequent queries using the same name. + cache_name: ?[]const u8 = null, + }; + + pub fn openAndAuthUri(allocator: Allocator, uri: std.Uri) !Conn { + var po = try lib.parseOpts(uri, allocator); + defer po.deinit(); + return try openAndAuth(allocator, po.opts.connect, po.opts.auth); + } + + pub fn openAndAuth(allocator: Allocator, opts: Opts, ao: AuthOpts) !Conn { + var conn = try open(allocator, opts); + errdefer conn.deinit(); + + try conn.auth(ao); + return conn; + } + + pub fn open(allocator: Allocator, opts: Opts) !Conn { + var ssl_ctx: ?*SSLCtx = null; + switch (opts.tls) { + .off => {}, + else => |tls_config| { + if (comptime lib.has_openssl == false) { + return error.OpenSSLNotConfigured; + } + ssl_ctx = try lib.initializeSSLContext(tls_config); + }, + } + errdefer lib.freeSSLContext(ssl_ctx); + var conn = try openWithContext(allocator, opts, ssl_ctx); + conn._ssl_ctx = ssl_ctx; + return conn; + } + + pub fn openWithContext(allocator: Allocator, opts: Opts, ssl_ctx: ?*SSLCtx) !Conn { + var stream = try Stream.connect(allocator, opts, ssl_ctx); + errdefer stream.close(); + + const buf = try Buffer.init(allocator, @max(opts.write_buffer orelse 2048, 128)); + errdefer buf.deinit(); + + const reader = try Reader.init(allocator, opts.read_buffer orelse 4096, stream); + errdefer reader.deinit(); + + const result_state = try Result.State.init(allocator, opts.result_state_size); + errdefer result_state.deinit(allocator); + + const param_oids = try allocator.alloc(i32, opts.result_state_size); + errdefer param_oids.deinit(allocator); + + return .{ + .err = null, + ._buf = buf, + ._ssl_ctx = null, + ._reader = reader, + ._stream = stream, + ._err_data = null, + ._state = .idle, + ._allocator = allocator, + ._param_oids = param_oids, + ._result_state = result_state, + ._prepared_statements = .{}, + ._stmt_cache_order = .{ .items = &.{}, .capacity = 0 }, + ._stmt_cache_max = 256, + ._created_at = posixTimestamp(), + }; + } + + pub fn deinit(self: *Conn) void { + const allocator = self._allocator; + if (self._err_data) |err_data| { + allocator.free(err_data); + } + self._buf.deinit(); + self._reader.deinit(); + allocator.free(self._param_oids); + self._result_state.deinit(allocator); + + // try to send a Terminate to the DB + self.write(&.{ 'X', 0, 0, 0, 4 }) catch {}; + lib.freeSSLContext(self._ssl_ctx); + self._stream.close(); + + var it = self._prepared_statements.valueIterator(); + while (it.next()) |value_ptr| { + value_ptr.arena.deinit(); + } + self._prepared_statements.deinit(self._allocator); + self._stmt_cache_order.deinit(self._allocator); + } + + pub fn release(self: *Conn) void { + var pool = self._pool orelse { + self.deinit(); + return; + }; + self.err = null; + pool.release(self); + } + + pub fn queryCount(self: *const Conn) u64 { + return self._query_count; + } + + pub fn age(self: *const Conn) i64 { + return posixTimestamp() - self._created_at; + } + + // ── LRU stmt cache helpers ────────────────────────────────────────── + + fn stmtCacheMoveToEnd(self: *Conn, name: []const u8) void { + const items = self._stmt_cache_order.items; + for (items, 0..) |item, i| { + if (std.mem.eql(u8, item, name)) { + // Remove from current position, append to end + _ = self._stmt_cache_order.orderedRemove(i); + self._stmt_cache_order.append(self._allocator, name) catch {}; + return; + } + } + } + + fn stmtCacheEvictOldest(self: *Conn) void { + if (self._stmt_cache_order.items.len == 0) return; + + // Remove oldest (index 0) + const oldest_name = self._stmt_cache_order.orderedRemove(0); + + // Free the statement's arena and remove from map + if (self._prepared_statements.fetchRemove(oldest_name)) |kv| { + var describe = kv.value; + describe.arena.deinit(); + } + } + + pub fn auth(self: *Conn, opts: AuthOpts) !void { + if (try lib.auth.auth(&self._stream, &self._buf, &self._reader, opts)) |raw_pg_err| { + return self.setErr(raw_pg_err); + } + + while (true) { + const msg = try self.read(); + switch (msg.type) { + 'Z' => return, + 'K' => {}, // TODO: BackendKeyData + else => return self.unexpectedDBMessage(), + } + } + } + + pub fn prepare(self: *Conn, sql: []const u8) !Stmt { + return self.prepareOpts(sql, .{}); + } + + pub fn prepareOpts(self: *Conn, sql: []const u8, opts: QueryOpts) !Stmt { + var stmt = try Stmt.init(self, opts); + errdefer stmt.deinit(); + try stmt.prepare(sql, null); + return stmt; + } + + pub fn query(self: *Conn, sql: []const u8, values: anytype) !*Result { + return self.queryOpts(sql, values, .{}); + } + + pub fn queryOpts(self: *Conn, sql: []const u8, values: anytype, opts: QueryOpts) !*Result { + if (self.canQuery() == false) { + self.maybeRelease(opts.release_conn); + return error.ConnectionBusy; + } + self._query_count += 1; + var cached = false; + var stmt: Stmt = undefined; + const name = opts.cache_name; + + if (name) |n| { + if (self._prepared_statements.getPtr(n)) |describe| { + cached = true; + stmt = try Stmt.fromDescribe(self, describe, opts); + errdefer stmt.deinit(); + + // Move to end of LRU order (most recently used) + self.stmtCacheMoveToEnd(n); + + try self._reader.startFlow(stmt.arena.allocator(), opts.timeout); + // Send a "SYNC" command + try self.write(&.{ 'S', 0, 0, 0, 4 }); + stmt.buf.reset(); + try stmt.prepareForBind(@intCast(describe.param_oids.len)); + } + } + + if (cached == false) { + // either this isn't supposed to be cached, or it is, but we don't + // have it in our cache + stmt = Stmt.init(self, opts) catch |err| { + self.maybeRelease(opts.release_conn); + return err; + }; + + errdefer stmt.deinit(); + if (name) |n| { + var describe_arena = ArenaAllocator.init(self._allocator); + errdefer describe_arena.deinit(); + try stmt.prepare(sql, describe_arena.allocator()); + + // Evict LRU if cache is full + if (self._stmt_cache_order.items.len >= self._stmt_cache_max) { + self.stmtCacheEvictOldest(); + } + + const owned_name = try describe_arena.allocator().dupe(u8, n); + try self._prepared_statements.put(self._allocator, owned_name, .{ + .arena = describe_arena, + .param_oids = stmt.param_oids, + .result_state = stmt.result_state, + }); + } else { + try stmt.prepare(sql, null); + } + } + + { + errdefer stmt.deinit(); + if (values.len != stmt.param_count) { + return error.WrongNumberOfParameters; + } + + inline for (values) |value| { + try stmt.bind(value); + } + } + + return stmt.execute() catch |err| { + stmt.deinit(); + self.maybeRelease(opts.release_conn); + return err; + }; + } + + pub fn row(self: *Conn, sql: []const u8, values: anytype) !?QueryRow { + return self._row(.safe, sql, values, .{}); + } + + pub fn rowUnsafe(self: *Conn, sql: []const u8, values: anytype) !?QueryRowUnsafe { + return self._row(.unsafe, sql, values, .{}); + } + + pub fn rowOpts(self: *Conn, sql: []const u8, values: anytype, opts: QueryOpts) !?QueryRow { + return self._row(.safe, sql, values, opts); + } + + pub fn rowUnsafeOpts(self: *Conn, sql: []const u8, values: anytype, opts: QueryOpts) !?QueryRowUnsafe { + return self._row(.unsafe, sql, values, opts); + } + + fn _row(self: *Conn, comptime fail_mode: lib.FailMode, sql: []const u8, values: anytype, opts: QueryOpts) !(if (fail_mode == .safe) ?QueryRow else ?QueryRowUnsafe) { + var result = try self.queryOpts(sql, values, opts); + errdefer result.deinit(); + + if (comptime fail_mode == .safe) { + return .{ + .result = result, + .row = try result.next() orelse { + result.deinit(); + return null; + }, + }; + } + + return .{ + .result = result, + .row = try result.nextUnsafe() orelse { + result.deinit(); + return null; + }, + }; + } + + // Execute a query that does not return rows + pub fn exec(self: *Conn, sql: []const u8, values: anytype) !?i64 { + return self.execOpts(sql, values, .{}); + } + + pub fn execManyDynamic(self: *Conn, sql: []const u8, rows: []const []const DynamicValue, opts: QueryOpts) !i64 { + if (rows.len == 0) return 0; + if (self.canQuery() == false) return error.ConnectionBusy; + + const auto_transaction = self._state == .idle; + if (auto_transaction) try self.begin(); + errdefer if (auto_transaction) self.rollback() catch {}; + + var cached = false; + var stmt: Stmt = undefined; + const name = opts.cache_name; + + if (name) |n| { + if (self._prepared_statements.getPtr(n)) |describe| { + cached = true; + stmt = try Stmt.fromDescribe(self, describe, opts); + errdefer stmt.deinit(); + + self.stmtCacheMoveToEnd(n); + + try self._reader.startFlow(stmt.arena.allocator(), opts.timeout); + stmt.buf.resetRetainingCapacity(); + try stmt.startBindMessage(@intCast(describe.param_oids.len)); + } + } + + if (!cached) { + stmt = try Stmt.init(self, opts); + errdefer stmt.deinit(); + + if (name) |n| { + var describe_arena = ArenaAllocator.init(self._allocator); + errdefer describe_arena.deinit(); + try stmt.prepare(sql, describe_arena.allocator()); + + if (self._stmt_cache_order.items.len >= self._stmt_cache_max) { + self.stmtCacheEvictOldest(); + } + + const owned_name = try describe_arena.allocator().dupe(u8, n); + try self._prepared_statements.put(self._allocator, owned_name, .{ + .arena = describe_arena, + .param_oids = stmt.param_oids, + .result_state = stmt.result_state, + }); + } else { + try stmt.prepare(sql, null); + } + } + + var total_rows: i64 = 0; + var batch_rows: usize = 0; + const batch_limit_bytes: usize = 256 * 1024; + + for (rows, 0..) |param_row, row_idx| { + if (param_row.len != stmt.param_count) return error.WrongNumberOfParameters; + if (row_idx != 0 and batch_rows == 0) { + stmt.buf.resetRetainingCapacity(); + try stmt.startBindMessage(stmt.param_count); + } else if (row_idx != 0) { + try stmt.startBindMessage(stmt.param_count); + } + + for (param_row) |value| { + try stmt.bindDynamic(value); + } + + const is_last = row_idx + 1 == rows.len; + const should_flush = is_last or stmt.buf.len() >= batch_limit_bytes; + try stmt.finishExecuteMessage(should_flush); + batch_rows += 1; + + if (should_flush) { + try self.write(stmt.buf.string()); + total_rows += try self.consumeExecManyResponses(batch_rows); + batch_rows = 0; + if (!is_last) { + try self._reader.startFlow(stmt.arena.allocator(), opts.timeout); + } + } + } + + stmt.deinit(); + if (auto_transaction) try self.commit(); + return total_rows; + } + + pub fn execOpts(self: *Conn, sql: []const u8, values: anytype, opts: QueryOpts) !?i64 { + if (self.canQuery() == false) { + return error.ConnectionBusy; + } + var buf = &self._buf; + // Large repeated execs often reuse similarly sized SQL payloads. Keep + // the grown write buffer instead of dropping back to the tiny static + // buffer and forcing realloc/copy on the next call. + buf.resetRetainingCapacity(); + + if (values.len == 0) { + try self._reader.startFlow(opts.allocator, opts.timeout); + defer self._reader.endFlow() catch { + // this can only fail in extreme conditions (OOM) and it will only impact + // the next query (and if the app is using the pool, the pool will try to + // recover from this anyways) + self._state = .fail; + }; + const simple_query = proto.Query{ .sql = sql }; + try simple_query.write(buf); + // no longer idle, we're now in a query + lib.metrics.query(); + self._state = .query; + try self.write(buf.string()); + } else { + // TODO: there's some optimization opportunities here, since we know + // we aren't expecting any result. We don't have to ask PG to DESCRIBE + // the returned columns (there should be none). This is very significant + // as it would remove 1 back-and-forth. We could just: + // Parse + Bind + Exec + Sync + // Instead of having to do: + // Parse + Describe + Sync ... read response ... Bind + Exec + Sync + const result = try self.queryOpts(sql, values, opts); + result.deinit(); + } + + // affected can be null, so we need a separate boolean to track if we + // actually have a response. + var affected: ?i64 = null; + while (true) { + const msg = self.read() catch |err| { + if (err == error.PG) { + self.readyForQuery() catch {}; + } + return err; + }; + switch (msg.type) { + 'C' => { + const cc = try proto.CommandComplete.parse(msg.data); + affected = cc.rowsAffected(); + }, + 'Z' => return affected, + 'T' => affected = 0, + 'D' => affected = (affected orelse 0) + 1, + else => return self.unexpectedDBMessage(), + } + } + } + + pub fn begin(self: *Conn) !void { + self._state = .transaction; + _ = try self.execOpts("begin", .{}, .{}); + } + + pub fn commit(self: *Conn) !void { + _ = try self.execOpts("commit", .{}, .{}); + } + + fn consumeExecManyResponses(self: *Conn, expected: usize) !i64 { + var completed: usize = 0; + var total_rows: i64 = 0; + + while (true) { + const msg = self.read() catch |err| { + if (err == error.PG) { + self.readyForQuery() catch {}; + } + return err; + }; + switch (msg.type) { + '2' => {}, + 'C' => { + const cc = try proto.CommandComplete.parse(msg.data); + total_rows += cc.rowsAffected() orelse 0; + completed += 1; + }, + 'I' => completed += 1, + 'T', 'D' => {}, + 'Z' => { + if (completed != expected) return error.UnexpectedDBMessage; + return total_rows; + }, + else => return self.unexpectedDBMessage(), + } + } + } + + /// COPY FROM STDIN: bulk insert rows as tab-separated text. + /// Sends COPY command, then CopyData messages for each row, then CopyDone. + /// Returns number of rows inserted. + pub fn copyFrom(self: *Conn, table: []const u8, columns: []const []const u8, rows: []const []const []const u8) !i64 { + if (self.canQuery() == false) { + return error.ConnectionBusy; + } + + var buf = &self._buf; + buf.reset(); + + // Build COPY command: COPY table(col1,col2,...) FROM STDIN + var sql_buf: [1024]u8 = undefined; + var sql_pos: usize = 0; + const prefix = "COPY "; + @memcpy(sql_buf[sql_pos..][0..prefix.len], prefix); + sql_pos += prefix.len; + @memcpy(sql_buf[sql_pos..][0..table.len], table); + sql_pos += table.len; + sql_buf[sql_pos] = '('; + sql_pos += 1; + for (columns, 0..) |col, i| { + if (i > 0) { + sql_buf[sql_pos] = ','; + sql_pos += 1; + } + @memcpy(sql_buf[sql_pos..][0..col.len], col); + sql_pos += col.len; + } + const suffix = ") FROM STDIN"; + @memcpy(sql_buf[sql_pos..][0..suffix.len], suffix); + sql_pos += suffix.len; + + // Send as simple query + try self._reader.startFlow(null, null); + defer self._reader.endFlow() catch { + self._state = .fail; + }; + const copy_query = proto.Query{ .sql = sql_buf[0..sql_pos] }; + try copy_query.write(buf); + self._state = .query; + try self.write(buf.string()); + + // Expect CopyInResponse (type 'G') + const msg = try self.read(); + if (msg.type != 'G') { + return self.unexpectedDBMessage(); + } + + // Send CopyData messages (type 'd'), each row as tab-separated text + newline + for (rows) |copy_row| { + buf.reset(); + // CopyData message: 'd' + length + data + try buf.writeByte('d'); + const len_pos = buf.len(); + try buf.write(&.{ 0, 0, 0, 0 }); // placeholder for length + + for (copy_row, 0..) |val, i| { + if (i > 0) try buf.writeByte('\t'); + try buf.write(val); + } + try buf.writeByte('\n'); + + // Write length (includes itself but not the 'd') + const data_len: u32 = @intCast(buf.len() - len_pos); + var len_bytes: [4]u8 = undefined; + std.mem.writeInt(u32, &len_bytes, data_len, .big); + buf.writeAt(&len_bytes, len_pos); + try self.write(buf.string()); + } + + // Send CopyDone + buf.reset(); + try buf.write(&.{ 'c', 0, 0, 0, 4 }); + try self.write(buf.string()); + + // Read CommandComplete + ReadyForQuery + var row_count: i64 = 0; + while (true) { + const resp = try self.read(); + switch (resp.type) { + 'C' => { + // CommandComplete: "COPY N" + const data = resp.data; + if (data.len > 5 and std.mem.startsWith(u8, data, "COPY ")) { + row_count = std.fmt.parseInt(i64, data[5 .. data.len - 1], 10) catch 0; + } + }, + 'Z' => break, // ReadyForQuery + 'E' => return self.unexpectedDBMessage(), + else => {}, + } + } + + self._state = .idle; + return row_count; + } + + // ── Query Pipelining ──────────────────────────────────────────────── + // Send multiple queries over a single connection with one round trip. + // All Bind+Execute messages are batched, with a single Sync at the end. + // Returns an array of JSON-serialized results (one per query). + // + // Usage: + // const results = try conn.pipelineExec(&.{ + // .{ .sql = "SELECT * FROM users WHERE id = $1", .values = .{1} }, + // .{ .sql = "SELECT * FROM users WHERE id = $1", .values = .{2} }, + // .{ .sql = "SELECT * FROM users WHERE id = $1", .values = .{3} }, + // }); + + pub const PipelineQuery = struct { + sql: []const u8, + cache_name: ?[]const u8 = null, + }; + + /// Execute multiple queries in a single pipeline (one round trip). + /// Each query must be pre-prepared. Sends Bind+Execute for each query + /// with a single Sync at the end. + /// Returns the number of successfully executed queries. + pub fn pipelineExecSimple(self: *Conn, queries: []const []const u8) !usize { + if (self.canQuery() == false) { + return error.ConnectionBusy; + } + self._query_count += queries.len; + + var buf = &self._buf; + buf.reset(); + + // Send all queries as simple Query messages, no Sync between them + for (queries) |sql| { + const query_msg = proto.Query{ .sql = sql }; + try query_msg.write(buf); + } + try self.write(buf.string()); + + self._state = .query; + + // Read all responses + var completed: usize = 0; + var expecting_ready = queries.len; + while (expecting_ready > 0) { + const msg = try self.read(); + switch (msg.type) { + 'C' => completed += 1, // CommandComplete + 'T' => {}, // RowDescription (skip) + 'D' => {}, // DataRow (skip for exec) + 'I' => {}, // EmptyQueryResponse + 'E' => {}, // Error (skip, count will be less) + 'Z' => { + expecting_ready -= 1; + if (expecting_ready == 0) { + self._state = .idle; + } + }, + else => {}, + } + } + + return completed; + } + + // We don't use `execOpts` here because rollback can be called at any point + // and we want to send this command even if the conn is in a fail state. + // So we issue the rollback, no matter what state we're in. + // It's also possible rollback was called while we were reading results, + // so we need to keep reading replies until we get a ready to query state, + // just skipping over any data rows or any other in-flight messages there + // might be. + pub fn rollback(self: *Conn) !void { + var buf = &self._buf; + buf.reset(); + + const state = self._state; + + const simple_query = proto.Query{ .sql = "rollback" }; + try simple_query.write(buf); + try self.write(buf.string()); + while (true) { + const msg = self.read() catch |err| { + if (state != .fail and err == error.PG) { + self.readyForQuery() catch {}; + } + return err; + }; + switch (msg.type) { + 'Z' => return, + 'C', 'T', 'D', 'n' => {}, + else => return self.unexpectedDBMessage(), + } + } + } + + pub fn deallocate(self: *Conn, cache_name: []const u8) !void { + if (self._prepared_statements.fetchRemove(cache_name)) |kv| { + kv.value.arena.deinit(); + } + const allocator = self._allocator; + const sql = try std.fmt.allocPrint(allocator, "deallocate {s}", .{cache_name}); + defer allocator.free(sql); + _ = try self.execOpts(sql, .{}, .{}); + } + + // Should not be called directly + pub fn peekForError(self: *Conn) !void { + const data = (try self._reader.peekForError()) orelse return; + try self.readyForQuery(); + return self.setErr(data); + } + + // Should not be called directly + pub fn read(self: *Conn) !lib.Message { + var reader = &self._reader; + while (true) { + const msg = reader.next() catch |err| { + self._state = .fail; + return err; + }; + switch (msg.type) { + 'Z' => { + self._state = switch (msg.data[0]) { + 'I' => .idle, + 'T' => .transaction, + 'E' => .fail, + else => unreachable, + }; + return msg; + }, + 'S' => {}, // TODO: ParameterStatus, + 'N' => {}, // TODO: NoticeResponse + 'E' => return self.setErr(msg.data), + else => return msg, + } + } + } + + pub fn write(self: *Conn, data: []const u8) !void { + self._stream.writeAll(data) catch |err| { + self._state = .fail; + return err; + }; + } + + fn setErr(self: *Conn, data: []const u8) error{ PG, OutOfMemory } { + const allocator = self._allocator; + + // The proto.Error that we're about to create is going to reference data. + // But data is owned by our Reader and its lifetime doesn't necessarily match + // what we want here. So we're going to dupe it and make the connection own + // the data so it can tie its lifecycle to the error. + + // That means clearing out any previous duped error data we had + if (self._err_data) |err_data| { + allocator.free(err_data); + } + + const owned = try allocator.dupe(u8, data); + self._err_data = owned; + self.err = proto.Error.parse(owned); + return error.PG; + } + + pub fn unexpectedDBMessage(self: *Conn) error{UnexpectedDBMessage} { + self._state = .fail; + return error.UnexpectedDBMessage; + } + + fn canQuery(self: *const Conn) bool { + const state = self._state; + if (state == .idle or state == .transaction) { + return true; + } + return false; + } + + inline fn maybeRelease(self: *Conn, rel: bool) void { + if (rel) { + self.release(); + } + } + + // should not be called directly + pub fn readyForQuery(self: *Conn) !void { + const msg = try self.read(); + if (msg.type != 'Z') { + return self.unexpectedDBMessage(); + } + } +}; + +const t = lib.testing; +test "Conn: auth trust (no pass)" { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try conn.auth(.{ .username = "pgz_user_nopass", .database = "postgres" }); +} + +test "Conn: auth unknown user" { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "does_not_exist" })); + try t.expectEqual(true, std.mem.indexOf(u8, conn.err.?.message, "user \"does_not_exist\"") != null); +} + +test "Conn: auth cleartext password" { + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "pgz_user_clear" })); + try t.expectString("empty password returned by client", conn.err.?.message); + } + + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "pgz_user_clear", .password = "wrong" })); + try t.expectString("password authentication failed for user \"pgz_user_clear\"", conn.err.?.message); + } + + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try conn.auth(.{ .username = "pgz_user_clear", .password = "pgz_user_clear_pw", .database = "postgres" }); + } +} + +test "Conn: auth scram-sha-256 password" { + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "pgz_user_scram_sha256" })); + try t.expectString("password authentication failed for user \"pgz_user_scram_sha256\"", conn.err.?.message); + } + + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "pgz_user_scram_sha256", .password = "wrong" })); + try t.expectString("password authentication failed for user \"pgz_user_scram_sha256\"", conn.err.?.message); + } + + { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try conn.auth(.{ .username = "pgz_user_scram_sha256", .password = "pgz_user_scram_sha256_pw", .database = "postgres" }); + } +} + +test "Conn: exec rowsAffected" { + var c = t.connect(.{}); + defer c.deinit(); + + { + const n = try c.exec("insert into simple_table values ('exec_insert_a'), ('exec_insert_b')", .{}); + try t.expectEqual(2, n.?); + } + + { + const n = try c.exec("update simple_table set value = 'exec_insert_a' where value = 'exec_insert_a'", .{}); + try t.expectEqual(1, n.?); + } + + { + const n = try c.exec("delete from simple_table where value like 'exec_insert%'", .{}); + try t.expectEqual(2, n.?); + } + + { + try t.expectEqual(null, try c.exec("begin", .{})); + try t.expectEqual(null, try c.exec("end", .{})); + } +} + +test "Conn: exec with values rowsAffected" { + var c = t.connect(.{}); + defer c.deinit(); + + { + const n = try c.exec("insert into simple_table values ($1), ($2)", .{ "exec_insert_args_a", "exec_insert_args_b" }); + try t.expectEqual(2, n.?); + } +} + +test "Conn: exec query that returns rows" { + var c = t.connect(.{}); + defer c.deinit(); + _ = try c.exec("insert into simple_table values ('exec_sel_1'), ('exec_sel_2')", .{}); + try t.expectEqual(0, c.exec("select * from simple_table where value = 'none'", .{})); + try t.expectEqual(2, c.exec("select * from simple_table where value like $1", .{"exec_sel_%"})); +} + +test "Conn: parse error" { + var c = t.connect(.{}); + defer c.deinit(); + try t.expectError(error.PG, c.query("selct 1", .{})); + + const err = c.err.?; + try t.expectString("42601", err.code); + try t.expectString("ERROR", err.severity); + try t.expectString("syntax error at or near \"selct\"", err.message); + + // connection is still usable + try t.expectEqual(2, t.scalar(&c, "select 2")); +} + +test "Conn: Query within Query error" { + var c = t.connect(.{}); + defer c.deinit(); + var rows = try c.query("select 1", .{}); + defer rows.deinit(); + + try t.expectError(error.ConnectionBusy, c.row("select 2", .{})); + try t.expectEqual(1, (try rows.nextUnsafe()).?.get(i32, 0)); +} + +test "PG: type support" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + var bytea1 = [_]u8{ 0, 1 }; + var bytea2 = [_]u8{ 255, 253, 253 }; + + { + const result = c.exec( + \\ + \\ insert into all_types ( + \\ id, + \\ col_int2, col_int2_arr, + \\ col_int4, col_int4_arr, + \\ col_int8, col_int8_arr, + \\ col_float4, col_float4_arr, + \\ col_float8, col_float8_arr, + \\ col_bool, col_bool_arr, + \\ col_text, col_text_arr, + \\ col_bytea, col_bytea_arr, + \\ col_enum, col_enum_arr, + \\ col_uuid, col_uuid_arr, + \\ col_numeric, col_numeric_arr, + \\ col_timestamp, col_timestamp_arr, + \\ col_timestamptz, col_timestamptz_arr, + \\ col_json, col_json_arr, + \\ col_jsonb, col_jsonb_arr, + \\ col_char, col_char_arr, + \\ col_charn, col_charn_arr, + \\ col_cidr, col_cidr_arr, + \\ col_inet, col_inet_arr, + \\ col_macaddr, col_macaddr_arr, + \\ col_macaddr8, col_macaddr8_arr + \\ ) values ( + \\ $1, + \\ $2, $3, + \\ $4, $5, + \\ $6, $7, + \\ $8, $9, + \\ $10, $11, + \\ $12, $13, + \\ $14, $15, + \\ $16, $17, + \\ $18, $19, + \\ $20, $21, + \\ $22, $23, + \\ $24, $25, + \\ $26, $27, + \\ $28, $29, + \\ $30, $31, + \\ $32, $33, + \\ $34, $35, + \\ $36, $37, + \\ $38, $39, + \\ $40, $41, + \\ $42, $43 + \\ ) + , .{ + 1, + @as(i16, 382), + [_]i16{ -9000, 9001 }, + @as(i32, -96534), + [_]i32{-4929123}, + @as(i64, 8983919283), + [_]i64{ 8888848483, 0, -1 }, + @as(f32, 1.2345), + [_]f32{ 4.492, -0.000021 }, + @as(f64, -48832.3233231), + [_]f64{ 393.291133, 3.1144 }, + true, + [_]bool{ false, true }, + "a text column", + [_][]const u8{ "it's", "over", "9000" }, + [_]u8{ 0, 0, 2, 255, 255, 255 }, + &[_][]u8{ &bytea1, &bytea2 }, + "val1", + [_][]const u8{ "val1", "val2" }, + "b7cc282f-ec43-49be-8e09-aafab0104915", + [_][]const u8{ "166B4751-D702-4FB9-9A2A-CD6B69ED18D6", "ae2f475f-8070-41b7-ba33-86bba8897bde" }, + 1234.567, + [_]f64{ 0, -1.1, std.math.nan(f64), std.math.inf(f32), 12345.000101 }, + "2023-10-23T15:33:13Z", + [_][]const u8{ "2010-02-10T08:22:07Z", "0003-04-05T06:07:08.123456" }, + "2024-11-23T16:34:14Z", + [_][]const u8{ "2011-03-11T09:23:05Z", "0002-03-04T05:06:02.0000991" }, + "{\"count\":1.3}", + [_][]const u8{ "[1,2,3]", "{\"rows\":[{\"a\": true}]}" }, + "{\"over\":9000}", + [_][]const u8{ "[true,false]", "{\"cols\":[{\"z\": 0.003}]}" }, + 79, + [_]u8{ '1', 'z', '!' }, + "Teg", + [_][]const u8{ &.{ 78, 82 }, "hi" }, + "192.168.100.128/25", + [_][]const u8{ "10.1.2", "2001:4f8:3:ba::/64" }, + "::ffff:1.2.3.0/120", + [_][]const u8{ "127.0.0.1/32", "2001:4f8:3:ba:2e0:81ff:fe22:d1f1/128" }, + "08:00:2b:01:02:03", + [_][]const u8{ "08002b:010203", "0800-2b01-0204" }, + "09:01:3b:21:21:03:04:05", + [_][]const u8{ "ffeeddccbbaa9988", "01-02-03-04-05-06-07-09" }, + }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query( + \\ select + \\ id, + \\ col_int2, col_int2_arr, + \\ col_int4, col_int4_arr, + \\ col_int8, col_int8_arr, + \\ col_float4, col_float4_arr, + \\ col_float8, col_float8_arr, + \\ col_bool, col_bool_arr, + \\ col_text, col_text_arr, + \\ col_bytea, col_bytea_arr, + \\ col_enum, col_enum_arr, + \\ col_uuid, col_uuid_arr, + \\ col_numeric, col_numeric_arr, + \\ col_timestamp, col_timestamp_arr, + \\ col_timestamptz, col_timestamptz_arr, + \\ col_json, col_json_arr, + \\ col_jsonb, col_jsonb_arr, + \\ col_char, col_char_arr, + \\ col_charn, col_charn_arr, + \\ col_cidr, col_cidr_arr, + \\ col_inet, col_inet_arr, + \\ col_macaddr, col_macaddr_arr, + \\ col_macaddr8, col_macaddr8_arr + \\ from all_types where id = $1 + , .{1}); + defer result.deinit(); + + // used for our arrays + const aa = t.arena.allocator(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(1, row.get(i32, 0)); + + { + // smallint & smallint[] + try t.expectEqual(382, row.get(i16, 1)); + try t.expectSlice(i16, &.{ -9000, 9001 }, try row.iterator(i16, 2).alloc(aa)); + } + + { + // int & int[] + try t.expectEqual(-96534, row.get(i32, 3)); + try t.expectSlice(i32, &.{-4929123}, try row.iterator(i32, 4).alloc(aa)); + } + + { + // bigint & bigint[] + try t.expectEqual(8983919283, row.get(i64, 5)); + try t.expectSlice(i64, &.{ 8888848483, 0, -1 }, try row.iterator(i64, 6).alloc(aa)); + } + + { + // float4, float4[] + try t.expectEqual(1.2345, row.get(f32, 7)); + try t.expectSlice(f32, &.{ 4.492, -0.000021 }, try row.iterator(f32, 8).alloc(aa)); + } + + { + // float8, float8[] + try t.expectEqual(-48832.3233231, row.get(f64, 9)); + try t.expectSlice(f64, &.{ 393.291133, 3.1144 }, try row.iterator(f64, 10).alloc(aa)); + } + + { + // bool, bool[] + try t.expectEqual(true, row.get(bool, 11)); + try t.expectSlice(bool, &.{ false, true }, try row.iterator(bool, 12).alloc(aa)); + } + + { + // text, text[] + try t.expectString("a text column", row.get([]u8, 13)); + const arr = try row.iterator([]const u8, 14).alloc(aa); + try t.expectEqual(3, arr.len); + try t.expectString("it's", arr[0]); + try t.expectString("over", arr[1]); + try t.expectString("9000", arr[2]); + } + + { + // bytea, bytea[] + try t.expectSlice(u8, &.{ 0, 0, 2, 255, 255, 255 }, row.get([]const u8, 15)); + const arr = try row.iterator([]u8, 16).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectSlice(u8, &bytea1, arr[0]); + try t.expectSlice(u8, &bytea2, arr[1]); + } + + { + // enum, emum[] + try t.expectString("val1", row.get([]u8, 17)); + const arr = try row.iterator([]const u8, 18).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("val1", arr[0]); + try t.expectString("val2", arr[1]); + } + + { + //uuid, uuid[] + try t.expectSlice(u8, &.{ 183, 204, 40, 47, 236, 67, 73, 190, 142, 9, 170, 250, 176, 16, 73, 21 }, row.get([]u8, 19)); + const arr = try row.iterator([]const u8, 20).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectSlice(u8, &.{ 22, 107, 71, 81, 215, 2, 79, 185, 154, 42, 205, 107, 105, 237, 24, 214 }, arr[0]); + try t.expectSlice(u8, &.{ 174, 47, 71, 95, 128, 112, 65, 183, 186, 51, 134, 187, 168, 137, 123, 222 }, arr[1]); + } + + { + // numeric, numeric[] + try t.expectEqual(1234.567, row.get(f64, 21)); + const arr = try row.iterator(types.Numeric, 22).alloc(aa); + try t.expectEqual(5, arr.len); + try expectNumeric(arr[0], "0.0"); + try expectNumeric(arr[1], "-1.1"); + try expectNumeric(arr[2], "nan"); + try expectNumeric(arr[3], "inf"); + try expectNumeric(arr[4], "12345.000101"); + } + + { + //timestamp, timestamp[] + try t.expectEqual(1698075193000000, row.get(i64, 23)); + try t.expectSlice(i64, &.{ 1265790127000000, -62064381171876544 }, try row.iterator(i64, 24).alloc(aa)); + } + + { + //timestamptz, timestamptz[] + try t.expectEqual(1732379654000000, row.get(i64, 25)); + try t.expectSlice(i64, &.{ 1299835385000000, -62098685637999901 }, try row.iterator(i64, 26).alloc(aa)); + } + + { + // json, json[] + try t.expectString("{\"count\":1.3}", row.get([]u8, 27)); + const arr = try row.iterator([]const u8, 28).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("[1,2,3]", arr[0]); + try t.expectString("{\"rows\":[{\"a\": true}]}", arr[1]); + } + + { + // jsonb, jsonb[] + try t.expectString("{\"over\": 9000}", row.get([]u8, 29)); + const arr = try row.iterator([]const u8, 30).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("[true, false]", arr[0]); + try t.expectString("{\"cols\": [{\"z\": 0.003}]}", arr[1]); + } + + { + // char, char[] + try t.expectEqual(79, row.get(u8, 31)); + const arr = try row.iterator(u8, 32).alloc(aa); + try t.expectEqual(3, arr.len); + try t.expectEqual('1', arr[0]); + try t.expectEqual('z', arr[1]); + try t.expectEqual('!', arr[2]); + } + + { + // charn, charn[] + try t.expectString("Teg", row.get([]u8, 33)); + const arr = try row.iterator([]u8, 34).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("NR", arr[0]); + try t.expectString("hi", arr[1]); + } + + { + // cidr, cidr[] + const cidr = row.get(types.Cidr, 35); + try t.expectEqual(25, cidr.netmask); + try t.expectEqual(.v4, cidr.family); + try t.expectString(&.{ 192, 168, 100, 128 }, cidr.address); + + const arr = try row.iterator(types.Cidr, 36).alloc(aa); + try t.expectEqual(2, arr.len); + + try t.expectEqual(24, arr[0].netmask); + try t.expectEqual(.v4, arr[0].family); + try t.expectString(&.{ 10, 1, 2, 0 }, arr[0].address); + + try t.expectEqual(64, arr[1].netmask); + try t.expectEqual(.v6, arr[1].family); + try t.expectSlice(u8, &.{ 32, 1, 4, 248, 0, 3, 0, 186, 0, 0, 0, 0, 0, 0, 0, 0 }, arr[1].address); + } + + { + // inet, inet[] + const inet = row.get(types.Cidr, 37); + try t.expectEqual(120, inet.netmask); + try t.expectEqual(.v6, inet.family); + try t.expectString(&.{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 1, 2, 3, 0 }, inet.address); + + const arr = try row.iterator(types.Cidr, 38).alloc(aa); + try t.expectEqual(2, arr.len); + + try t.expectEqual(32, arr[0].netmask); + try t.expectEqual(.v4, arr[0].family); + try t.expectString(&.{ 127, 0, 0, 1 }, arr[0].address); + + try t.expectEqual(128, arr[1].netmask); + try t.expectEqual(.v6, arr[1].family); + try t.expectSlice(u8, &.{ 32, 1, 4, 248, 0, 3, 0, 186, 2, 224, 129, 255, 254, 34, 209, 241 }, arr[1].address); + } + + { + // macaddr, macaddr[] + try t.expectSlice(u8, &.{ 8, 0, 43, 1, 2, 3 }, row.get([]u8, 39)); + + const arr = try row.iterator([]u8, 40).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectSlice(u8, &.{ 8, 0, 43, 1, 2, 3 }, arr[0]); + try t.expectSlice(u8, &.{ 8, 0, 43, 1, 2, 4 }, arr[1]); + } + + { + // macaddr8, macaddr8[] + try t.expectSlice(u8, &.{ 9, 1, 59, 33, 33, 3, 4, 5 }, row.get([]u8, 41)); + + const arr = try row.iterator([]u8, 42).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectSlice(u8, &.{ 255, 238, 221, 204, 187, 170, 153, 136 }, arr[0]); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 9 }, arr[1]); + } + + try t.expectEqual(null, try result.next()); +} + +// For ambiguous types, the above "type support" test is using the text-representation +// This test will use the binary representation of each of the ambiguous types +test "PG: binary support" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + + { + const result = c.exec( + \\ + \\ insert into all_types ( + \\ id, + \\ col_uuid, col_uuid_arr, + \\ col_timestamp, col_timestamp_arr, + \\ col_timestamptz, col_timestamptz_arr, + \\ col_numeric, col_numeric_arr, + \\ col_macaddr, col_macaddr_arr, + \\ col_macaddr8, col_macaddr8_arr + \\ ) values ( + \\ $1, + \\ $2, $3, + \\ $4, $5, + \\ $6, $7, + \\ $8, $9, + \\ $10, $11, + \\ $12, $13 + \\ ) + , .{ + 2, + &[_]u8{ 142, 243, 93, 100, 249, 159, 77, 126, 167, 54, 150, 204, 170, 222, 98, 124 }, + [_][]const u8{ &.{ 53, 140, 59, 37, 1, 148, 72, 139, 130, 197, 181, 40, 44, 109, 127, 165 }, &.{ 57, 203, 218, 97, 37, 38, 70, 107, 182, 116, 24, 125, 236, 123, 117, 247 } }, + 169804639500713, + [_]i64{ 169804639500713, -94668480000000 }, + 169804639500714, + [_]i64{ 169804639500714, -94668480000001 }, + "-394956.2221", + [_][]const u8{ "1.0008", "-987.110", "-inf" }, + &[_]u8{ 1, 2, 3, 4, 5, 6 }, + [_][]const u8{ &.{ 0, 1, 0, 2, 0, 3 }, &.{ 255, 0, 254, 1, 253, 2 } }, + &[_]u8{ 1, 2, 3, 4, 5, 6, 7, 8 }, + [_][]const u8{ &.{ 0, 1, 0, 2, 0, 3, 4, 0 }, &.{ 255, 0, 254, 1, 253, 2, 3, 252 } }, + }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query( + \\ select + \\ col_uuid, col_uuid_arr, + \\ col_timestamp, col_timestamp_arr, + \\ col_timestamptz, col_timestamptz_arr, + \\ col_numeric, col_numeric_arr, + \\ col_macaddr, col_macaddr_arr, + \\ col_macaddr8, col_macaddr8_arr + \\ from all_types where id = $1 + , .{2}); + defer result.deinit(); + + // used for our arrays + const aa = t.arena.allocator(); + + const row = (try result.nextUnsafe()) orelse unreachable; + + { + //uuid, uuid[] + try t.expectString("8ef35d64-f99f-4d7e-a736-96ccaade627c", &(try types.UUID.toString(row.get([]u8, 0)))); + + const arr = try row.iterator([]const u8, 1).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("358c3b25-0194-488b-82c5-b5282c6d7fa5", &(try types.UUID.toString(arr[0]))); + try t.expectString("39cbda61-2526-466b-b674-187dec7b75f7", &(try types.UUID.toString(arr[1]))); + } + + { + //timestamp, timestamp[] + try t.expectEqual(169804639500713, row.get(i64, 2)); + try t.expectSlice(i64, &.{ 169804639500713, -94668480000000 }, try row.iterator(i64, 3).alloc(aa)); + } + + { + //timestamptz, timestamptz[] + try t.expectEqual(169804639500714, row.get(i64, 4)); + try t.expectSlice(i64, &.{ 169804639500714, -94668480000001 }, try row.iterator(i64, 5).alloc(aa)); + } + + { + //numeric, numeric[] + try t.expectEqual(-394956.2221, row.get(f64, 6)); + try t.expectSlice(f64, &.{ 1.0008, -987.110, -std.math.inf(f64) }, try row.iterator(f64, 7).alloc(aa)); + } + + { + //macaddr, macaddr[] + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6 }, row.get([]u8, 8)); + const arr = try row.iterator([]u8, 9).alloc(aa); + try t.expectSlice(u8, &.{ 0, 1, 0, 2, 0, 3 }, arr[0]); + try t.expectSlice(u8, &.{ 255, 0, 254, 1, 253, 2 }, arr[1]); + } + + { + //macaddr8, macaddr8[] + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, row.get([]u8, 10)); + const arr = try row.iterator([]u8, 11).alloc(aa); + try t.expectSlice(u8, &.{ 0, 1, 0, 2, 0, 3, 4, 0 }, arr[0]); + try t.expectSlice(u8, &.{ 255, 0, 254, 1, 253, 2, 3, 252 }, arr[1]); + } + + try t.expectEqual(null, try result.next()); +} + +test "PG: null support" { + var c = t.connect(.{}); + defer c.deinit(); + { + const result = c.exec( + \\ + \\ insert into all_types ( + \\ id, + \\ col_int2, col_int2_arr, + \\ col_int4, col_int4_arr, + \\ col_int8, col_int8_arr, + \\ col_float4, col_float4_arr, + \\ col_float8, col_float8_arr, + \\ col_bool, col_bool_arr, + \\ col_text, col_text_arr, + \\ col_bytea, col_bytea_arr, + \\ col_enum, col_enum_arr, + \\ col_uuid, col_uuid_arr, + \\ col_numeric, col_numeric_arr, + \\ col_timestamp, col_timestamp_arr, + \\ col_json, col_json_arr, + \\ col_jsonb, col_jsonb_arr, + \\ col_char, col_char_arr, + \\ col_charn, col_charn_arr + \\ ) values ( + \\ $1, + \\ $2, $3, + \\ $4, $5, + \\ $6, $7, + \\ $8, $9, + \\ $10, $11, + \\ $12, $13, + \\ $14, $15, + \\ $16, $17, + \\ $18, $19, + \\ $20, $21, + \\ $22, $23, + \\ $24, $25, + \\ $26, $27, + \\ $28, $29, + \\ $30, $31, + \\ $32, $33 + \\ ) + , .{ + 3, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query( + \\ select + \\ id, + \\ col_int2, col_int2_arr, + \\ col_int4, col_int4_arr, + \\ col_int8, col_int8_arr, + \\ col_float4, col_float4_arr, + \\ col_float8, col_float8_arr, + \\ col_bool, col_bool_arr, + \\ col_text, col_text_arr, + \\ col_bytea, col_bytea_arr, + \\ col_enum, col_enum_arr, + \\ col_uuid, col_uuid_arr, + \\ col_numeric, 'numeric[] placeholder', + \\ col_timestamp, col_timestamp_arr, + \\ col_json, col_json_arr, + \\ col_jsonb, col_jsonb_arr, + \\ col_char, col_char_arr, + \\ col_charn, col_charn_arr + \\ from all_types where id = $1 + , .{3}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(null, row.get(?i16, 1)); + try t.expectEqual(true, row.iterator(i16, 2).is_null); + + try t.expectEqual(null, row.get(?i32, 3)); + try t.expectEqual(true, row.iterator(i32, 4).is_null); + + try t.expectEqual(null, row.get(?i64, 5)); + try t.expectEqual(true, row.iterator(i64, 6).is_null); + + try t.expectEqual(null, row.get(?f32, 7)); + try t.expectEqual(true, row.iterator(f32, 8).is_null); + + try t.expectEqual(null, row.get(?f64, 9)); + try t.expectEqual(true, row.iterator(f64, 10).is_null); + + try t.expectEqual(null, row.get(?bool, 11)); + try t.expectEqual(true, row.iterator(bool, 12).is_null); + + try t.expectEqual(null, row.get(?[]u8, 13)); + try t.expectEqual(true, row.iterator([]u8, 14).is_null); + + try t.expectEqual(null, row.get(?[]const u8, 15)); + try t.expectEqual(true, row.iterator([]const u8, 16).is_null); + + try t.expectEqual(null, row.get(?[]const u8, 17)); + try t.expectEqual(true, row.iterator([]const u8, 18).is_null); + + try t.expectEqual(null, row.get(?[]u8, 19)); + try t.expectEqual(true, row.iterator([]const u8, 20).is_null); + + try t.expectEqual(null, row.get(?[]u8, 21)); + try t.expectEqual(null, row.get(?f64, 21)); + + try t.expectEqual(null, row.get(?i64, 23)); + try t.expectEqual(true, row.iterator(i64, 24).is_null); + + try t.expectEqual(null, row.get(?[]u8, 25)); + try t.expectEqual(true, row.iterator([]const u8, 26).is_null); + + try t.expectEqual(null, row.get(?[]u8, 27)); + try t.expectEqual(true, row.iterator([]const u8, 28).is_null); + + try t.expectEqual(null, row.get(?u8, 29)); + try t.expectEqual(true, row.iterator(u8, 30).is_null); + + try t.expectEqual(null, row.get(?u8, 31)); + try t.expectEqual(true, row.iterator(u8, 32).is_null); + + try t.expectEqual(null, try result.next()); +} + +test "PG: query column names" { + var c = t.connect(.{}); + defer c.deinit(); + { + var result = try c.query("select 1 as id, 'leto' as name", .{}); + try t.expectEqual(0, result.column_names.len); + try result.drain(); + result.deinit(); + } + + { + var result = try c.queryOpts("select 1 as id, 'leto' as name", .{}, .{ .column_names = true }); + defer result.deinit(); + try t.expectEqual(2, result.column_names.len); + try t.expectString("id", result.column_names[0]); + try t.expectString("name", result.column_names[1]); + } +} + +test "PG: JSON struct" { + var c = t.connect(.{}); + defer c.deinit(); + + { + const result = c.exec( + \\ + \\ insert into all_types (id, col_json, col_jsonb) + \\ values ($1, $2, $3) + , .{ 4, DummyStruct{ .id = 1, .name = "Leto" }, &DummyStruct{ .id = 2, .name = "Ghanima" } }); + + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select col_json, col_jsonb from all_types where id = $1", .{4}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectString("{\"id\":1,\"name\":\"Leto\"}", row.get([]u8, 0)); + try t.expectString("{\"id\": 2, \"name\": \"Ghanima\"}", row.get(?[]const u8, 1).?); +} + +test "Conn: prepare" { + var c = t.connect(.{}); + defer c.deinit(); + + var stmt = try c.prepare("select $1::int where $2"); + try stmt.bind(938); + try stmt.bind(true); + + var result = try stmt.execute(); + defer result.deinit(); + + var row = (try result.next()) orelse unreachable; + try t.expectEqual(938, row.get(i32, 0)); + + try t.expectEqual(null, try result.next()); +} + +test "PG: row" { + var c = t.connect(.{}); + defer c.deinit(); + + const r1 = try c.row("select 1 where $1", .{false}); + try t.expectEqual(null, r1); + + var r2 = (try c.rowUnsafe("select 2 where $1", .{true})) orelse unreachable; + try t.expectEqual(2, r2.get(i32, 0)); + try r2.deinit(); + + // make sure the conn is still valid after a successful row + var r3 = (try c.rowUnsafe("select $1::int where $2", .{ 3, true })) orelse unreachable; + try t.expectEqual(3, r3.get(i32, 0)); + try r3.deinit(); + + // make sure the conn is still valid after MoreThanOneRow error + var r4 = (try c.rowUnsafe("select $1::text where $2", .{ "hi", true })) orelse unreachable; + try t.expectString("hi", r4.get([]u8, 0)); + try r4.deinit(); +} + +test "PG: begin/commit" { + var c = t.connect(.{}); + defer c.deinit(); + + try c.begin(); + _ = try c.exec("delete from simple_table", .{}); + _ = try c.exec("insert into simple_table values ($1)", .{"begin_commit"}); + try c.commit(); + + var row = (try c.rowUnsafe("select value from simple_table", .{})).?; + defer row.deinit() catch {}; + + try t.expectString("begin_commit", row.get([]u8, 0)); +} + +test "PG: begin/rollback" { + var c = t.connect(.{}); + defer c.deinit(); + + _ = try c.exec("delete from simple_table", .{}); + try c.begin(); + _ = try c.exec("insert into simple_table values ($1)", .{"begin_commit"}); + try c.rollback(); + + const row = try c.row("select value from simple_table", .{}); + try t.expectEqual(null, row); +} + +test "PG: bind enums" { + var c = t.connect(.{}); + defer c.deinit(); + + _ = try c.exec( + \\ insert into all_types (id, col_enum, col_enum_arr, col_text, col_text_arr) + \\ values (5, $1, $2, $3, $4) + , .{ DummyEnum.val1, &[_]DummyEnum{ DummyEnum.val1, DummyEnum.val2 }, DummyEnum.val2, [_]DummyEnum{ DummyEnum.val2, DummyEnum.val1 } }); + + var row = (try c.rowUnsafe( + \\ select col_enum, col_text, col_enum_arr, col_text_arr + \\ from all_types + \\ where id = 5 + , .{})) orelse unreachable; + defer row.deinit() catch {}; + + try t.expectString("val1", row.get([]u8, 0)); + try t.expectString("val2", row.get([]u8, 1)); + + var arena = std.heap.ArenaAllocator.init(t.allocator); + defer arena.deinit(); + const aa = arena.allocator(); + + { + const arr = try row.iterator([]const u8, 2).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("val1", arr[0]); + try t.expectString("val2", arr[1]); + } + + { + const arr = try row.iterator([]const u8, 3).alloc(aa); + try t.expectEqual(2, arr.len); + try t.expectString("val2", arr[0]); + try t.expectString("val1", arr[1]); + } +} + +test "PG: numeric" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + + { + // read + var row = (try c.rowUnsafe( + \\ select 'nan'::numeric, '+Inf'::numeric, '-Inf'::numeric, + \\ 0::numeric, 0.0::numeric, -0.00009::numeric, -999999.888880::numeric, + \\ 0.000008, 999999.888807::numeric, 123456.78901234::numeric(14, 8) + , .{})).?; + defer row.deinit() catch {}; + + try t.expectEqual(true, std.math.isNan(row.get(f64, 0))); + try t.expectEqual(true, std.math.isInf(row.get(f64, 1))); + try t.expectEqual(true, std.math.isNegativeInf(row.get(f64, 2))); + try t.expectEqual(0, row.get(f64, 3)); + try t.expectEqual(0, row.get(f64, 4)); + try t.expectEqual(-0.00009, row.get(f64, 5)); + try t.expectEqual(-999999.888880, row.get(f64, 6)); + try t.expectEqual(0.000008, row.get(f64, 7)); + try t.expectEqual(999999.888807, row.get(f64, 8)); + try t.expectEqual(123456.78901234, row.get(f64, 9)); + } + + { + // write + write + var row = (try c.rowUnsafe( + \\ select + \\ $1::numeric, $2::numeric, $3::numeric, + \\ $4::numeric, $5::numeric, $6::numeric, + \\ $7::numeric, $8::numeric, $9::numeric, + \\ $10::numeric, $11::numeric, $12::numeric, + \\ $13::numeric, $14::numeric, $15::numeric, + \\ $16::numeric, $17::numeric, $18::numeric, + \\ $19::numeric, $20::numeric, $21::numeric, + \\ $22::numeric, $23::numeric, $24::numeric, + \\ $25::numeric, $26::numeric, $27::numeric[] + , .{ -0.00089891, 939293122.0001101, "-123.4560991", std.math.nan(f64), std.math.inf(f64), -std.math.inf(f64), std.math.nan(f32), std.math.inf(f32), -std.math.inf(f32), 1.1, 12.98, 123.987, 1234.9876, 12345.98765, 123456.987654, 1234567.9876543, 12345678.98765432, 123456789.987654321, @as(f64, 0), @as(f64, 1), 0, 1, 999999999.9999999, @as(f64, 999999999.9999999), -999999999.9999999, @as(f64, -999999999.9999999), &[_][]const u8{ "1.1", "-0.0034" } })).?; + defer row.deinit() catch {}; + + { + // test the pg.Numeric fields + const numeric = row.get(types.Numeric, 1); + try t.expectEqual(939293122.0001101, numeric.toFloat()); + try t.expectEqual(2, numeric.weight); + try t.expectEqual(.positive, numeric.sign); + try t.expectEqual(7, numeric.scale); + try t.expectSlice(u8, &.{ 0, 9, 15, 89, 12, 50, 0, 1, 3, 242 }, numeric.digits); + } + + try expectNumeric(row.get(types.Numeric, 0), "-0.00089891"); + try expectNumeric(row.get(types.Numeric, 1), "939293122.0001101"); + try expectNumeric(row.get(types.Numeric, 2), "-123.4560991"); + + try expectNumeric(row.get(types.Numeric, 3), "nan"); + try expectNumeric(row.get(types.Numeric, 4), "inf"); + try expectNumeric(row.get(types.Numeric, 5), "-inf"); + + try expectNumeric(row.get(types.Numeric, 6), "nan"); + try expectNumeric(row.get(types.Numeric, 7), "inf"); + try expectNumeric(row.get(types.Numeric, 8), "-inf"); + + try expectNumeric(row.get(types.Numeric, 9), "1.1"); + try expectNumeric(row.get(types.Numeric, 10), "12.98"); + try expectNumeric(row.get(types.Numeric, 11), "123.987"); + try expectNumeric(row.get(types.Numeric, 12), "1234.9876"); + try expectNumeric(row.get(types.Numeric, 13), "12345.98765"); + try expectNumeric(row.get(types.Numeric, 14), "123456.987654"); + try expectNumeric(row.get(types.Numeric, 15), "1234567.9876543"); + try expectNumeric(row.get(types.Numeric, 16), "12345678.98765432"); + try expectNumeric(row.get(types.Numeric, 17), "123456789.987654321"); + try expectNumeric(row.get(types.Numeric, 18), "0.0"); + try expectNumeric(row.get(types.Numeric, 19), "1.0"); + try expectNumeric(row.get(types.Numeric, 20), "0.0"); + try expectNumeric(row.get(types.Numeric, 21), "1.0"); + try expectNumeric(row.get(types.Numeric, 22), "999999999.9999999"); + try expectNumeric(row.get(types.Numeric, 23), "999999999.9999999"); + try expectNumeric(row.get(types.Numeric, 24), "-999999999.9999999"); + try expectNumeric(row.get(types.Numeric, 25), "-999999999.9999999"); + + const arr = try row.iterator(types.Numeric, 26).alloc(t.arena.allocator()); + try t.expectEqual(2, arr.len); + try t.expectEqual(1.1, arr[0].toFloat()); + try t.expectDelta(-0.0034, arr[1].toFloat(), 0.00000001); + } +} + +// char array encoding is a little special, so let's test variants +test "PG: char" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + + // read + var row = (try c.rowUnsafe( + \\ select $1::char[], $2::char[], $3::char[], $4::char[] + , .{ &[_]u8{','}, &[_]u8{ ',', '"' }, &[_]u8{ '\\', 'a', ' ' }, &[_]u8{ 'z', '@' } })).?; + defer row.deinit() catch {}; + + // used for our arrays + const aa = t.arena.allocator(); + + try t.expectSlice(u8, &.{','}, try row.iterator(u8, 0).alloc(aa)); + try t.expectSlice(u8, &.{ ',', '"' }, try row.iterator(u8, 1).alloc(aa)); + try t.expectSlice(u8, &.{ '\\', 'a', ' ' }, try row.iterator(u8, 2).alloc(aa)); + try t.expectSlice(u8, &.{ 'z', '@' }, try row.iterator(u8, 3).alloc(aa)); +} + +test "PG: bind []const u8" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + const value: []const u8 = "hello"; + + { + const result = c.exec("insert into all_types (id, col_text) values ($1, $2)", .{ 6, value }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select id, col_text from all_types where id = $1", .{6}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(6, row.get(i32, 0)); + try t.expectString("hello", row.get([]u8, 1)); +} + +test "PG: bind []?i64" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + const values = [_]?i64{ 1, null, 3 }; + + { + const result = c.exec("insert into all_types (id, col_int8_arr) values ($1, $2)", .{ 7, values }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select id, col_int8_arr from all_types where id = $1", .{7}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(7, row.get(i32, 0)); + { + const arr = try row.iterator(?i64, 1).alloc(t.arena.allocator()); + try t.expectEqual(3, arr.len); + try t.expectEqual(1, arr[0]); + try t.expectEqual(null, arr[1]); + try t.expectEqual(3, arr[2]); + } +} + +test "PG: bind []?f64" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + const values = [_]?f64{ null, null, 0.2, null }; + + { + const result = c.exec("insert into all_types (id, col_float8_arr) values ($1, $2)", .{ 8, values }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select id, col_float8_arr from all_types where id = $1", .{8}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(8, row.get(i32, 0)); + { + const arr = try row.iterator(?f64, 1).alloc(t.arena.allocator()); + try t.expectEqual(4, arr.len); + try t.expectEqual(null, arr[0]); + try t.expectEqual(null, arr[1]); + try t.expectEqual(0.2, arr[2]); + try t.expectEqual(null, arr[3]); + } +} + +test "PG: bind []?bool" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + const values = [_]?bool{ null, true, false, null }; + + { + const result = c.exec("insert into all_types (id, col_bool_arr) values ($1, $2)", .{ 9, values }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select id, col_bool_arr from all_types where id = $1", .{9}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(9, row.get(i32, 0)); + { + const arr = try row.iterator(?bool, 1).alloc(t.arena.allocator()); + try t.expectEqual(4, arr.len); + try t.expectEqual(null, arr[0]); + try t.expectEqual(true, arr[1]); + try t.expectEqual(false, arr[2]); + try t.expectEqual(null, arr[3]); + } +} + +test "PG: bind []?[]const u8" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + const values = [_]?[]const u8{ "hello", null, null }; + + { + const result = c.exec("insert into all_types (id, col_text_arr) values ($1, $2)", .{ 10, values }); + if (result) |affected| { + try t.expectEqual(1, affected); + } else |err| { + try t.fail(c, err); + } + } + + var result = try c.query("select id, col_text_arr from all_types where id = $1", .{10}); + defer result.deinit(); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(10, row.get(i32, 0)); + { + const arr = try row.iterator(?[]const u8, 1).alloc(t.arena.allocator()); + try t.expectEqual(3, arr.len); + try t.expectString("hello", arr[0].?); + try t.expectEqual(null, arr[1]); + try t.expectEqual(null, arr[2]); + } +} + +test "PG: binary wrapper" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + + _ = try c.exec( + \\ create extension if not exists postgis; + \\ create table if not exists places ( + \\ id int not null, + \\ location geography not null + \\ ); + , .{}); + + const data = lib.Binary{ + .data = &.{ 1, 1, 0, 0, 32, 230, 16, 0, 0, 43, 107, 238, 243, 22, 122, 82, 192, 60, 20, 204, 226, 238, 89, 68, 64 }, + }; + var row = (try c.rowUnsafe("select $1::geography", .{data})).?; + defer row.deinit() catch {}; + try t.expectString(data.data, row.get([]const u8, 0)); +} + +test "PG: isUnique" { + defer t.reset(); + + var c = t.connect(.{}); + defer c.deinit(); + + { + try t.expectError(error.PG, c.exec("insert into all_types (id, id) values ($1)", .{ 999, null })); + try t.expectEqual(false, c.err.?.isUnique()); + } + + { + _ = try c.exec("insert into all_types (id) values ($1)", .{999}); + _ = try t.expectError(error.PG, c.exec("insert into all_types (id) values ($1)", .{999})); + try t.expectEqual(true, c.err.?.isUnique()); + } + + { + // can still use the connection after the error + _ = try t.expectError(error.PG, c.row("insert into all_types (id) values ($1) returning id", .{999})); + try t.expectEqual(true, c.err.?.isUnique()); + } +} + +test "PG: large read" { + var c = t.connect(.{ .read_buffer = 500 }); + defer c.deinit(); + + { + // want this to be larger than our read_buffer + var rows = try c.query("select $1::text", .{"!" ** 1000}); + defer rows.deinit(); + + const row = (try rows.nextUnsafe()).?; + try t.expectString("!" ** 1000, row.get([]u8, 0)); + try t.expectEqual(null, try rows.next()); + } + + { + // with a row + var row = (try c.rowUnsafe("select $1::text", .{"z" ** 1000})).?; + defer row.deinit() catch {}; + try t.expectString("z" ** 1000, row.get([]u8, 0)); + } +} + +test "Conn: dynamic buffer freed on error" { + var c = t.connect(.{ .read_buffer = 100 }); + defer c.deinit(); + + var rows = try c.query("select $1::text", .{"!" ** 200}); + defer rows.deinit(); + + const row = (try rows.nextUnsafe()).?; + try t.expectString("!" ** 200, row.get([]u8, 0)); + + // we end here, simulating the app returning an error. This causes + // rows.deinit() and c.deinit() to be called prematurely (from + // the point of view of our internal state). Specifically, conn.reader.endFlow + // isn't called. +} + +test "PG: Record" { + var c = t.connect(.{}); + defer c.deinit(); + + { + var row = (try c.rowUnsafe("select row(9001, 'hello'::text)", .{})).?; + defer row.deinit() catch {}; + + var record = row.record(0); + try t.expectEqual(2, record.number_of_columns); + try t.expectEqual(9001, record.next(i32)); + try t.expectString("hello", record.next([]const u8)); + } + + { + var row = (try c.rowUnsafe("select row(null)", .{})).?; + defer row.deinit() catch {}; + + var record = row.record(0); + try t.expectEqual(1, record.number_of_columns); + try t.expectEqual(null, record.next(?i32)); + } +} + +test "Conn: application_name" { + var conn = try Conn.open(t.allocator, .{}); + defer conn.deinit(); + try conn.auth(.{ + .username = "pgz_user_clear", + .password = "pgz_user_clear_pw", + .database = "postgres", + .application_name = "pg_zig_test", + }); + + var row = (try conn.rowUnsafe("show application_name", .{})) orelse unreachable; + defer row.deinit() catch {}; + + try t.expectString("pg_zig_test", row.get([]const u8, 0)); +} + +test "PG: bind strictness" { + var c = t.connect(.{}); + defer c.deinit(); + try t.expectError(error.BindWrongType, c.row("select $1", .{100})); + try t.expectError(error.BindWrongType, c.row("select $1", .{10.2})); + try t.expectError(error.BindWrongType, c.row("select $1", .{true})); + + try t.expectError(error.BindWrongType, c.row("select $1", .{@as(i32, 100)})); + try t.expectError(error.BindWrongType, c.row("select $1", .{@as(f32, 10.2)})); + + // conn is still usable + try t.expectEqual(4, t.scalar(&c, "select 4")); +} + +test "PG: eager error" { + var c = t.connect(.{}); + defer c.deinit(); + + { + // Some errors happen when the prepared statement is executed + try t.expectError(error.PG, c.query("select * from invalid", .{})); + try t.expectString("relation \"invalid\" does not exist", c.err.?.message); + } + + { + // some errors only happen when the result is read + try c.begin(); + defer c.rollback() catch {}; + const sql = "create temp table test1 (id int) on commit drop"; + _ = try c.exec(sql, .{}); + try t.expectError(error.PG, c.query(sql, .{})); + } +} + +// https://github.com/karlseguin/pg.zig/issues/44 +test "PG: eager error conn state" { + var pool = try lib.Pool.init(t.allocator, .{ .size = 1, .auth = t.authOpts(.{}) }); + defer pool.deinit(); + + { + var c = try pool.acquire(); + defer c.release(); + + // duplicate it + _ = try c.exec("insert into all_types (id) values ($1)", .{2000}); + try t.expectError(error.PG, c.exec("insert into all_types (id) values ($1)", .{2000})); + } + + { + // only 1 connection in our pool, so the fact that the above fails and + // this one succeeds, means we're properly handling the failure + var c = try pool.acquire(); + defer c.release(); + _ = try c.exec("insert into all_types (id) values ($1)", .{2001}); + } +} + +// https://github.com/karlseguin/pg.zig/issues/45 +test "PG: rollback during error" { + var pool = try lib.Pool.init(t.allocator, .{ .size = 1, .auth = t.authOpts(.{}) }); + defer pool.deinit(); + + _ = try pool.exec("truncate table all_types", .{}); + + { + var c = try pool.acquire(); + defer c.release(); + + try c.begin(); + // duplicate it + _ = try c.exec("insert into all_types (id) values ($1)", .{3000}); + try t.expectError(error.PG, c.exec("insert into all_types (id) values ($1)", .{3000})); + try c.rollback(); + } + + { + // only 1 connection in our pool, so the fact that the above fails and + // this one succeeds, means we're properly handling the failure + var c = try pool.acquire(); + defer c.release(); + _ = try c.exec("insert into all_types (id) values ($1)", .{3001}); + } + + var result = try pool.query("select id from all_types order by id", .{}); + defer result.deinit(); + + try t.expectEqual(3001, (try result.next()).?.get(i32, 0)); + try t.expectEqual(null, (try result.next())); +} + +test "open URI" { + const uri = try std.Uri.parse("postgresql://postgres:postgres@localhost:5432/postgres?tcp_user_timeout=5000"); + var conn = try Conn.openAndAuthUri(t.allocator, uri); + conn.deinit(); +} + +test "Conn: TLS required" { + { + var conn = try Conn.open(t.allocator, .{ .tls = .off }); + defer conn.deinit(); + try t.expectError(error.PG, conn.auth(.{ .username = "pgz_user_ssl" })); + try t.expectEqual(true, std.mem.indexOf(u8, conn.err.?.message, "no encryption") != null); + } + + { + var conn = t.connect(.{ .tls = Conn.Opts.TLS.require, .username = "pgz_user_ssl", .password = "pgz_user_ssl_pw" }); + defer conn.deinit(); + } +} + +test "Conn: TLS verify-full" { + try t.expectError(error.SSLCertificationVerificationError, Conn.open(t.allocator, .{ .tls = .{ .verify_full = null } })); + + { + var conn = t.connect(.{ .tls = Conn.Opts.TLS{ .verify_full = "tests/root.crt" }, .username = "pgz_user_ssl", .password = "pgz_user_ssl_pw" }); + defer conn.deinit(); + } +} + +test "PG: cached query" { + var c = t.connect(.{}); + defer c.deinit(); + + { + var result = try c.queryOpts("select $1::int as id, $2::text as name", .{ 1, "leto" }, .{ .cache_name = "c1" }); + try t.expectEqual(0, result.column_names.len); + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(1, row.get(i32, 0)); + try t.expectString("leto", row.get([]u8, 1)); + + try t.expectEqual(null, try result.next()); + result.deinit(); + } + + { + var result = try c.queryOpts("slc", .{ 2, "ghanima" }, .{ .cache_name = "c1" }); + try t.expectEqual(0, result.column_names.len); + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(2, row.get(i32, 0)); + try t.expectString("ghanima", row.get([]u8, 1)); + + try t.expectEqual(null, try result.next()); + result.deinit(); + } + + try c.deallocate("c1"); + + { + try t.expectError(error.PG, c.queryOpts("slc", .{ 2, "ghanima" }, .{ .cache_name = "c1" })); + try t.expectEqual(true, std.mem.indexOf(u8, c.err.?.message, "syntax error at or near \"slc\"") != null); + } +} + +test "PG: cached query with column names" { + var c = t.connect(.{}); + defer c.deinit(); + + { + var result = try c.queryOpts("select $1::int as id, $2::text as name", .{ 1, "leto" }, .{ .cache_name = "c2", .column_names = true }); + try t.expectEqual(2, result.column_names.len); + try t.expectString("id", result.column_names[0]); + try t.expectString("name", result.column_names[1]); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(1, row.get(i32, 0)); + try t.expectString("leto", row.get([]u8, 1)); + + try t.expectEqual(null, try result.next()); + result.deinit(); + } + + { + var result = try c.queryOpts("", .{ 2, "ghanima" }, .{ .cache_name = "c2", .column_names = true }); + try t.expectEqual(2, result.column_names.len); + try t.expectString("id", result.column_names[0]); + try t.expectString("name", result.column_names[1]); + + const row = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(2, row.get(i32, 0)); + try t.expectString("ghanima", row.get([]u8, 1)); + + try t.expectEqual(null, try result.next()); + result.deinit(); + } +} + +fn expectNumeric(numeric: types.Numeric, expected: []const u8) !void { + var str_buf: [50]u8 = undefined; + try t.expectString(expected, try numeric.toString(&str_buf)); + + const a = try t.allocator.alloc(u8, numeric.estimatedStringLen()); + defer t.allocator.free(a); + try t.expectString(expected, try numeric.toString(a)); + + if (std.mem.eql(u8, expected, "nan")) { + try t.expectEqual(true, std.math.isNan(numeric.toFloat())); + } else if (std.mem.eql(u8, expected, "inf")) { + try t.expectEqual(true, std.math.isInf(numeric.toFloat())); + } else if (std.mem.eql(u8, expected, "-inf")) { + try t.expectEqual(true, std.math.isNegativeInf(numeric.toFloat())); + } else { + try t.expectDelta(try std.fmt.parseFloat(f64, expected), numeric.toFloat(), 0.000001); + } +} + +const DummyStruct = struct { + id: i32, + name: []const u8, +}; + +const DummyEnum = enum { + val1, + val2, +}; diff --git a/zig/pg/src/lib.zig b/zig/pg/src/lib.zig new file mode 100644 index 0000000..d2b37d0 --- /dev/null +++ b/zig/pg/src/lib.zig @@ -0,0 +1,335 @@ +// Exposed within this library +const std = @import("std"); + +pub const openssl = @cImport({ + @cInclude("openssl/ssl.h"); + @cInclude("openssl/err.h"); +}); + +const build_config = @import("config"); + +pub const log = std.log.scoped(.pg); + +pub const types = @import("types.zig"); +pub const proto = @import("proto.zig"); +pub const auth = @import("auth.zig"); +pub const Conn = @import("conn.zig").Conn; +pub const Stmt = @import("stmt.zig").Stmt; +pub const Pool = @import("pool.zig").Pool; +pub const Stream = @import("stream.zig").Stream; +pub const metrics = @import("metrics.zig"); +pub const has_openssl = build_config.openssl; +pub const SSLCtx = if (has_openssl) openssl.SSL_CTX else void; +pub const default_column_names = build_config.column_names; +/// True when the build was configured with `-Diouring=true` AND the +/// target is Linux. The Stream type in stream.zig switches to a +/// per-connection io_uring transport when this is true. +pub const has_iouring = build_config.iouring and @import("builtin").os.tag == .linux; + +const result = @import("result.zig"); +pub const Row = result.Row; +pub const RowUnsafe = result.RowUnsafe; +pub const Result = result.Result; +pub const Iterator = result.Iterator; +pub const IteratorUnsafe = result.IteratorUnsafe; +pub const QueryRow = result.QueryRow; +pub const QueryRowUnsafe = result.QueryRowUnsafe; +pub const Mapper = result.Mapper; + +const reader = @import("reader.zig"); +pub const Reader = reader.Reader; +pub const Message = reader.Message; + +pub const testing = @import("t.zig"); + +const root = @import("root"); +const _assert = blk: { + if (@hasDecl(root, "pg_assert")) { + break :blk root.pg_assert; + } + switch (@import("builtin").mode) { + .ReleaseFast, .ReleaseSmall => break :blk false, + else => break :blk true, + } +}; + +pub const _stderr_tls = blk: { + if (@hasDecl(root, "pg_stderr_tls")) { + break :blk root.pg_stderr_tls; + } + break :blk false; +}; + +pub fn assert(ok: bool) void { + if (comptime _assert) { + std.debug.assert(ok); + } +} + +pub fn verifyDecodeType(comptime fail_mode: FailMode, comptime T: type, comptime expected_oids: []const i32, actual: i32) !void { + if (comptime fail_mode == .safe) { + if (isExpectedId(expected_oids, actual)) { + return; + } + return error.InvalidType; + } + + if (comptime _assert == false) { + return; + } + + if (isExpectedId(expected_oids, actual)) { + return; + } + + log.warn("PostgreSQL value of type {s} cannot be read into a " ++ @typeName(T) ++ ". " ++ + "pg.zig has strict type checking when reading value.", .{types.oidToString(actual)}); + unreachable; +} + +fn isExpectedId(comptime expected_oids: []const i32, actual: i32) bool { + inline for (expected_oids) |expected_oid| { + if (expected_oid == actual) { + return true; + } + } + return false; +} + +pub fn verifyNotNull(comptime fail_mode: FailMode, comptime T: type, is_null: bool) !void { + if (comptime fail_mode == .safe) { + if (is_null == false) { + return; + } + return error.UnexpectedNull; + } + + if (comptime _assert == false) { + return; + } + + if (is_null == false) { + return; + } + + log.warn("PostgreSQL null column cannot be read into non-optional type (" ++ @typeName(T) ++ "). " ++ + "pg.zig has strict type checking when reading value.", .{}); + unreachable; +} + +pub fn verifyColumnName(comptime fail_mode: FailMode, name: []const u8, valid: bool) !void { + if (comptime fail_mode == .safe) { + if (valid) { + return; + } + return error.UnknownColumnName; + } + + if (comptime _assert == false) { + return; + } + + if (valid) { + return; + } + + log.warn("Unknown column name '{s}'", .{name}); + unreachable; +} + +pub const ParsedOpts = struct { + opts: Pool.Opts, + arena: std.heap.ArenaAllocator, + + pub fn deinit(self: *ParsedOpts) void { + self.arena.deinit(); + } +}; + +pub fn parseOpts(uri: std.Uri, allocator: std.mem.Allocator) !ParsedOpts { + if (!std.mem.eql(u8, uri.scheme, "postgresql") and !std.mem.eql(u8, uri.scheme, "postgres")) { + return error.InvalidUriScheme; + } + + var arena = std.heap.ArenaAllocator.init(allocator); + errdefer arena.deinit(); + const aa = arena.allocator(); + + var tls: Conn.Opts.TLS = .off; + var tcp_user_timeout: ?u32 = null; + if (uri.query) |qry| { + const query_string = try qry.toRawMaybeAlloc(aa); + var it = std.mem.splitScalar(u8, query_string, '&'); + while (it.next()) |param| { + var it2 = std.mem.splitScalar(u8, param, '='); + const key = it2.first(); + const val = it2.rest(); + if (std.mem.eql(u8, key, "tcp_user_timeout")) { + tcp_user_timeout = try std.fmt.parseInt(u32, val, 10); + } else if (std.mem.eql(u8, key, "sslmode")) { + if (std.mem.eql(u8, val, "require")) { + tls = .require; + } else if (std.mem.eql(u8, val, "verify-full")) { + tls = .{ .verify_full = null }; + } else if (std.mem.eql(u8, val, "disable") == false) { + return error.UnsupportedSSLModeValue; + } + } else { + return error.UnsupportedConnectionParam; + } + } + } + + const path = std.mem.trimLeft(u8, try uri.path.toRawMaybeAlloc(aa), "/"); + const host = if (uri.host) |host| try host.toRawMaybeAlloc(aa) else null; + const username = if (uri.user) |user| try user.toRawMaybeAlloc(aa) else "postgres"; + const password = if (uri.password) |password| try password.toRawMaybeAlloc(aa) else null; + + // don't use `aa` after this point, we're about to copy `arena` and any usage + // of `aa` will leak + + return .{ .arena = arena, .opts = .{ + .size = 0, + .timeout = 0, + .auth = .{ + .username = username, + .password = password, + .database = if (path.len == 0) null else path, + .timeout = tcp_user_timeout orelse 10_000, + }, + .connect = .{ + .tls = tls, + .port = uri.port orelse null, + .host = host, + }, + } }; +} + +pub fn initializeSSLContext(config: Conn.Opts.TLS) !*SSLCtx { + // OpenSSL documentation says these are implicitly called, and only need to + // be called if you're doing something special + + // if (openssl.OPENSSL_init_ssl(openssl.OPENSSL_INIT_LOAD_SSL_STRINGS | openssl.OPENSSL_INIT_LOAD_CRYPTO_STRINGS, null) != 1) { + // return error.OpenSSLInitSslFailed; + // } + + // if (openssl.OPENSSL_init_crypto(openssl.OPENSSL_INIT_ADD_ALL_CIPHERS | openssl.OPENSSL_INIT_ADD_ALL_DIGESTS | openssl.OPENSSL_INIT_LOAD_CRYPTO_STRINGS, null) != 1) { + // return error.OpenSSLInitCryptoFailed; + // } + + const ctx = openssl.SSL_CTX_new(openssl.TLS_client_method()) orelse { + return error.SSLContextNew; + }; + errdefer openssl.SSL_CTX_free(ctx); + + if (openssl.SSL_CTX_set_min_proto_version(ctx, openssl.TLS1_2_VERSION) != 1) { + return error.SSLMinVersion; + } + + _ = openssl.SSL_CTX_set_mode(ctx, openssl.SSL_MODE_AUTO_RETRY); + + switch (config) { + .off, .require => {}, + .verify_full => |path_to_root| { + if (path_to_root) |p| { + var pathz: [std.fs.max_path_bytes + 1]u8 = undefined; + @memcpy(pathz[0..p.len], p); + pathz[p.len] = 0; + if (openssl.SSL_CTX_load_verify_locations(ctx, pathz[0 .. p.len + 1].ptr, null) != 1) { + if (comptime _stderr_tls) { + printSSLError(); + } + return error.SSLVerifyPaths; + } + } else { + if (openssl.SSL_CTX_set_default_verify_paths(ctx) != 1) { + if (comptime _stderr_tls) { + printSSLError(); + } + return error.SSLDefaultVerifyPaths; + } + } + openssl.SSL_CTX_set_verify(ctx, openssl.SSL_VERIFY_PEER, null); + }, + } + + return ctx; +} + +pub fn freeSSLContext(ctx: ?*SSLCtx) void { + if (comptime has_openssl == false) { + return; + } + + if (ctx) |c| { + openssl.SSL_CTX_free(c); + } +} + +pub fn printSSLError() void { + if (comptime has_openssl == false) { + return; + } + + const bio = openssl.BIO_new(openssl.BIO_s_mem()); + defer _ = openssl.BIO_free(bio); + openssl.ERR_print_errors(bio); + var buf: [*]u8 = undefined; + const len: usize = @intCast(openssl.BIO_get_mem_data(bio, &buf)); + if (len > 0) { + std.debug.print("{s}\n", .{buf[0..len]}); + } +} + +pub const Binary = struct { + data: []const u8, +}; + +const TestCase = struct { + uri: []const u8, + expected_opts: Pool.Opts, +}; + +pub const FailMode = enum { + safe, + unsafe, +}; + +pub const TypeError = error{ + InvalidType, + UnexpectedNull, + UnknownColumnName, +}; + +const valid_tcs: [2]TestCase = .{ + .{ .uri = "postgresql:///", .expected_opts = .{ .size = 0, .auth = .{ .username = "postgres" }, .connect = .{}, .timeout = 0 } }, + .{ .uri = "postgresql://user:pass@somehost:1234/somedb?tcp_user_timeout=5678", .expected_opts = .{ .size = 0, .auth = .{ + .username = "user", + .password = "pass", + .database = "somedb", + .timeout = 5678, + }, .connect = .{ + .host = "somehost", + .port = 1234, + }, .timeout = 0 } }, +}; + +test "URI: parse valid" { + const a = std.testing.allocator; + for (valid_tcs) |tc| { + var po = parseOpts(try std.Uri.parse(tc.uri), a) catch |e| { + std.log.err("failed to parse URI {s}", .{tc.uri}); + return e; + }; + defer po.deinit(); + try std.testing.expectEqualDeep(tc.expected_opts, po.opts); + } +} + +test "URI: invalid scheme" { + try std.testing.expectError(error.InvalidUriScheme, parseOpts(try std.Uri.parse("foobar:///"), std.testing.allocator)); +} + +test "URI: invalid params" { + try std.testing.expectError(error.UnsupportedConnectionParam, parseOpts(try std.Uri.parse("postgresql:///?bar=baz"), std.testing.allocator)); +} diff --git a/zig/pg/src/listener.zig b/zig/pg/src/listener.zig new file mode 100644 index 0000000..9efc505 --- /dev/null +++ b/zig/pg/src/listener.zig @@ -0,0 +1,269 @@ +const std = @import("std"); +const lib = @import("lib.zig"); +const Buffer = @import("buffer").Buffer; + +const proto = lib.proto; +const Conn = lib.Conn; +const Reader = lib.Reader; +const NotificationResponse = lib.proto.NotificationResponse; + +const Stream = lib.Stream; +const Allocator = std.mem.Allocator; + +const ListenError = union(enum) { + err: anyerror, + pg: lib.proto.Error, +}; + +pub const Listener = struct { + err: ?ListenError = null, + closed: bool = false, + + _stream: Stream, + + // A buffer used for writing to PG. This can grow dynamically as needed. + _buf: Buffer, + + // Used to read data from PG. Has its own buffer which can grow dynamically + _reader: Reader, + + // If we get a PG error, we'll return a LIstenError.pg, and we'll own its + // memory. + _err_data: ?[]const u8 = null, + + _allocator: Allocator, + + pub fn open(allocator: Allocator, opts: Conn.Opts) !Listener { + var stream = try Stream.connect(allocator, opts, null); + errdefer stream.close(); + + const buf = try Buffer.init(allocator, opts.write_buffer orelse 2048); + errdefer buf.deinit(); + + const reader = try Reader.init(allocator, opts.read_buffer orelse 4096, stream); + errdefer reader.deinit(); + + return .{ + ._buf = buf, + ._stream = stream, + ._reader = reader, + ._allocator = allocator, + }; + } + + pub fn deinit(self: *Listener) void { + if (self._err_data) |err_data| { + self._allocator.free(err_data); + } + self._buf.deinit(); + self._reader.deinit(); + + self.stop(); + } + + pub fn stop(self: *Listener) void { + if (@atomicRmw(bool, &self.closed, .Xchg, true, .monotonic) == true) { + return; + } + + // try to send a Terminate to the DB + self._stream.writeAll(&.{ 'X', 0, 0, 0, 4 }) catch {}; + self._stream.close(); + } + + pub fn auth(self: *Listener, opts: Conn.AuthOpts) !void { + if (try lib.auth.auth(&self._stream, &self._buf, &self._reader, opts)) |raw_pg_err| { + return self.setErr(raw_pg_err); + } + + while (true) { + const msg = try self.read(); + switch (msg.type) { + 'Z' => return, + 'K' => {}, // TODO: BackendKeyData + 'S' => {}, // TODO: ParameterStatus, + else => return error.UnexpectedDBMessage, + } + } + } + + const ListenOpts = struct { + timeout: u32 = 0, + }; + pub fn listen(self: *Listener, channel: []const u8, opts: ListenOpts) !void { + // LISTEN doesn't support parameterized queries. It has to be a simple query. + // We don't use proto.Query because we want to quote the identifier. + + const buf = &self._buf; + buf.reset(); + + // "LISTEN " = 7 + // "IDENTIFIER" = 128 + // max identifier size is 63, but if we need to quote every character, that's + // 126. + 2 for the opening and closing quote + // + 1 for null terminator + try buf.ensureTotalCapacity(136); + buf.writeByteAssumeCapacity('Q'); + + var len_view = try buf.skip(4); + + buf.writeAssumeCapacity("LISTEN \""); + + // + 4 for the length itself + // + 7 for the LISTEN + // + 2 for the quotes + // + 1 for the null terminator + var len = 11 + channel.len + 3; + for (channel) |c| { + if (c == '"') { + len += 1; + buf.writeAssumeCapacity("\"\""); + } else { + buf.writeByteAssumeCapacity(c); + } + } + buf.writeByteAssumeCapacity('"'); + buf.writeByteAssumeCapacity(0); + + // fill in the length + len_view.writeIntBig(u32, @intCast(len)); + + try self._stream.writeAll(buf.string()); + + { + // we expect a command complete ('C') + const msg = try self.read(); + switch (msg.type) { + 'C' => {}, + else => return error.UnexpectedDBMessage, + } + } + + { + // followed by a ReadyForQuery ('Z') + const msg = try self.read(); + switch (msg.type) { + 'Z' => {}, + else => return error.UnexpectedDBMessage, + } + } + + try self._reader.startFlow(null, opts.timeout); + } + + pub fn next(self: *Listener) ?NotificationResponse { + const msg = self.read() catch |err| { + self.err = .{ .err = err }; + return null; + }; + + switch (msg.type) { + 'A' => return NotificationResponse.parse(msg.data) catch |err| { + self.err = .{ .err = err }; + return null; + }, + else => { + self.err = .{ .err = error.UnexpectedDBMessage }; + return null; + }, + } + } + + fn read(self: *Listener) !lib.Message { + var reader = &self._reader; + while (true) { + const msg = try reader.next(); + switch (msg.type) { + 'N' => {}, // TODO: NoticeResponse + 'E' => return self.setErr(msg.data), + else => return msg, + } + } + } + + fn setErr(self: *Listener, data: []const u8) error{ PG, OutOfMemory } { + const allocator = self._allocator; + + // The proto.Error that we're about to create is going to reference data. + // But data is owned by our Reader and its lifetime doesn't necessarily match + // what we want here. So we're going to dupe it and make the connection own + // the data so it can tie its lifecycle to the error. + + // That means clearing out any previous duped error data we had + if (self._err_data) |err_data| { + allocator.free(err_data); + } + + const owned = try allocator.dupe(u8, data); + self._err_data = owned; + self.err = .{ .pg = proto.Error.parse(owned) }; + return error.PG; + } +}; + +const t = lib.testing; +test "Listener" { + var l = try Listener.open(t.allocator, .{ .host = "localhost" }); + defer l.deinit(); + try l.auth(t.authOpts(.{})); + try testListener(&l); +} + +test "Listener: from Pool" { + var pool = try lib.Pool.init(t.allocator, .{ + .size = 1, + .auth = t.authOpts(.{}), + }); + defer pool.deinit(); + + var l = try pool.newListener(); + defer l.deinit(); + + try testListener(&l); +} + +fn testListener(l: *Listener) !void { + var reset: std.Thread.ResetEvent = .{}; + var tt = try std.Thread.spawn(.{}, struct { + fn shutdown(ll: *Listener, r: *std.Thread.ResetEvent) void { + r.wait(); + ll.stop(); + } + }.shutdown, .{ l, &reset }); + tt.detach(); + + try l.listen("chan-1", .{}); + try l.listen("chan_2", .{}); + + const thrd = try std.Thread.spawn(.{}, testNotifier, .{}); + { + const notification = l.next().?; + try t.expectString("chan-1", notification.channel); + try t.expectString("pl-1", notification.payload); + } + + { + const notification = l.next().?; + try t.expectString("chan_2", notification.channel); + try t.expectString("pl-2", notification.payload); + } + + { + const notification = l.next().?; + try t.expectString("chan-1", notification.channel); + try t.expectString("", notification.payload); + } + + reset.set(); + try t.expectEqual(null, l.next()); + thrd.join(); +} + +fn testNotifier() void { + var c = t.connect(.{}); + defer c.deinit(); + _ = c.exec("select pg_notify($1, $2)", .{ "chan_x", "pl-x" }) catch unreachable; + _ = c.exec("select pg_notify($1, $2)", .{ "chan-1", "pl-1" }) catch unreachable; + _ = c.exec("select pg_notify($1, $2)", .{ "chan_2", "pl-2" }) catch unreachable; + _ = c.exec("select pg_notify($1, null)", .{"chan-1"}) catch unreachable; +} diff --git a/zig/pg/src/metrics.zig b/zig/pg/src/metrics.zig new file mode 100644 index 0000000..831f0f9 --- /dev/null +++ b/zig/pg/src/metrics.zig @@ -0,0 +1,56 @@ +const m = @import("metrics"); + +// This is an advanced usage of metrics.zig, largely done because we aren't +// using any vectored metrics and thus can do everything at comptime. +var metrics = Metrics{ + .queries = m.Counter(usize).Impl.init("pg_query", .{}), + .pool_empty = m.Counter(usize).Impl.init("pg_pool_empty", .{}), + .pool_dirty = m.Counter(usize).Impl.init("pg_pool_dirty", .{}), + .alloc_params = m.Counter(usize).Impl.init("pg_alloc_params", .{}), + .alloc_columns = m.Counter(usize).Impl.init("pg_alloc_columns", .{}), + .alloc_reader = m.Counter(usize).Impl.init("pg_alloc_reader", .{}), +}; + +const Metrics = struct { + queries: m.Counter(usize).Impl, + pool_empty: m.Counter(usize).Impl, + pool_dirty: m.Counter(usize).Impl, + alloc_params: m.Counter(usize).Impl, + alloc_columns: m.Counter(usize).Impl, + alloc_reader: m.Counter(usize).Impl, +}; + +pub fn write(writer: anytype) !void { + try metrics.queries.write(writer); + try metrics.pool_empty.write(writer); + try metrics.pool_dirty.write(writer); + try metrics.alloc_params.write(writer); + try metrics.alloc_columns.write(writer); + try metrics.alloc_reader.write(writer); +} + +pub fn query() void { + metrics.queries.incr(); +} + +pub fn poolEmpty() void { + metrics.pool_empty.incr(); +} + +pub fn poolDirty() void { + metrics.pool_dirty.incr(); +} + +pub fn allocParams(count: usize) void { + // this is the # of parameters, not the bytes allocated. + metrics.alloc_params.incrBy(count); +} + +pub fn allocColumns(count: usize) void { + // this is the # of columns, not the bytes allocated. + metrics.alloc_columns.incrBy(count); +} + +pub fn allocReader(size: usize) void { + metrics.alloc_reader.incrBy(size); +} diff --git a/zig/pg/src/pg.zig b/zig/pg/src/pg.zig new file mode 100644 index 0000000..06b9279 --- /dev/null +++ b/zig/pg/src/pg.zig @@ -0,0 +1,43 @@ +const std = @import("std"); +const lib = @import("lib.zig"); + +pub const Row = lib.Row; +pub const Conn = lib.Conn; +pub const Pool = lib.Pool; +pub const Stmt = lib.Stmt; +pub const Result = lib.Result; +pub const Iterator = lib.Iterator; +pub const QueryRow = lib.QueryRow; +pub const Mapper = lib.Mapper; +pub const Binary = lib.Binary; + +pub const Listener = @import("listener.zig").Listener; + +pub const types = lib.types; +pub const DynamicValue = types.DynamicValue; +pub const Cidr = types.Cidr; +pub const Numeric = types.Numeric; +pub const Vector = types.Vector; +pub const Error = lib.proto.Error; +pub const printSSLError = lib.printSSLError; + +pub fn uuidToHex(uuid: []const u8) ![36]u8 { + return lib.types.UUID.toString(uuid); +} + +pub fn writeMetrics(writer: anytype) !void { + return @import("metrics.zig").write(writer); +} + +const t = lib.testing; +test { + try t.setup(); + std.testing.refAllDecls(@This()); +} + +test "pg: uuidToHex" { + try t.expectError(error.InvalidUUID, uuidToHex(&.{ 73, 190, 142, 9, 170, 250, 176, 16, 73, 21 })); + + const s = try uuidToHex(&.{ 183, 204, 40, 47, 236, 67, 73, 190, 142, 9, 170, 250, 176, 16, 73, 21 }); + try t.expectString("b7cc282f-ec43-49be-8e09-aafab0104915", &s); +} diff --git a/zig/pg/src/pool.zig b/zig/pg/src/pool.zig new file mode 100644 index 0000000..d256d6b --- /dev/null +++ b/zig/pg/src/pool.zig @@ -0,0 +1,616 @@ +const std = @import("std"); +const lib = @import("lib.zig"); + +const log = lib.log; +const Conn = lib.Conn; +const Result = lib.Result; +const SSLCtx = lib.SSLCtx; +const QueryRow = lib.QueryRow; +const QueryRowUnsafe = lib.QueryRowUnsafe; +const DynamicValue = lib.types.DynamicValue; +const Listener = @import("listener.zig").Listener; + +const Thread = std.Thread; +const Allocator = std.mem.Allocator; + +// Zig 0.16: Thread.Mutex/Condition moved to Io.Mutex/Condition which requires an Io instance. +// pg has no Io context, so fall back to POSIX pthreads directly. +const PthreadMutex = struct { + inner: std.c.pthread_mutex_t = std.c.PTHREAD_MUTEX_INITIALIZER, + + pub fn lock(m: *PthreadMutex) void { + _ = std.c.pthread_mutex_lock(&m.inner); + } + pub fn unlock(m: *PthreadMutex) void { + _ = std.c.pthread_mutex_unlock(&m.inner); + } + pub fn tryLock(m: *PthreadMutex) bool { + return @intFromEnum(std.c.pthread_mutex_trylock(&m.inner)) == 0; + } +}; + +const PthreadCondition = struct { + inner: std.c.pthread_cond_t = std.c.PTHREAD_COND_INITIALIZER, + + pub fn timedWait(cond: *PthreadCondition, mutex: *PthreadMutex, timeout_ns: u64) !void { + var ts: std.c.timespec = undefined; + _ = std.c.clock_gettime(.REALTIME, &ts); + const now_ns: u128 = @as(u128, @intCast(ts.sec)) * 1_000_000_000 + @as(u128, @intCast(ts.nsec)); + const deadline_ns: u128 = now_ns + timeout_ns; + const abs_time = std.c.timespec{ + .sec = @intCast(deadline_ns / 1_000_000_000), + .nsec = @intCast(deadline_ns % 1_000_000_000), + }; + const rc = std.c.pthread_cond_timedwait(&cond.inner, &mutex.inner, &abs_time); + if (@intFromEnum(rc) == @intFromEnum(std.c.E.TIMEDOUT)) return error.Timeout; + } + + pub fn signal(cond: *PthreadCondition) void { + _ = std.c.pthread_cond_signal(&cond.inner); + } + pub fn broadcast(cond: *PthreadCondition) void { + _ = std.c.pthread_cond_broadcast(&cond.inner); + } +}; + +fn nanoTimestamp() i128 { + var ts: std.c.timespec = undefined; + _ = std.c.clock_gettime(.REALTIME, &ts); + return @as(i128, ts.sec) * 1_000_000_000 + @as(i128, ts.nsec); +} + +fn threadSleep(ns: u64) void { + const ts = std.c.timespec{ + .sec = @intCast(ns / std.time.ns_per_s), + .nsec = @intCast(ns % std.time.ns_per_s), + }; + _ = std.c.nanosleep(&ts, null); +} + +pub const Pool = struct { + _opts: Opts, + _timeout: u64, + _conns: []*Conn, + _available: usize, + _missing: usize, + _allocator: Allocator, + _mutex: PthreadMutex, + _cond: PthreadCondition, + _ssl_ctx: ?*lib.SSLCtx, + _reconnector: Reconnector, + _arena: std.heap.ArenaAllocator, + + pub const Opts = struct { + size: u16 = 10, + auth: Conn.AuthOpts = .{}, + connect: Conn.Opts = .{}, + timeout: u32 = 10 * std.time.ms_per_s, + connect_on_init_count: ?u16 = null, + max_queries_per_conn: u64 = 0, // 0 = unlimited + max_conn_lifetime: i64 = 0, // 0 = unlimited (seconds) + }; + + pub const Stats = struct { + size: usize, + available: usize, + missing: usize, + in_use: usize, + }; + + pub fn initUri(allocator: Allocator, uri: std.Uri, opts: Opts) !*Pool { + var po = try lib.parseOpts(uri, allocator); + defer po.deinit(); + po.opts.size = opts.size; + po.opts.timeout = opts.timeout; + return Pool.init(allocator, po.opts); + } + + pub fn init(allocator: Allocator, opts: Opts) !*Pool { + var arena = std.heap.ArenaAllocator.init(allocator); + const aa = arena.allocator(); + errdefer arena.deinit(); + + const pool = try aa.create(Pool); + const size = opts.size; + const conns = try aa.alloc(*Conn, size); + + var opts_copy = opts; + var ssl_ctx: ?*SSLCtx = null; + if (comptime lib.has_openssl) { + switch (opts.connect.tls) { + .off => {}, + else => |tls_config| { + if (opts.connect.host) |h| { + opts_copy.connect._hostz = try aa.dupeZ(u8, h); + } + ssl_ctx = try lib.initializeSSLContext(tls_config); + }, + } + } + errdefer lib.freeSSLContext(ssl_ctx); + const connect_on_init_count = opts.connect_on_init_count orelse size; + + pool.* = .{ + ._cond = .{}, + ._mutex = .{}, + ._conns = conns, + ._arena = arena, + ._opts = opts_copy, + ._ssl_ctx = ssl_ctx, + ._missing = 0, + ._allocator = allocator, + ._available = connect_on_init_count, + ._reconnector = Reconnector.init(pool), + ._timeout = @as(u64, @intCast(opts.timeout)) * std.time.ns_per_ms, + }; + + var opened_connections: usize = 0; + errdefer { + for (0..opened_connections) |i| { + conns[i].deinit(); + } + } + + for (0..connect_on_init_count) |i| { + conns[i] = try newConnection(pool, true); + opened_connections += 1; + } + + const lazy_start_count = size - connect_on_init_count; + pool._missing = lazy_start_count; + for (0..lazy_start_count) |_| { + try pool._reconnector.reconnect(); + } + + return pool; + } + + pub fn deinit(self: *Pool) void { + self._reconnector.stop(); + const allocator = self._allocator; + for (self._conns) |conn| { + conn.deinit(); + allocator.destroy(conn); + } + lib.freeSSLContext(self._ssl_ctx); + self._arena.deinit(); + } + + pub fn acquire(self: *Pool) !*Conn { + const conns = self._conns; + const deadline = nanoTimestamp() + @as(i128, @intCast(self._timeout)); + + self._mutex.lock(); + errdefer self._mutex.unlock(); + + while (true) { + const available = self._available; + const missing = self._missing; + + if (available == 0) { + // Check if pool is completely exhausted + const total_alive = self._conns.len - missing; + if (total_alive == 0) { + return error.PoolExhausted; + } + + lib.metrics.poolEmpty(); + + // Calculate remaining timeout + const now = nanoTimestamp(); + if (now >= deadline) { + return error.Timeout; + } + const remaining_ns: u64 = @intCast(deadline - now); + + try self._cond.timedWait(&self._mutex, remaining_ns); + continue; + } + + const index = available - 1; + const conn = conns[index]; + self._available = index; + self._mutex.unlock(); + return conn; + } + } + + pub fn release(self: *Pool, conn: *Conn) void { + var conn_to_add = conn; + var needs_replace = conn._state != .idle; + + if (!needs_replace) { + // Check if connection should be rotated (query count or age) + const opts = self._opts; + if (opts.max_queries_per_conn > 0 and conn.queryCount() >= opts.max_queries_per_conn) { + needs_replace = true; + } else if (opts.max_conn_lifetime > 0 and conn.age() >= opts.max_conn_lifetime) { + needs_replace = true; + } + } + + if (needs_replace) { + if (conn._state != .idle) { + lib.metrics.poolDirty(); + } + conn.deinit(); + self._allocator.destroy(conn); + + conn_to_add = newConnection(self, true) catch |err1| { + self._mutex.lock(); + self._missing += 1; + self._mutex.unlock(); + + self._reconnector.reconnect() catch |err2| { + log.err("Re-opening connection failed ({}) and background reconnector failed to start ({})", .{ err1, err2 }); + }; + return; + }; + } + + var conns = self._conns; + self._mutex.lock(); + const available = self._available; + conns[available] = conn_to_add; + self._available = available + 1; + self._mutex.unlock(); + self._cond.signal(); + } + + pub fn newListener(self: *Pool) !Listener { + var listener = try Listener.open(self._allocator, self._opts.connect); + try listener.auth(self._opts.auth); + return listener; + } + + pub fn stats(self: *Pool) Stats { + self._mutex.lock(); + defer self._mutex.unlock(); + + const available = self._available; + const missing = self._missing; + const size = self._conns.len; + + return .{ + .size = size, + .available = available, + .missing = missing, + .in_use = size - available - missing, + }; + } + + pub fn exec(self: *Pool, sql: []const u8, values: anytype) !?i64 { + return self.execOpts(sql, values, .{}); + } + + pub fn execOpts(self: *Pool, sql: []const u8, values: anytype, opts: Conn.QueryOpts) !?i64 { + var conn = try self.acquire(); + defer self.release(conn); + return conn.execOpts(sql, values, opts); + } + + pub fn execManyDynamic(self: *Pool, sql: []const u8, rows: []const []const DynamicValue, opts: Conn.QueryOpts) !i64 { + var conn = try self.acquire(); + defer self.release(conn); + return conn.execManyDynamic(sql, rows, opts); + } + + pub fn query(self: *Pool, sql: []const u8, values: anytype) !*Result { + return self.queryOpts(sql, values, .{}); + } + + pub fn queryOpts(self: *Pool, sql: []const u8, values: anytype, opts_: Conn.QueryOpts) !*Result { + var opts = opts_; + opts.release_conn = true; + var conn = try self.acquire(); + errdefer self.release(conn); + return conn.queryOpts(sql, values, opts); + } + + pub fn row(self: *Pool, sql: []const u8, values: anytype) !?QueryRow { + return self.rowOpts(sql, values, .{}); + } + + pub fn rowUnsafe(self: *Pool, sql: []const u8, values: anytype) !?QueryRowUnsafe { + return self.rowUnsafeOpts(sql, values, .{}); + } + + pub fn rowOpts(self: *Pool, sql: []const u8, values: anytype, opts_: Conn.QueryOpts) !?QueryRow { + var opts = opts_; + opts.release_conn = true; + var conn = try self.acquire(); + return conn.rowOpts(sql, values, opts); + } + + pub fn rowUnsafeOpts(self: *Pool, sql: []const u8, values: anytype, opts_: Conn.QueryOpts) !?QueryRowUnsafe { + var opts = opts_; + opts.release_conn = true; + var conn = try self.acquire(); + return conn.rowUnsafeOpts(sql, values, opts); + } +}; + +const Reconnector = struct { + // number of connections that the pool is missing, i.e. how many need to be + // reconnected + count: usize, + + // when stop is called, this is set to true + stopped: bool, + + pool: *Pool, + mutex: PthreadMutex, + + // the thread, if any, that the monitor is running in + thread: ?Thread, + + fn init(pool: *Pool) Reconnector { + return .{ + .pool = pool, + .count = 0, + .mutex = .{}, + .stopped = false, + .thread = null, + }; + } + + fn run(self: *Reconnector) void { + const pool = self.pool; + const retry_delay = 2 * std.time.ns_per_s; + + self.mutex.lock(); + defer self.mutex.unlock(); + loop: while (self.count > 0) { + const stopped = self.stopped; + self.mutex.unlock(); + if (stopped == true) { + return; + } + + const conn = newConnection(pool, false) catch { + threadSleep(retry_delay); + self.mutex.lock(); + continue :loop; + }; + + // Decrement missing count when successfully recreated + pool._mutex.lock(); + std.debug.assert(pool._missing > 0); + pool._missing -= 1; + pool._mutex.unlock(); + + conn.release(); // inserts it into the pool + self.mutex.lock(); + self.count -= 1; + } + + self.thread.?.detach(); + self.thread = null; + } + + fn stop(self: *Reconnector) void { + self.mutex.lock(); + self.stopped = true; + self.mutex.unlock(); + if (self.thread) |thrd| { + thrd.join(); + } + } + + fn reconnect(self: *Reconnector) !void { + self.mutex.lock(); + defer self.mutex.unlock(); + self.count += 1; + if (self.thread == null) { + self.thread = try Thread.spawn(.{ .stack_size = 1024 * 1024 }, Reconnector.run, .{self}); + } + } +}; + +fn newConnection(pool: *Pool, log_failure: bool) !*Conn { + const opts = &pool._opts; + const allocator = pool._allocator; + + const conn = allocator.create(Conn) catch |err| { + if (log_failure) log.err("connect error: {}", .{err}); + return err; + }; + errdefer allocator.destroy(conn); + + conn.* = Conn.open(allocator, opts.connect) catch |err| { + if (log_failure) log.err("connect error: {}", .{err}); + return err; + }; + errdefer conn.deinit(); + + conn.auth(opts.auth) catch |err| { + if (log_failure) { + if (conn.err) |pg_err| { + log.err("connect error: {s}", .{pg_err.message}); + } else { + log.err("connect error: {}", .{err}); + } + } + return err; + }; + conn._pool = pool; + return conn; +} + +const t = lib.testing; +test "Pool" { + var pool = try Pool.init(t.allocator, .{ + .size = 2, + .auth = t.authOpts(.{}), + .connect_on_init_count = 1, + }); + defer pool.deinit(); + + { + const c1 = try pool.acquire(); + defer pool.release(c1); + _ = try c1.exec( + \\ drop table if exists pool_test; + \\ create table pool_test (id int not null) + , .{}); + } + + const t1 = try std.Thread.spawn(.{}, testPool, .{pool}); + const t2 = try std.Thread.spawn(.{}, testPool, .{pool}); + const t3 = try std.Thread.spawn(.{}, testPool, .{pool}); + + t1.join(); + t2.join(); + t3.join(); + + { + const c1 = try pool.acquire(); + defer c1.release(); + + const affected = try c1.exec("delete from pool_test", .{}); + try t.expectEqual(1500, affected.?); + } +} + +test "Pool: Release" { + var pool = try Pool.init(t.allocator, .{ + .size = 2, + .auth = .{ + .database = "postgres", + .username = "postgres", + .password = "postgres", + }, + }); + defer pool.deinit(); + + const c1 = try pool.acquire(); + c1._state = .query; + pool.release(c1); +} + +test "Pool: stats" { + var pool = try Pool.init(t.allocator, .{ + .size = 3, + .auth = t.authOpts(.{}), + }); + defer pool.deinit(); + + // Initial state: all connections available + { + const s = pool.stats(); + try t.expectEqual(3, s.size); + try t.expectEqual(3, s.available); + try t.expectEqual(0, s.missing); + try t.expectEqual(0, s.in_use); + } + + // Acquire one connection + const c1 = try pool.acquire(); + { + const s = pool.stats(); + try t.expectEqual(3, s.size); + try t.expectEqual(2, s.available); + try t.expectEqual(0, s.missing); + try t.expectEqual(1, s.in_use); + } + + // Acquire another + const c2 = try pool.acquire(); + { + const s = pool.stats(); + try t.expectEqual(3, s.size); + try t.expectEqual(1, s.available); + try t.expectEqual(0, s.missing); + try t.expectEqual(2, s.in_use); + } + + // Release one + pool.release(c1); + { + const s = pool.stats(); + try t.expectEqual(3, s.size); + try t.expectEqual(2, s.available); + try t.expectEqual(0, s.missing); + try t.expectEqual(1, s.in_use); + } + + // Release the other + pool.release(c2); + { + const s = pool.stats(); + try t.expectEqual(3, s.size); + try t.expectEqual(3, s.available); + try t.expectEqual(0, s.missing); + try t.expectEqual(0, s.in_use); + } +} + +test "Pool: exec" { + var pool = try Pool.init(t.allocator, .{ .size = 1, .auth = t.authOpts(.{}) }); + defer pool.deinit(); + + { + const n = try pool.exec("insert into simple_table values ($1), ($2), ($3)", .{ "pool_insert_args_a", "pool_insert_args_b", "pool_insert_args_c" }); + try t.expectEqual(3, n.?); + } + + { + // this makes sure the connection was returned to the pool + const n = try pool.exec("insert into simple_table values ($1)", .{"pool_insert_args_a"}); + try t.expectEqual(1, n.?); + } +} + +test "Pool: Query/Row" { + var pool = try Pool.init(t.allocator, .{ .size = 1, .auth = t.authOpts(.{}) }); + defer pool.deinit(); + + { + _ = try pool.exec("insert into all_types (id, col_int8, col_text) values ($1, $2, $3)", .{ 100, 1, "val-1" }); + _ = try pool.exec("insert into all_types (id, col_int8, col_text) values ($1, $2, $3)", .{ 101, 2, "val-2" }); + } + + for (0..3) |_| { + var result = try pool.query("select col_int8, col_text from all_types where id = any($1)", .{[2]i32{ 100, 101 }}); + defer result.deinit(); + + const row1 = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(1, row1.get(i64, 0)); + try t.expectString("val-1", row1.get([]u8, 1)); + + const row2 = (try result.nextUnsafe()) orelse unreachable; + try t.expectEqual(2, row2.get(i64, 0)); + try t.expectString("val-2", row2.get([]u8, 1)); + + try t.expectEqual(null, result.nextUnsafe()); + } + + for (0..3) |_| { + var row = try pool.rowUnsafe("select col_int8, col_text from all_types where id = $1", .{101}) orelse unreachable; + defer row.deinit() catch {}; + + try t.expectEqual(2, row.get(i64, 0)); + try t.expectString("val-2", row.get([]u8, 1)); + } +} + +test "Pool: Row error" { + var pool = try Pool.init(t.allocator, .{ .size = 1, .auth = t.authOpts(.{}) }); + defer pool.deinit(); + + _ = try pool.rowUnsafe("insert into all_types (id) values ($1)", .{200}); + + // This would segfault: + // https://github.com/karlseguin/pg.zig/issues/34 + try t.expectError(error.PG, pool.rowUnsafe("insert into all_types (id) values ($1)", .{200})); + + try t.expectEqual(1, pool._available); +} + +fn testPool(p: *Pool) void { + for (0..500) |i| { + const conn = p.acquire() catch unreachable; + _ = conn.exec("insert into pool_test (id) values ($1)", .{i}) catch unreachable; + conn.release(); + } +} diff --git a/zig/pg/src/proto.zig b/zig/pg/src/proto.zig new file mode 100644 index 0000000..b808a37 --- /dev/null +++ b/zig/pg/src/proto.zig @@ -0,0 +1,19 @@ +pub const AuthenticationRequest = @import("proto/authentication_request.zig").AuthenticationRequest; +pub const AuthenticationSASLContinue = @import("proto/AuthenticationSASLContinue.zig"); +pub const AuthenticationSASLFinal = @import("proto/AuthenticationSASLFinal.zig"); +pub const CommandComplete = @import("proto/CommandComplete.zig"); +pub const Describe = @import("proto/Describe.zig"); +pub const Error = @import("proto/Error.zig"); +pub const Execute = @import("proto/Execute.zig"); +pub const NotificationResponse = @import("proto/NotificationResponse.zig"); +pub const Parse = @import("proto/Parse.zig"); +pub const PasswordMessage = @import("proto/PasswordMessage.zig"); +pub const Query = @import("proto/Query.zig"); +pub const SASLInitialResponse = @import("proto/SASLInitialResponse.zig"); +pub const SASLResponse = @import("proto/SASLResponse.zig"); +pub const StartupMessage = @import("proto/StartupMessage.zig"); +pub const Sync = @import("proto/Sync.zig"); + +test { + @import("std").testing.refAllDecls(@This()); +} diff --git a/zig/pg/src/proto/AuthenticationSASLContinue.zig b/zig/pg/src/proto/AuthenticationSASLContinue.zig new file mode 100644 index 0000000..dc77131 --- /dev/null +++ b/zig/pg/src/proto/AuthenticationSASLContinue.zig @@ -0,0 +1,51 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); +const Reader = proto.Reader; + +// #2 - Server responds with this +const AuthenticationSASLContinue = @This(); + +data: []const u8, + +pub fn parse(data: []const u8) !AuthenticationSASLContinue { + var reader = Reader.init(data); + + if (try reader.int32() != 11) { + return error.NotSASLChallenge; + } + + return .{ + .data = reader.rest(), + }; +} + +const t = proto.testing; +test "AuthenticationSASLContinue: parse" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + // too short + try t.expectError(error.NoMoreData, AuthenticationSASLContinue.parse(buf.string())); + + try buf.write("123"); + try t.expectError(error.NoMoreData, AuthenticationSASLContinue.parse(buf.string())); + } + + { + // wrong special sasl type + buf.reset(); + try buf.writeIntBig(u32, 12); + try t.expectError(error.NotSASLChallenge, AuthenticationSASLContinue.parse(buf.string())); + } + + { + // success + buf.reset(); + try buf.writeIntBig(u32, 11); + try buf.write("r=a-nounce,s=the-S@lt,i=4096"); + + const c = try AuthenticationSASLContinue.parse(buf.string()); + try t.expectString("r=a-nounce,s=the-S@lt,i=4096", c.data); + } +} diff --git a/zig/pg/src/proto/AuthenticationSASLFinal.zig b/zig/pg/src/proto/AuthenticationSASLFinal.zig new file mode 100644 index 0000000..8d72069 --- /dev/null +++ b/zig/pg/src/proto/AuthenticationSASLFinal.zig @@ -0,0 +1,54 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Reader = proto.Reader; + +// #4 - Server finalizes with this +const AuthenticationSASLFinal = @This(); + +data: []const u8, + +pub fn parse(data: []const u8) !AuthenticationSASLFinal { + var reader = Reader.init(data); + + if (try reader.int32() != 12) { + return error.NotSASLChallenge; + } + + // can't really parse this, since it technically depends on the SASL + // mechanism in use + return .{ + .data = reader.rest(), + }; +} + +const t = proto.testing; +test "AuthenticationSASLFinal: parse" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + // too short + try t.expectError(error.NoMoreData, AuthenticationSASLFinal.parse(buf.string())); + + try buf.write("123"); + try t.expectError(error.NoMoreData, AuthenticationSASLFinal.parse(buf.string())); + } + + { + // wrong special sasl type + buf.reset(); + try buf.writeIntBig(u32, 13); + try t.expectError(error.NotSASLChallenge, AuthenticationSASLFinal.parse(buf.string())); + } + + { + // success + buf.reset(); + try buf.writeIntBig(u32, 12); + try buf.write("some server data"); + + const final = try AuthenticationSASLFinal.parse(buf.string()); + try t.expectString("some server data", final.data); + } +} diff --git a/zig/pg/src/proto/CommandComplete.zig b/zig/pg/src/proto/CommandComplete.zig new file mode 100644 index 0000000..dcfd9dc --- /dev/null +++ b/zig/pg/src/proto/CommandComplete.zig @@ -0,0 +1,77 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); +const Reader = proto.Reader; + +const CommandComplete = @This(); + +tag: []const u8, + +pub fn parse(data: []const u8) !CommandComplete { + var reader = Reader.init(data); + return .{ + .tag = try reader.restAsString(), + }; +} + +// Finds the last number in the tag of the command completed. If no number is +// found, than 0 rows were affected. Commands like "create table" or "create +// role" have a tag that's just "CREATE XYZ". +// But update/delete/select/... have something like "delete #" +// "insert" is a bit more complicated, but the rows inserted is the last number +// so this works for it too. +pub fn rowsAffected(self: CommandComplete) ?i64 { + const tag = self.tag; + const end = tag.len - 1; + var i: usize = end; + while (i >= 0) : (i -= 1) { + const b = tag[i]; + if (b < '0' or b > '9') { + break; + } + } + + if (i == end) { + return null; + } + + return std.fmt.parseInt(i64, tag[(i + 1)..], 10) catch unreachable; +} + +const t = proto.testing; +test "CommandComplete: parse" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + // not a string (not null terminated) + try buf.write("123"); + try t.expectError(error.NotAString, CommandComplete.parse(buf.string())); + } + + { + // success + buf.reset(); + try buf.write("CREATE ROLE"); + try buf.writeByte(0); + + const c = try CommandComplete.parse(buf.string()); + try t.expectString("CREATE ROLE", c.tag); + } +} + +test "CommandComplete: rowsAffected" { + { + const c = CommandComplete{ .tag = "DROP ROLE" }; + try t.expectEqual(null, c.rowsAffected()); + } + + { + const c = CommandComplete{ .tag = "INSERT 392 1" }; + try t.expectEqual(1, c.rowsAffected()); + } + + { + const c = CommandComplete{ .tag = "DELETE 9392" }; + try t.expectEqual(9392, c.rowsAffected()); + } +} diff --git a/zig/pg/src/proto/Describe.zig b/zig/pg/src/proto/Describe.zig new file mode 100644 index 0000000..cec2db3 --- /dev/null +++ b/zig/pg/src/proto/Describe.zig @@ -0,0 +1,64 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Describe = @This(); + +name: []const u8 = "", +type: Type = .portal, + +pub const Type = enum { + portal, + prepared_statement, +}; + +pub fn write(self: Describe, buf: *proto.Buffer) !void { + // 4 + 1 + N + 1 + // len + $type + $name + 0 + const payload_len = 6 + self.name.len; + + // +1 for the type field, 'D' + const total_length = payload_len + 1; + + try buf.ensureTotalCapacity(total_length); + + var view = buf.skip(total_length) catch unreachable; + view.writeByte('D'); + view.writeIntBig(u32, @intCast(payload_len)); + view.writeByte(switch (self.type) { + .portal => 'P', + .prepared_statement => 'S', + }); + + view.write(self.name); + view.writeByte(0); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "Describe: write portal no name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const p = Describe{}; + try p.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('D', try reader.byte()); + try t.expectEqual(6, try reader.int32()); // payload length + try t.expectEqual('P', try reader.byte()); + try t.expectString("", try reader.restAsString()); +} + +test "Describe: write prepared statement with name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const p = Describe{ .type = .prepared_statement, .name = "the-name" }; + try p.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('D', try reader.byte()); + try t.expectEqual(14, try reader.int32()); // payload length + try t.expectEqual('S', try reader.byte()); + try t.expectString("the-name", try reader.restAsString()); +} diff --git a/zig/pg/src/proto/Error.zig b/zig/pg/src/proto/Error.zig new file mode 100644 index 0000000..bb2a93f --- /dev/null +++ b/zig/pg/src/proto/Error.zig @@ -0,0 +1,127 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Error = @This(); + +code: []const u8, +message: []const u8, +severity: []const u8, + +column: ?[]const u8 = null, +constraint: ?[]const u8 = null, +data_type_name: ?[]const u8 = null, +detail: ?[]const u8 = null, +file: ?[]const u8 = null, +hint: ?[]const u8 = null, +internal_position: ?[]const u8 = null, +internal_query: ?[]const u8 = null, +line: ?[]const u8 = null, +position: ?[]const u8 = null, +routine: ?[]const u8 = null, +schema: ?[]const u8 = null, +severity2: ?[]const u8 = null, +table: ?[]const u8 = null, +where: ?[]const u8 = null, + +pub fn isUnique(self: Error) bool { + return std.mem.eql(u8, self.code, "23505"); +} + +pub fn parse(data: []const u8) Error { + var err = Error{ + .code = "", + .message = "", + .severity = "", + }; + + var pos: usize = 0; + while (pos < data.len) { + const value_end = std.mem.indexOfScalarPos(u8, data, pos + 1, 0) orelse { + // TODO: should not happen + break; + }; + + const value = data[pos + 1 .. value_end]; + switch (data[pos]) { + 'S' => err.severity = value, + 'V' => err.severity2 = value, + 'C' => err.code = value, + 'M' => err.message = value, + 'D' => err.detail = value, + 'H' => err.hint = value, + 'P' => err.position = value, + 'p' => err.internal_position = value, + 'q' => err.internal_query = value, + 'W' => err.where = value, + 's' => err.schema = value, + 't' => err.table = value, + 'c' => err.column = value, + 'd' => err.data_type_name = value, + 'n' => err.constraint = value, + 'F' => err.file = value, + 'L' => err.line = value, + 'R' => err.routine = value, + else => unreachable, + } + pos = value_end + 1; + } + + return err; +} + +const t = proto.testing; +test "Error: parse" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + // only required + try buf.writeByte('C'); + try buf.write("10391A"); + try buf.writeByte(0); + + try buf.writeByte('M'); + try buf.write("The Message"); + try buf.writeByte(0); + + try buf.writeByte('S'); + try buf.write("FATAL"); + try buf.writeByte(0); + + const err = Error.parse(buf.string()); + try t.expectString("10391A", err.code); + try t.expectString("The Message", err.message); + try t.expectString("FATAL", err.severity); + } + + { + // all fields + const fields = [_]u8{ 'S', 'V', 'C', 'M', 'D', 'H', 'P', 'p', 'q', 'W', 's', 't', 'c', 'd', 'n', 'F', 'L', 'R' }; + for (fields) |field| { + try buf.writeByte(field); + try buf.writeByte(field); + try buf.write("-value"); + try buf.writeByte(0); + } + + const err = Error.parse(buf.string()); + try t.expectString("C-value", err.code); + try t.expectString("M-value", err.message); + try t.expectString("S-value", err.severity); + try t.expectString("V-value", err.severity2.?); + try t.expectString("D-value", err.detail.?); + try t.expectString("H-value", err.hint.?); + try t.expectString("P-value", err.position.?); + try t.expectString("p-value", err.internal_position.?); + try t.expectString("q-value", err.internal_query.?); + try t.expectString("W-value", err.where.?); + try t.expectString("s-value", err.schema.?); + try t.expectString("t-value", err.table.?); + try t.expectString("c-value", err.column.?); + try t.expectString("d-value", err.data_type_name.?); + try t.expectString("n-value", err.constraint.?); + try t.expectString("F-value", err.file.?); + try t.expectString("L-value", err.line.?); + try t.expectString("R-value", err.routine.?); + } +} diff --git a/zig/pg/src/proto/Execute.zig b/zig/pg/src/proto/Execute.zig new file mode 100644 index 0000000..a32f8fc --- /dev/null +++ b/zig/pg/src/proto/Execute.zig @@ -0,0 +1,59 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Execute = @This(); + +portal: []const u8 = "", +// 0 == no limit, don't use a nullable, since that would imply +// that 0 means something else. +max_rows: u32 = 0, + +pub fn write(self: Execute, buf: *proto.Buffer) !void { + // 4 + N + 1 + 4 + // len + $portal + 0 + $max_rows + const payload_len = 9 + self.portal.len; + + // +1 for the type field, 'P' + const total_length = payload_len + 1; + + try buf.ensureTotalCapacity(total_length); + + var view = buf.skip(total_length) catch unreachable; + view.writeByte('E'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.portal); + view.writeByte(0); + view.writeIntBig(u32, self.max_rows); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "Execute: write no name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const e = Execute{}; + try e.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('E', try reader.byte()); + try t.expectEqual(9, try reader.int32()); // payload length + try t.expectString("", try reader.string()); + try t.expectEqual(0, try reader.int32()); + try t.expectString("", reader.rest()); +} + +test "Execute: write with name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const p = Execute{ .portal = "a name", .max_rows = 500 }; + try p.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('E', try reader.byte()); + try t.expectEqual(15, try reader.int32()); // payload length + try t.expectString("a name", try reader.string()); + try t.expectEqual(500, try reader.int32()); + try t.expectString("", reader.rest()); +} diff --git a/zig/pg/src/proto/NotificationResponse.zig b/zig/pg/src/proto/NotificationResponse.zig new file mode 100644 index 0000000..55b2bfc --- /dev/null +++ b/zig/pg/src/proto/NotificationResponse.zig @@ -0,0 +1,32 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Reader = proto.Reader; + +const NotificationResponse = @This(); + +process_id: u32, +channel: []const u8, +payload: []const u8, + +pub fn parse(data: []const u8) !NotificationResponse { + var reader = Reader.init(data); + return .{ .process_id = try reader.int32(), .channel = try reader.string(), .payload = try reader.string() }; +} + +const t = proto.testing; +test "NotificationResponse: parse" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + try buf.writeIntBig(u32, 912); + try buf.write("chan-1"); + try buf.writeByte(0); + try buf.write("payload-2"); + try buf.writeByte(0); + + const nr = try NotificationResponse.parse(buf.string()); + try t.expectEqual(912, nr.process_id); + try t.expectString("chan-1", nr.channel); + try t.expectString("payload-2", nr.payload); +} diff --git a/zig/pg/src/proto/Parse.zig b/zig/pg/src/proto/Parse.zig new file mode 100644 index 0000000..3c4b157 --- /dev/null +++ b/zig/pg/src/proto/Parse.zig @@ -0,0 +1,62 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Parse = @This(); + +prepared_statement: []const u8 = "", +sql: []const u8, + +pub fn write(self: Parse, buf: *proto.Buffer) !void { + // 4 + N + 1 + S + 1 + 2 + // len + $name + 0 + $sql + 0 + u16(0) + const payload_len = 8 + self.prepared_statement.len + self.sql.len; + + // +1 for the type field, 'P' + const total_length = payload_len + 1; + + try buf.ensureTotalCapacity(total_length); + + var view = buf.skip(total_length) catch unreachable; + view.writeByte('P'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.prepared_statement); + view.writeByte(0); + view.write(self.sql); + view.writeByte(0); + // this is the # of parameters types that we plan on describing + view.write(&.{ 0, 0 }); // 0 as a u16 +} + +const t = proto.testing; +const Reader = proto.Reader; +test "Parse: write no name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const p = Parse{ .sql = "select 1" }; + try p.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('P', try reader.byte()); + try t.expectEqual(16, try reader.int32()); // payload length + try t.expectString("", try reader.string()); + try t.expectString("select 1", try reader.string()); + try t.expectEqual(0, try reader.int16()); + try t.expectString("", reader.rest()); +} + +test "Parse: write with name" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const p = Parse{ .sql = "select 1", .prepared_statement = "a name" }; + try p.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('P', try reader.byte()); + try t.expectEqual(22, try reader.int32()); // payload length + try t.expectString("a name", try reader.string()); + try t.expectString("select 1", try reader.string()); + try t.expectEqual(0, try reader.int16()); + try t.expectString("", reader.rest()); +} diff --git a/zig/pg/src/proto/PasswordMessage.zig b/zig/pg/src/proto/PasswordMessage.zig new file mode 100644 index 0000000..1ab2cce --- /dev/null +++ b/zig/pg/src/proto/PasswordMessage.zig @@ -0,0 +1,38 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const PasswordMessage = @This(); + +password: []const u8, + +pub fn write(self: PasswordMessage, buf: *proto.Buffer) !void { + // +4 since the payload length includes the length itself + // +1 for null terminated string + const payload_len = self.password.len + 5; + + // +1 for the type field, 'p' + const total_length = payload_len + 1; + + try buf.ensureTotalCapacity(total_length); + + var view = buf.skip(total_length) catch unreachable; + view.writeByte('p'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.password); + view.writeByte(0); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "PasswordMessage: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const pw = PasswordMessage{ .password = "gh@nim@" }; + try pw.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('p', try reader.byte()); + try t.expectEqual(12, try reader.int32()); // payload length + try t.expectString("gh@nim@", try reader.string()); +} diff --git a/zig/pg/src/proto/Query.zig b/zig/pg/src/proto/Query.zig new file mode 100644 index 0000000..cf7c24e --- /dev/null +++ b/zig/pg/src/proto/Query.zig @@ -0,0 +1,38 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Query = @This(); + +sql: []const u8, + +pub fn write(self: Query, buf: *proto.Buffer) !void { + // 4 + S + 1 + // len + $sql + 0 + const payload_len = 5 + self.sql.len; + + // +1 for the type field, 'Q' + const total_length = payload_len + 1; + + try buf.ensureTotalCapacity(total_length); + + var view = buf.skip(total_length) catch unreachable; + view.writeByte('Q'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.sql); + view.writeByte(0); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "Query: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const q = Query{ .sql = "select 1" }; + try q.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('Q', try reader.byte()); + try t.expectEqual(13, try reader.int32()); // payload length + try t.expectString("select 1", try reader.restAsString()); +} diff --git a/zig/pg/src/proto/SASLInitialResponse.zig b/zig/pg/src/proto/SASLInitialResponse.zig new file mode 100644 index 0000000..b551c15 --- /dev/null +++ b/zig/pg/src/proto/SASLInitialResponse.zig @@ -0,0 +1,48 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +//Client sends this based on getting a request.sasl +const SASLInitialResponse = @This(); + +response: []const u8, +mechanism: []const u8, + +pub fn write(self: SASLInitialResponse, buf: *proto.Buffer) !void { + // 4 + M + 1 + 4 + R + // len + $mechanism + 0 + $response.len + $response + const payload_len = 9 + self.mechanism.len + self.response.len; + + // + 1 for the leading 'p' + const total_length = payload_len + 1; + try buf.ensureTotalCapacity(total_length); + + // this nonsense is to skip the buffers bound checking, since we've already + // ensured the available capacity + var view = buf.skip(total_length) catch unreachable; + view.writeByte('p'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.mechanism); + view.writeByte(0); + view.writeIntBig(u32, @intCast(self.response.len)); + view.write(self.response); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "SASLInitialResponse: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const s = SASLInitialResponse{ + .mechanism = "SCRAM-SHA-256", + .response = "a sasl response", + }; + try s.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('p', try reader.byte()); + try t.expectEqual(37, try reader.int32()); // payload length + try t.expectString("SCRAM-SHA-256", try reader.string()); + try t.expectEqual(15, try reader.int32()); // length of response + try t.expectString("a sasl response", reader.rest()); +} diff --git a/zig/pg/src/proto/SASLResponse.zig b/zig/pg/src/proto/SASLResponse.zig new file mode 100644 index 0000000..9901350 --- /dev/null +++ b/zig/pg/src/proto/SASLResponse.zig @@ -0,0 +1,41 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); +const Reader = proto.Reader; + +// #3 - Client finalizes with this +const SASLResponse = @This(); + +data: []const u8, + +pub fn write(self: SASLResponse, buf: *proto.Buffer) !void { + // 4 + N + // len + $data + const payload_len = 4 + self.data.len; + + // + 1 for the leading 'p' + const total_length = payload_len + 1; + try buf.ensureTotalCapacity(total_length); + + // this nonsense is to skip the buffers bound checking, since we've already + // ensured the available capacity + var view = buf.skip(total_length) catch unreachable; + view.writeByte('p'); + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.data); +} + +const t = proto.testing; +test "SASLResponse: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const s = SASLResponse{ + .data = "the response", + }; + try s.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('p', try reader.byte()); + try t.expectEqual(16, try reader.int32()); // payload length + try t.expectString("the response", reader.rest()); +} diff --git a/zig/pg/src/proto/StartupMessage.zig b/zig/pg/src/proto/StartupMessage.zig new file mode 100644 index 0000000..11f15df --- /dev/null +++ b/zig/pg/src/proto/StartupMessage.zig @@ -0,0 +1,76 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const StartupMessage = @This(); + +protocol: []const u8 = &[_]u8{ 0, 3, 0, 0 }, +username: []const u8, +database: []const u8, +application_name: ?[]const u8 = null, +params: ?std.StringHashMap([]const u8) = null, + +pub fn write(self: StartupMessage, buf: *proto.Buffer) !void { + // 4 + 4 + 4 + 1 + N + 1 + 8 + 1 + M + 1 + 1 = 25 + N + M + // len + protocol + "user" + 0 + $username + 0 "database" + 0 + $database + 0 + 0 + 0 + var payload_len = 25 + self.username.len + self.database.len; + if (self.params) |p| { + var it = p.iterator(); + while (it.next()) |kv| { + // +2 because both key and value are null-terminated + payload_len += kv.key_ptr.len + kv.value_ptr.len + 2; + } + } + if (self.application_name) |an| { + // +2 because both key and value are null-terminated + payload_len += "application_name".len + an.len + 2; + } + + try buf.ensureTotalCapacity(payload_len); + + // this nonsense is to skip the buffers bound checking, since we've already + // ensured the available capacity + var view = buf.skip(payload_len) catch unreachable; + view.writeIntBig(u32, @intCast(payload_len)); + view.write(self.protocol); + view.write(&[_]u8{ 'u', 's', 'e', 'r', 0 }); + view.write(self.username); + view.writeByte(0); + view.write(&[_]u8{ 'd', 'a', 't', 'a', 'b', 'a', 's', 'e', 0 }); + view.write(self.database); + view.writeByte(0); + if (self.application_name) |an| { + view.write("application_name"); + view.writeByte(0); + view.write(an); + view.writeByte(0); + } + if (self.params) |p| { + var it = p.iterator(); + while (it.next()) |kv| { + view.write(kv.key_ptr.*); + view.writeByte(0); + view.write(kv.value_ptr.*); + view.writeByte(0); + } + } + view.writeByte(0); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "StartupMessage: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const s = StartupMessage{ .username = "leto", .database = "ghanima" }; + try s.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual(36, try reader.int32()); // payload length + try t.expectEqual(196608, try reader.int32()); // protocol version + try t.expectString("user", try reader.string()); + try t.expectString("leto", try reader.string()); + try t.expectString("database", try reader.string()); + try t.expectString("ghanima", try reader.string()); + try t.expectSlice(u8, &.{0}, reader.rest()); +} diff --git a/zig/pg/src/proto/Sync.zig b/zig/pg/src/proto/Sync.zig new file mode 100644 index 0000000..5287f5a --- /dev/null +++ b/zig/pg/src/proto/Sync.zig @@ -0,0 +1,24 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); + +const Sync = @This(); +pub fn write(_: Sync, buf: *proto.Buffer) !void { + try buf.ensureTotalCapacity(5); + var view = buf.skip(5) catch unreachable; + view.write(&.{ 'S', 0, 0, 0, 4 }); +} + +const t = proto.testing; +const Reader = proto.Reader; +test "Sync: write" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + const s = Sync{}; + try s.write(&buf); + + var reader = Reader.init(buf.string()); + try t.expectEqual('S', try reader.byte()); + try t.expectEqual(4, try reader.int32()); // payload length + try t.expectString("", reader.rest()); +} diff --git a/zig/pg/src/proto/_proto.zig b/zig/pg/src/proto/_proto.zig new file mode 100644 index 0000000..bf0314a --- /dev/null +++ b/zig/pg/src/proto/_proto.zig @@ -0,0 +1,85 @@ +// Used internall by files in this folder like a "utils". Underscore filename +// because I like having this clearly separate in the file view. Would be nice +// if editors let you override file ordering. +const std = @import("std"); + +pub const Buffer = @import("buffer").Buffer; +pub const testing = @import("../lib.zig").testing; + +pub const Reader = struct { + pos: usize, + data: []const u8, + + pub fn init(data: []const u8) Reader { + return .{ + .pos = 0, + .data = data, + }; + } + + pub fn byte(self: *Reader) !u8 { + if (!self.hasAtLeast(1)) { + return error.NoMoreData; + } + const pos = self.pos; + const value = self.data[pos]; + self.pos = pos + 1; + return value; + } + + pub fn optionalString(self: *Reader) ?[]const u8 { + const pos = self.pos; + const data = self.data; + const index = std.mem.indexOfScalarPos(u8, data, pos, 0) orelse return null; + + const value = data[pos..index]; + self.pos = index + 1; // +1 to consume the null terminator + return value; + } + + pub fn string(self: *Reader) ![]const u8 { + if (!self.hasAtLeast(1)) { + return error.NoMoreData; + } + return self.optionalString() orelse return error.NotAString; + } + + pub fn int16(self: *Reader) !u16 { + if (!self.hasAtLeast(2)) { + return error.NoMoreData; + } + const pos = self.pos; + const end = pos + 2; + const value = std.mem.readInt(u16, self.data[pos..end][0..2], .big); + self.pos = end; + return value; + } + + pub fn int32(self: *Reader) !u32 { + if (!self.hasAtLeast(4)) { + return error.NoMoreData; + } + const pos = self.pos; + const end = pos + 4; + const value = std.mem.readInt(u32, self.data[pos..end][0..4], .big); + self.pos = end; + return value; + } + + // does not consume + pub fn rest(self: *Reader) []const u8 { + return self.data[self.pos..]; + } + + pub fn restAsString(self: *Reader) ![]const u8 { + const r = self.data[self.pos..]; + if (r[r.len - 1] != 0) { + return error.NotAString; + } + return r[0 .. r.len - 1]; + } + + pub fn hasAtLeast(self: Reader, n: usize) bool { + return self.pos + n <= self.data.len; + } +}; diff --git a/zig/pg/src/proto/authentication_request.zig b/zig/pg/src/proto/authentication_request.zig new file mode 100644 index 0000000..4b35838 --- /dev/null +++ b/zig/pg/src/proto/authentication_request.zig @@ -0,0 +1,143 @@ +const std = @import("std"); +const proto = @import("_proto.zig"); +const Reader = proto.Reader; + +// The server making an authentication request to the client. This is in response +// to a Startup message sent from the client. In my mind, this is really a +// "Response", but the documentation calls it a "Request" and, from the point of +// view of the server, that's what it is. +pub const AuthenticationRequest = union(enum) { + ok: void, + password: void, + md5: []const u8, + sasl: SASL, + + pub const SASL = struct { + scram_sha_256: bool = false, + scram_sha_256_plus: bool = false, + }; + + pub fn parse(data: []const u8) !AuthenticationRequest { + var reader = Reader.init(data); + + switch (try reader.int32()) { + 0 => return .{ .ok = {} }, // authentication ok + 3 => return .{ .password = {} }, // authentication requires a plain-text password + 5 => { + // 4 + 4 + 4 + // payload_len + 5 + $salt + if (data.len != 8) { + return error.NoMoreData; + } + return .{ .md5 = reader.rest() }; + }, + 10 => { + var sasl = SASL{}; + while (reader.optionalString()) |auth_mechanism| { + if (std.ascii.eqlIgnoreCase(auth_mechanism, "SCRAM-SHA-256")) { + sasl.scram_sha_256 = true; + } else if (std.ascii.eqlIgnoreCase(auth_mechanism, "SCRAM-SHA-256-PLUS")) { + sasl.scram_sha_256_plus = true; + } + } + return .{ .sasl = sasl }; + }, + else => return error.AuthNotSupported, + } + } +}; + +const t = proto.testing; +test "AuthenticationRequest: invalid" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + // empty + try t.expectError(error.NoMoreData, AuthenticationRequest.parse(buf.string())); + } + + { + // less than minimum length + try buf.write("123"); + try t.expectError(error.NoMoreData, AuthenticationRequest.parse(buf.string())); + } + + { + // unknown auth type + buf.reset(); + try buf.writeIntBig(u32, 99); + try t.expectError(error.AuthNotSupported, AuthenticationRequest.parse(buf.string())); + } +} + +test "AuthenticationRequest: ok" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + try buf.writeIntBig(u32, 0); + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectEqual({}, request.ok); +} + +test "AuthenticationRequest: password" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + try buf.writeIntBig(u32, 3); + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectEqual({}, request.password); +} + +test "AuthenticationRequest: md5" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + try buf.writeIntBig(u32, 5); + try buf.write("s@Lt"); + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectString("s@Lt", request.md5); +} + +test "AuthenticationRequest: sasl with 1 mechanism" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + { + try buf.writeIntBig(u32, 10); + try buf.write("SCRAM-SHA-256"); + try buf.writeByte(0); + + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectEqual(true, request.sasl.scram_sha_256); + try t.expectEqual(false, request.sasl.scram_sha_256_plus); + } + + { + buf.reset(); + try buf.writeIntBig(u32, 10); + try buf.write("SCRAM-SHA-256-PLUS"); + try buf.writeByte(0); + + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectEqual(false, request.sasl.scram_sha_256); + try t.expectEqual(true, request.sasl.scram_sha_256_plus); + } +} + +test "AuthenticationRequest: sasl with multiple including unknown" { + var buf = try proto.Buffer.init(t.allocator, 128); + defer buf.deinit(); + + try buf.writeIntBig(u32, 10); + try buf.write("SCRAM-SHA-256-PLUS"); + try buf.writeByte(0); + try buf.write("SCRAM-SHA-256"); + try buf.writeByte(0); + try buf.write("SCRAM-MD5"); + try buf.writeByte(0); + + const request = try AuthenticationRequest.parse(buf.string()); + try t.expectEqual(true, request.sasl.scram_sha_256); + try t.expectEqual(true, request.sasl.scram_sha_256_plus); +} diff --git a/zig/pg/src/reader.zig b/zig/pg/src/reader.zig new file mode 100644 index 0000000..0917dc9 --- /dev/null +++ b/zig/pg/src/reader.zig @@ -0,0 +1,727 @@ +const std = @import("std"); +const lib = @import("lib.zig"); +const builtin = @import("builtin"); + +const posix = std.posix; +const Conn = lib.Conn; +const Allocator = std.mem.Allocator; + +// to everyone else, this is our reader +pub const Reader = ReaderT(lib.Stream); + +const zero_timeval = std.mem.toBytes(posix.timeval{ .sec = 0, .usec = 0 }); + +// generic just for testing within this file +fn ReaderT(comptime T: type) type { + return struct { + // Whether or not we've put a timeout on the request. This helps avoid + // system calls when no timeout is set. + has_timeout: bool, + + // Provided when the reader was allocated (which is the allocator given + // when the connection/pool was created). Owns `static` and unless a query- + // specific allocator is provided, will be used for any dynamic allocations. + default_allocator: Allocator, + + // Current active allocator. This will normally reference `default_allocator` + // but a query can provide a specific allocator to use for the processing + // of said query. (via startFlow) + allocator: Allocator, + + // Exists for the lifetime of the reader, but normally references this, but + // for messages that don't fit, we'll allocate memory dynamically and + // eventually revert back to buf. + static: []u8, + + // buffer to read into + buf: []u8, + + // start within buf of the next message + start: usize = 0, + + // position in buf that we have valid data up to + pos: usize = 0, + + stream: T, + + const Self = @This(); + + pub fn init(allocator: Allocator, size: usize, stream: T) !Self { + const static = try allocator.alloc(u8, size); + return .{ + .buf = static, + .stream = stream, + .static = static, + .has_timeout = false, + .allocator = allocator, + .default_allocator = allocator, + }; + } + + pub fn deinit(self: Self) void { + if (self.static.ptr != self.buf.ptr) { + self.allocator.free(self.buf); + } + self.default_allocator.free(self.static); + } + + // Between a call to startFlow and endFlow, the reader can re-use any + // dynamic buffer it creates. The idea beind this is that if reading 1 row + // requires more than static.len other rows within the same result might + // as well. + pub fn startFlow(self: *Self, allocator: ?Allocator, timeout_ms: ?u32) !void { + if (timeout_ms) |ms| { + const timeval = std.mem.toBytes(posix.timeval{ + .sec = @intCast(@divTrunc(ms, 1000)), + .usec = @intCast(@mod(ms, 1000) * 1000), + }); + try posix.setsockopt(self.stream.socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &timeval); + self.has_timeout = true; + } else if (self.has_timeout) { + try posix.setsockopt(self.stream.socket, posix.SOL.SOCKET, posix.SO.RCVTIMEO, &zero_timeval); + self.has_timeout = false; + } + + self.allocator = allocator orelse self.default_allocator; + } + + pub fn endFlow(self: *Self) !void { + const buf = self.buf; + const allocator = self.allocator; + + self.allocator = self.default_allocator; + if (self.static.ptr == buf.ptr) { + // we never created a dynamic buffer + return; + } + + // even if the following fails, we want to free this + defer allocator.free(buf); + + // Normally, when an "flow" ends, we expect our read buffer to be empty. + // This is true because data from PG is normally only sent in response + // to a request. If we've ended our "flow", then we should have read + // everything from PG. But PG can occasionally send data on its own. + // So it's possible that we over-read and now our dynamic buffer has + // data unrelated to this flow. + const pos = self.pos; + const start = self.start; + const extra = pos - start; + + var new_buf: []u8 = undefined; + if (extra > self.static.len) { + // This is unusual. Not only did we overread, but we've overread so + // much that we can't use our static buffer. + + const default_allocator = self.default_allocator; + if (allocator.ptr == default_allocator.ptr) { + // The dynamic buffer was allocated with our default allocator, so + // we can keep it as-is + return; + } + + // This is the worst. We have extra data in our dynamically buffer AND + // we have a query-specific allocator. This data _cannot_ remain + // where it is (because we have no guarantee that the allocator is valid + // beyond this query). + // So we'll copy it to a new buffer using our default allocator. + new_buf = try default_allocator.dupe(u8, self.buf[start..pos]); + } else { + // We either have no extra data, or we have extra data, but it fits in + // our static buffer. Either way, we're reverting self.buf to self.static; + new_buf = self.static; + if (extra > 0) { + // We read extra data, copy this into our static buffer + @memcpy(new_buf[0..extra], self.buf[start..pos]); + } + } + + self.pos = extra; + self.start = 0; + self.buf = new_buf; + } + + // If you execute "select * from invalid_table", PostgreSQL will return + // an error early in the process of preparing the statement - as part + // of parsing the statement, it knows that "invalid_table" isn't a valid table. + + // But if you execute "create table already_exists", the error is is only + // returned once you try to read the result. + // + // This difference results in an inconsistent api: some error are returned + // immediately by conn.query() and some errors are only returned when + // result.next() is first called. + // + // Here we attempt to fix this by eagerly reading the next message. If it's + // an error, we return it. If it isn't an error, we put it back for the next + // successful read. + pub fn peekForError(self: *Self) !?[]const u8 { + const message = self.buffered(self.pos, true) orelse try self.read(true); + return if (message.type == 'E') message.data else null; + } + + pub fn next(self: *Self) !Message { + return self.buffered(self.pos, false) orelse self.read(false); + } + + fn read(self: *Self, error_peek: bool) !Message { + var stream = self.stream; + // const spare = buf.len - pos; // how much space we have left in our buffer + + // Every PG message has 1 type byte followed by a 4 byte length prefix. + // Since the length prefix includes itself (but not the type byte) the + // minimum possible length is 4. We use 0 to denote "unknown". + var buf = self.buf; + var pos = self.pos; + var message_length: usize = 0; + + while (true) { + if (message_length == 0) { + // we don't yet know the length of this message + + const start = self.start; + + // how much of the next message we have + const current_length = pos - start; + + // we have enough data to figure the message length + if (current_length > 4) { + // + 1 for the type byte + message_length = std.mem.readInt(u32, buf[start + 1 .. start + 5][0..4], .big) + 1; + + if (message_length > buf.len) { + var new_buf: []u8 = undefined; + const allocator = self.allocator; + + if (buf.ptr == self.static.ptr) { + //currently using our static buffer, we need to allocate a larger one + new_buf = try allocator.alloc(u8, message_length); + @memcpy(new_buf[0..current_length], buf[start..pos]); + lib.metrics.allocReader(message_length); + } else { + // currently using a dynamically allocated buffer, we'll + // grow or allocate a larger one (which is what realloc does) + new_buf = try allocator.realloc(buf, message_length); + if (start > 0) { + std.mem.copyForwards(u8, new_buf[0..current_length], new_buf[start..pos]); + } + lib.metrics.allocReader(message_length - current_length); + } + + self.start = 0; + pos = current_length; + buf = new_buf; + self.buf = new_buf; + } else if (message_length > buf.len - start) { + // our buffer is big enough, but not from where we're currently starting + std.mem.copyForwards(u8, buf[0..current_length], buf[start..pos]); + pos = current_length; + self.start = 0; + } + } else if (buf.len - start < 5) { + // we don't even have enough space to read the 5 byte header + std.mem.copyForwards(u8, buf[0..current_length], buf[start..pos]); + pos = current_length; + self.start = 0; + } + } + + const n = try stream.read(buf[pos..]); + if (n == 0) { + return error.Closed; + } + pos += n; + if (self.buffered(pos, error_peek)) |msg| { + return msg; + } + } + } + + // checks and consume if we already have a message buffered + fn buffered(self: *Self, pos: usize, error_peek: bool) ?Message { + const start = self.start; + const available = pos - start; + + // we always need at least 5 bytes, 1 for the type and 4 for the length + if (available < 5) { + return null; + } + const buf = self.buf; + + const len_end = start + 5; + const len = std.mem.readInt(u32, buf[start + 1 .. len_end][0..4], .big); + + // +1 because the first byte, the message type, isn't included in the length + if (available < len + 1) { + return null; + } + + // -4 because the len includes the 4 byte length header itself + const end = len_end + len - 4; + + const message_type = buf[start]; + + if (error_peek == false or message_type == 'E') { + // how much extra data we already have + const extra = pos - end; + if (extra == 0) { + // we have no more data in the buffer, reset everything to the start + // so that we have the full buffer for future messages + self.pos = 0; + self.start = 0; + } else { + self.pos = pos; + self.start = end; + } + } else { + self.pos = pos; + } + + return .{ + .type = message_type, + .data = buf[len_end..end], + }; + } + }; +} + +pub const Message = struct { + type: u8, + data: []const u8, +}; + +const t = lib.testing; +test "Reader: next" { + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + { + s.reset(); + s.add(&[_]u8{ 8, 0, 0, 0, 4 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + const msg = try reader.next(); + try t.expectEqual(8, msg.type); + try t.expectSlice(u8, &[_]u8{}, msg.data); + } + + { + s.reset(); + s.add(&[_]u8{ 1, 0, 0, 0, 5, 2 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + const msg = try reader.next(); + try t.expectEqual(1, msg.type); + try t.expectSlice(u8, &[_]u8{2}, msg.data); + } + + { + s.reset(); + s.add(&[_]u8{ 1, 0, 0, 0, 9, 1, 2, 3, 4, 19 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + const msg = try reader.next(); + try t.expectEqual(1, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 19 }, msg.data); + // optimization, resets pos to 0 since we read an exact message + try t.expectEqual(0, reader.pos); + } + + { + // partial 2nd message, but closed without all the data + s.reset(); + s.add(&[_]u8{ 1, 0, 0, 0, 9, 1, 2, 3, 4, 19, 2 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + const msg = try reader.next(); + try t.expectEqual(1, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 19 }, msg.data); + try t.expectError(error.Closed, reader.next()); + } + + { + // 2 full messages, 2nd message has no data + s.reset(); + s.add(&[_]u8{ 99, 0, 0, 0, 6, 200, 201, 2, 0, 0, 0, 4 }); + var reader = R.init(t.allocator, 20, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(99, msg1.type); + try t.expectSlice(u8, &[_]u8{ 200, 201 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectEqual(2, msg2.type); + try t.expectSlice(u8, &[_]u8{}, msg2.data); + } + + { + // 2 full messages, 2nd message has data + s.reset(); + s.add(&[_]u8{ 99, 0, 0, 0, 6, 200, 201, 3, 0, 0, 0, 7, 1, 8, 2 }); + var reader = R.init(t.allocator, 20, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(99, msg1.type); + try t.expectSlice(u8, &[_]u8{ 200, 201 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectEqual(3, msg2.type); + try t.expectSlice(u8, &[_]u8{ 1, 8, 2 }, msg2.data); + } + + { + // 2 full messages, split across packets + s.reset(); + s.add(&[_]u8{ 91, 0, 0, 0, 6, 200, 22, 4, 0, 0, 0, 5 }); + var reader = R.init(t.allocator, 20, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(91, msg1.type); + try t.expectSlice(u8, &[_]u8{ 200, 22 }, msg1.data); + + s.add(&[_]u8{73}); + const msg2 = try reader.next(); + try t.expectEqual(4, msg2.type); + try t.expectSlice(u8, &[_]u8{73}, msg2.data); + } + + { + // not enough room in buffer for header of 2nd message + s.reset(); + s.add(&[_]u8{ 17, 0, 0, 0, 4, 5 }); + var reader = R.init(t.allocator, 8, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(17, msg1.type); + try t.expectSlice(u8, &[_]u8{}, msg1.data); + + s.add(&[_]u8{ 0, 0, 0, 6, 10, 12 }); + const msg2 = try reader.next(); + try t.expectEqual(5, msg2.type); + try t.expectSlice(u8, &[_]u8{ 10, 12 }, msg2.data); + } + + { + // not enough room in buffer for header of 2nd message across multiple callss + s.reset(); + s.add(&[_]u8{ 17, 0, 0, 0, 5, 1, 200 }); + var reader = R.init(t.allocator, 8, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(17, msg1.type); + try t.expectSlice(u8, &[_]u8{1}, msg1.data); + + s.add(&[_]u8{ 0, 0 }); + s.add(&[_]u8{0}); + s.add(&[_]u8{ 7, 10, 12, 14 }); + const msg2 = try reader.next(); + try t.expectEqual(200, msg2.type); + try t.expectSlice(u8, &[_]u8{ 10, 12, 14 }, msg2.data); + } +} + +// simulates message fragmentations +test "Reader: fuzz" { + const R = ReaderT(*t.Stream); + + var r = t.getRandom(); + const random = r.random(); + + const messages = [_]u8{ 1, 0, 0, 0, 4, 2, 0, 0, 0, 5, 1, 3, 0, 0, 0, 6, 1, 2, 4, 0, 0, 0, 24, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 5, 0, 0, 0, 8, 1, 2, 3, 4, 6, 0, 0, 0, 9, 1, 2, 3, 4, 5, 7, 0, 0, 0, 10, 1, 2, 3, 4, 5, 6, 8, 0, 0, 0, 11, 1, 2, 3, 4, 5, 6, 7, 9, 0, 0, 0, 25, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 }; + + for (0..400) |_| { + var s = t.Stream.init(); + defer s.deinit(); + var reader = R.init(t.allocator, 12, s) catch unreachable; + defer reader.deinit(); + + for (0..4) |_| { + var buf: []const u8 = messages[0..]; + while (buf.len > 0) { + const l = random.uintAtMost(usize, buf.len - 1) + 1; + s.add(buf[0..l]); + buf = buf[l..]; + } + + var arena = std.heap.ArenaAllocator.init(t.allocator); + defer arena.deinit(); + + const allocator: ?Allocator = if (random.uintAtMost(usize, 1) == 1) arena.allocator() else null; + + try reader.startFlow(allocator, null); + defer reader.endFlow() catch unreachable; + + { + const msg = try reader.next(); + try t.expectEqual(1, msg.type); + try t.expectSlice(u8, &[_]u8{}, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(2, msg.type); + try t.expectSlice(u8, &[_]u8{1}, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(3, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(4, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(5, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(6, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 5 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(7, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 5, 6 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(8, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 5, 6, 7 }, msg.data); + } + + { + const msg = try reader.next(); + try t.expectEqual(9, msg.type); + try t.expectSlice(u8, &[_]u8{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 }, msg.data); + } + try t.expectError(error.Closed, reader.next()); + } + } +} + +test "Reader: dynamic" { + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + { + // message bigger than static buffer + s.add(&[_]u8{ 200, 0, 0, 0, 14, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + const msg = try reader.next(); + try t.expectEqual(200, msg.type); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, msg.data); + } + + { + // 2nd message bigger than static buffer + s.add(&[_]u8{ 199, 0, 0, 0, 6, 9, 8, 200, 0, 0, 0, 14, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(199, msg1.type); + try t.expectSlice(u8, &.{ 9, 8 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectEqual(200, msg2.type); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, msg2.data); + } + + { + // middle message bigger than static + s.add(&[_]u8{ 199, 0, 0, 0, 6, 9, 8, 200, 0, 0, 0, 14, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 198, 0, 0, 0, 5, 1 }); + var reader = R.init(t.allocator, 10, s) catch unreachable; + defer reader.deinit(); + + const msg1 = try reader.next(); + try t.expectEqual(199, msg1.type); + try t.expectSlice(u8, &.{ 9, 8 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectEqual(200, msg2.type); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, msg2.data); + + const msg3 = try reader.next(); + try t.expectEqual(198, msg3.type); + try t.expectSlice(u8, &.{1}, msg3.data); + } +} + +test "Reader: start/endFlow basic" { + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + // 1st message is bigge than static + s.add(&[_]u8{ 1, 0, 0, 0, 8, 1, 2, 3, 4 }); + + // 2nd message is bigger than first + s.add(&[_]u8{ 2, 0, 0, 0, 10, 1, 2, 3, 4, 5, 6 }); + + // 3rd message is smaller than 2nd (should re-use previous buffer) + s.add(&[_]u8{ 3, 0, 0, 0, 9, 1, 2, 3, 4, 5 }); + + var reader = R.init(t.allocator, 5, s) catch unreachable; + defer reader.deinit(); + + try reader.startFlow(null, null); + const msg1 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6 }, msg2.data); + + const msg3 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5 }, msg3.data); + reader.endFlow() catch unreachable; +} + +test "Reader: start/endFlow overread into static" { + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + // 1st message is bigge than static + s.add(&[_]u8{ 1, 0, 0, 0, 8, 1, 2, 3, 4 }); + + // 2nd message is bigger than first + s.add(&[_]u8{ 2, 0, 0, 0, 10, 1, 2, 3, 4, 5, 6 }); + + // 3rd message is smaller than 2nd (should re-use previous buffer) + s.add(&[_]u8{ 3, 0, 0, 0, 9, 1, 2, 3, 4, 5 }); + + // 4th message is overread and fits in static + s.add(&[_]u8{ 3, 0, 0, 0, 5, 255 }); + + var reader = R.init(t.allocator, 7, s) catch unreachable; + defer reader.deinit(); + + try reader.startFlow(null, null); + const msg1 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6 }, msg2.data); + + const msg3 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5 }, msg3.data); + reader.endFlow() catch unreachable; + + const msg4 = try reader.next(); + try t.expectSlice(u8, &.{255}, msg4.data); +} + +test "Reader: start/endFlow large overread" { + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + // 1st message is bigger than static + s.add(&[_]u8{ 1, 0, 0, 0, 8, 1, 2, 3, 4 }); + + // 2nd message is bigger than first + s.add(&[_]u8{ 2, 0, 0, 0, 18, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 }); + + // 3rd message is smaller than 2nd (should re-use previous buffer) + s.add(&[_]u8{ 3, 0, 0, 0, 9, 1, 2, 3, 4, 5 }); + + // 4rd message is huge + s.add(&[_]u8{ 4, 0, 0, 19, 140 } ++ "z" ** 5000); + + // 5th message is overread and does not fit into static + s.add(&[_]u8{ 5, 0, 0, 0, 11, 255, 250, 245, 240, 235, 230, 225 }); + + var reader = R.init(t.allocator, 7, s) catch unreachable; + defer reader.deinit(); + + try reader.startFlow(null, null); + const msg1 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 }, msg2.data); + + const msg3 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5 }, msg3.data); + + const msg4 = try reader.next(); + try t.expectSlice(u8, "z" ** 5000, msg4.data); + reader.endFlow() catch unreachable; + + const msg5 = try reader.next(); + try t.expectSlice(u8, &.{ 255, 250, 245, 240, 235, 230, 225 }, msg5.data); +} + +test "Reader: start/endFlow large overread with flow-specific allocator" { + defer t.reset(); + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + // 1st message is bigger than static + s.add(&[_]u8{ 1, 0, 0, 0, 8, 1, 2, 3, 4 }); + + // 2nd message is bigger than first + s.add(&[_]u8{ 2, 0, 0, 0, 10, 1, 2, 3, 4, 5, 6 }); + + // 3rd message is smaller than 2nd (should re-use previous buffer) + s.add(&[_]u8{ 3, 0, 0, 0, 9, 1, 2, 3, 4, 5 }); + + // 4th message is overread and does not fit into static + s.add(&[_]u8{ 3, 0, 0, 0, 11, 255, 250, 245, 240, 235, 230, 225 }); + + var reader = R.init(t.allocator, 7, s) catch unreachable; + defer reader.deinit(); + + try reader.startFlow(t.arena.allocator(), null); + const msg1 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4 }, msg1.data); + + const msg2 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5, 6 }, msg2.data); + + const msg3 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4, 5 }, msg3.data); + reader.endFlow() catch unreachable; + + const msg4 = try reader.next(); + try t.expectSlice(u8, &.{ 255, 250, 245, 240, 235, 230, 225 }, msg4.data); +} + +test "Reader: startFlow with dynamic allocation into deinit " { + // This can happen on an error case, where we start a flow, but an error + // happens during processing, causing conn.deinit() to be called (say, when + // it's released back into the pool in an error state). + defer t.reset(); + const R = ReaderT(*t.Stream); + var s = t.Stream.init(); + defer s.deinit(); + + // 1st message is bigger than static + s.add(&[_]u8{ 1, 0, 0, 0, 8, 1, 2, 3, 4 }); + + var reader = R.init(t.allocator, 7, s) catch unreachable; + defer reader.deinit(); + + try reader.startFlow(t.arena.allocator(), null); + const msg1 = try reader.next(); + try t.expectSlice(u8, &.{ 1, 2, 3, 4 }, msg1.data); +} diff --git a/zig/pg/src/result.zig b/zig/pg/src/result.zig new file mode 100644 index 0000000..7564b62 --- /dev/null +++ b/zig/pg/src/result.zig @@ -0,0 +1,2180 @@ +const std = @import("std"); +const lib = @import("lib.zig"); + +const types = lib.types; +const proto = lib.proto; +const Conn = lib.Conn; +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +pub const Result = struct { + number_of_columns: usize, + + // will be empty unless the query was executed with the column_names = true option + column_names: [][]const u8, + + _conn: *Conn, + _arena: *ArenaAllocator, + + // a sliced version of _state.oids (so we don't have to keep reslicing it to + // number_of_columns on each row) + _oids: []i32, + + // a sliced version of _state.values (so we don't have to keep reslicing it to + // number_of_columns on each row) + _values: []State.Value, + + // When true, result.deinit() will call conn.release() + // Used when the result came directly from the pool.query() helper. + _release_conn: bool, + + pub fn deinit(self: *const Result) void { + // value.data references the buffer of the reader, this buffer is potentially + // reused and potentially discarded. There are at least a few very good + // reasons why the least we can do is blank it out. + for (self._values) |*value| { + value.data = &[_]u8{}; + } + + self._conn._reader.endFlow() catch { + // this can only fail in extreme conditions (OOM) and it will only impact + // the next query (and if the app is using the pool, the pool will try to + // recover from this anyways) + self._conn._state = .fail; + }; + + if (self._release_conn) { + self._conn.release(); + } + + const arena = self._arena; + const allocator = arena.child_allocator; + arena.deinit(); + allocator.destroy(arena); + } + + // Caller should typically call next() until null is returned. + // But in some cases, that might not be desirable. So they can + // "drain" to empty the rest of the result. + // I don't want to do this implictly in deinit because it can fail + // and returning an error union in deinit is a pain for the caller. + pub fn drain(self: *Result) !void { + var conn = self._conn; + if (conn._state == .idle) { + return; + } + + while (true) { + const msg = try conn.read(); + switch (msg.type) { + 'C' => {}, // CommandComplete + 'D' => {}, // DataRow + 'Z' => return, + else => return error.UnexpectedDBMessage, + } + } + } + + pub fn next(self: *Result) !?Row { + return self._next(.safe); + } + pub fn nextUnsafe(self: *Result) !?RowUnsafe { + return self._next(.unsafe); + } + + fn _next(self: *Result, comptime fail_mode: lib.FailMode) !(if (fail_mode == .safe) ?Row else ?RowUnsafe) { + if (self._conn._state != .query) { + // Possibly weird state. Most likely cause is calling next() multiple times + // despite null being returned. + return null; + } + + const msg = try self._conn.read(); + switch (msg.type) { + 'D' => { + const data = msg.data; + // Since our Row API gets data by column #, we need translate the column + // # to a slice within msg.data. We could do this on the fly within Row, + // but creating this mapping up front simplifies things and, in normal + // cases, performs best. "Normal case" here assumes that the client app + // is going to fetch most/all columns. + + // first column starts at position 2 + var offset: usize = 2; + const values = self._values; + for (values) |*value| { + const data_start = offset + 4; + const length = std.mem.readInt(i32, data[offset..data_start][0..4], .big); + if (length == -1) { + value.is_null = true; + value.data = &[_]u8{}; + offset = data_start; + } else { + const data_end = data_start + @as(usize, @intCast(length)); + value.is_null = false; + value.data = data[data_start..data_end]; + offset = data_end; + } + } + + return .{ + .values = values, + .oids = self._oids, + ._result = self, + }; + }, + 'C' => { + try self._conn.readyForQuery(); + return null; + }, + else => return error.UnexpectedDBMessage, + } + } + + pub fn columnIndex(self: *const Result, column_name: []const u8) ?usize { + for (self.column_names, 0..) |n, i| { + if (std.mem.eql(u8, n, column_name)) { + return i; + } + } + return null; + } + + const MapperOpts = struct { + dupe: bool = false, + allocator: ?Allocator = null, + }; + + pub fn mapper(self: *Result, comptime T: type, opts: MapperOpts) Mapper(T) { + var column_indexes: [std.meta.fields(T).len]?usize = undefined; + + inline for (std.meta.fields(T), 0..) |field, i| { + column_indexes[i] = self.columnIndex(field.name); + } + + // if we're given an allocator, use that. + // if we're not given an allocator, but asked to dupe use our arena and thus + // tie the lifetime of the returned T to the lifetime of the DB result object. + var allocator: ?Allocator = null; + if (opts.allocator) |a| { + allocator = a; + } else if (opts.dupe) { + allocator = self._arena.allocator(); + } + + return .{ + .result = self, + .allocator = allocator, + .column_indexes = column_indexes, + }; + } + + // For every query, we need to store the type of each column (so we know + // how to parse the data). Optionally, we might need the name of each column. + // The connection has a default Result.State for a max # of columns, and we'll use + // that whenever we can. Otherwise, we'll create this dynamically. + pub const State = struct { + // The name for each returned column, we only populate this if we're told + // to (since it requires us to dupe the data) + names: [][]const u8, + + // This is different than the above. The above are set once per query + // from the RowDescription response of our Describe message. This is set for + // each DataRow message we receive. It maps a column position with the encoded + // value. + values: []Value, + + // The OID for each returned column + oids: []i32, + + pub const Value = struct { + is_null: bool, + data: []const u8, + }; + + pub fn init(allocator: Allocator, size: usize) !State { + const names = try allocator.alloc([]const u8, size); + errdefer allocator.free(names); + + const values = try allocator.alloc(Value, size); + errdefer allocator.free(values); + + const oids = try allocator.alloc(i32, size); + errdefer allocator.free(oids); + + return .{ + .names = names, + .values = values, + .oids = oids, + }; + } + + // Populates the State from the RowDescription payload + // We already read the number_of_columns from data, so we pass it in here + // We also already know that number_of_columns fits within our arrays + pub fn from(self: *State, number_of_columns: u16, data: []const u8, allocator: ?Allocator) !void { + // skip the column count, which we already know as number_of_columns + var pos: usize = 2; + + for (0..number_of_columns) |i| { + const end_pos = std.mem.indexOfScalarPos(u8, data, pos, 0) orelse return error.InvalidDataRow; + if (data.len < (end_pos + 19)) { + return error.InvalidDataRow; + } + if (allocator) |a| { + self.names[i] = try a.dupe(u8, data[pos..end_pos]); + } + + // skip the name null terminator (1) + // skip the table object_id this table belongs to (4) + // skip the attribute number of this table column (2) + pos = end_pos + 7; + + { + const end = pos + 4; + self.oids[i] = std.mem.readInt(i32, data[pos..end][0..4], .big); + pos = end; + } + + // skip date type size (2), type modifier (4) format code (2) + pos += 8; + } + } + + pub fn deinit(self: State, allocator: Allocator) void { + allocator.free(self.names); + allocator.free(self.values); + allocator.free(self.oids); + } + }; +}; + +pub const Row = RowT(.safe); +pub const RowUnsafe = RowT(.unsafe); + +pub fn RowT(comptime fail_mode: lib.FailMode) type { + return struct { + _result: *Result, + oids: []i32, + values: []Result.State.Value, + + const Self = @This(); + + pub fn get(self: *const Self, comptime T: type, col: usize) if (fail_mode == .safe) lib.TypeError!T else T { + const value = self.values[col]; + const TT = switch (@typeInfo(T)) { + .optional => |opt| { + if (value.is_null) { + return null; + } + const val = self.get(opt.child, col); + if (comptime fail_mode == .safe) { + return try val; + } + return val; + }, + .@"struct" => blk: { + if (@hasDecl(T, "fromPgzRow") == true) { + return T.fromPgzRow(value.data, self.oids[col]) catch |err| { + if (comptime fail_mode == .safe) { + return err; + } + std.debug.panic("PostgreSQL value of type {s} could not be read into a " ++ @typeName(T) ++ ".", .{types.oidToString(self.oids[col])}); + }; + } + break :blk T; + }, + else => blk: { + lib.verifyNotNull(fail_mode, T, value.is_null) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk T; + }, + }; + + return getScalar(fail_mode, TT, value.data, self.oids[col]); + } + + pub fn getCol(self: *const Self, comptime T: type, name: []const u8) if (fail_mode == .safe) lib.TypeError!T else T { + const col = self._result.columnIndex(name); + try lib.verifyColumnName(fail_mode, name, col != null); + return self.get(T, col.?); + } + + pub fn iterator(self: *const Self, comptime T: type, col: usize) if (fail_mode == .safe) lib.TypeError!Iterator(T) else IteratorUnsafe(T) { + const value = self.values[col]; + if (value.is_null) { + return IteratorT(fail_mode, T).asNull(); + } + return IteratorT(fail_mode, T).fromPgzRow(value.data, self.oids[col]) catch |err| { + if (comptime fail_mode == .safe) { + return err; + } + @panic("Could not get iterator of type " ++ @typeName(T) ++ " for row."); + }; + } + + pub fn iteratorCol(self: *const Self, comptime T: type, name: []const u8) if (fail_mode == .safe) lib.TypeError!Iterator(T) else IteratorUnsafe(T) { + const col = self._result.columnIndex(name); + try lib.verifyColumnName(fail_mode, name, col != null); + return self.iterator(T, col.?); + } + + pub fn record(self: *const Self, col: usize) RecordT(fail_mode) { + const data = self.values[col].data; + const number_of_columns = std.mem.readInt(i32, data[0..4], .big); + return .{ + .data = data[4..], + .number_of_columns = @intCast(number_of_columns), + }; + } + + pub fn recordCol(self: *const Self, name: []const u8) if (fail_mode == .safe) lib.TypeError!Record else RecordUnsafe { + const col = self._result.columnIndex(name); + try lib.verifyColumnName(fail_mode, name, col != null); + return self.record(col); + } + + const ToOpts = struct { + dupe: bool = false, + map: Mapping = .ordinal, + allocator: ?Allocator = null, + + const Mapping = enum { + name, + ordinal, + }; + }; + + pub fn to(self: *const Self, T: type, opts: ToOpts) !T { + // if we're given an allocator, use that. + // if we're not given an allocator, but asked to dupe use our arena and thus + // tie the lifetime of the returned T to the lifetime of the DB result object. + var allocator: ?Allocator = null; + if (opts.allocator) |a| { + allocator = a; + } else if (opts.dupe) { + allocator = self._result._arena.allocator(); + } + + return switch (opts.map) { + .ordinal => self.toUsingOrdinal(T, allocator), + .name => return self.toUsingName(T, allocator), + }; + } + + fn toUsingOrdinal(self: *const Self, T: type, allocator: ?Allocator) !T { + var value: T = undefined; + inline for (std.meta.fields(T), 0..) |field, column_index| { + @field(value, field.name) = try self.mapColumn(&field, column_index, allocator); + } + return value; + } + + fn toUsingName(self: *const Self, T: type, allocator: ?Allocator) !T { + var value: T = undefined; + const result = self._result; + inline for (std.meta.fields(T)) |field| { + const name = field.name; + @field(value, name) = try self.mapColumn(&field, result.columnIndex(name), allocator); + } + return value; + } + + fn mapColumn(self: *const Self, field: *const std.builtin.Type.StructField, optional_column_index: ?usize, allocator: ?Allocator) !field.type { + const T = field.type; + const column_index = optional_column_index orelse { + if (field.default_value_ptr) |dflt| { + return @as(*align(1) const field.type, @ptrCast(dflt)).*; + } + return error.FieldColumnMismatch; + }; + + if (comptime isSlice(T)) |S| { + const slice = blk: { + if (@typeInfo(T) == .optional) { + break :blk self.get(?Iterator(S), column_index) orelse return null; + } else { + break :blk self.get(Iterator(S), column_index); + } + }; + return try slice.alloc(allocator orelse return error.AllocatorRequiredForSliceMapping); + } + + const value = self.get(field.type, column_index); + const a = allocator orelse return value; + return mapValue(T, if (comptime fail_mode == .safe) try value else value, a); + } + + /// Write a single column's value as JSON to a buffer. + /// Handles all Postgres types: int, float, numeric, bool, text, jsonb, arrays, timestamps. + /// Returns the number of bytes written. + pub fn writeJsonValue(self: *const Self, col: usize, buf: []u8) usize { + const value = self.values[col]; + if (value.is_null) { + if (buf.len < 4) return 0; + @memcpy(buf[0..4], "null"); + return 4; + } + + const data = value.data; + const oid = self.oids[col]; + var pos: usize = 0; + + switch (oid) { + // Integers + 21 => { // int2 + const v = std.mem.readInt(i16, data[0..2], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + 23 => { // int4 + const v = std.mem.readInt(i32, data[0..4], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + 20 => { // int8 + const v = std.mem.readInt(i64, data[0..8], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + // Floats + 700 => { // float4 + const n = std.mem.readInt(i32, data[0..4], .big); + const v: f32 = @bitCast(n); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + 701 => { // float8 + const n = std.mem.readInt(i64, data[0..8], .big); + const v: f64 = @bitCast(n); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + // Numeric/Decimal — decode via Numeric type then format + 1700 => { + const numeric = types.Numeric.decode(fail_mode, data, oid) catch { + @memcpy(buf[pos..][0..4], "null"); + return pos + 4; + }; + const v = numeric.toFloat(); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch return 0; + pos += s.len; + }, + // Bool + 16 => { + const v = data[0] != 0; + const s = if (v) "true" else "false"; + @memcpy(buf[pos..][0..s.len], s); + pos += s.len; + }, + // JSON — already valid JSON, pass through + 114 => { + @memcpy(buf[pos..][0..data.len], data); + pos += data.len; + }, + // JSONB — strip version byte, pass through + 3802 => { + const json_data = data[1..]; + @memcpy(buf[pos..][0..json_data.len], json_data); + pos += json_data.len; + }, + // Integer arrays + 1005, 1007, 1016 => { // int2[], int4[], int8[] + pos += writeIntArrayJson(data, oid, buf[pos..]); + }, + // Text arrays + 1009, 1015 => { // text[], varchar[] + pos += writeTextArrayJson(data, buf[pos..]); + }, + // Timestamp — format as ISO 8601 + 1114, 1184 => { // timestamp, timestamptz + const usec = std.mem.readInt(i64, data[0..8], .big); + // Postgres epoch is 2000-01-01, Unix is 1970-01-01 + const pg_epoch_offset: i64 = 946684800; + const unix_sec = @divTrunc(usec, 1_000_000) + pg_epoch_offset; + buf[pos] = '"'; + pos += 1; + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{unix_sec}) catch return 0; + pos += s.len; + buf[pos] = '"'; + pos += 1; + }, + // Everything else — check for pgvector, then fall back to quoted string + else => { + // pgvector support — dynamic OID, check at runtime + if (types.Vector.oid_decimal != 0 and oid == types.Vector.oid_decimal) { + const vec = types.Vector.decode(data); + pos += vec.writeJson(buf[pos..]); + return pos; + } + // Default: quote as string (SIMD-accelerated escape scan) + buf[pos] = '"'; + pos += 1; + pos += simdJsonEscape(data, buf[pos..]); + buf[pos] = '"'; + pos += 1; + }, + } + return pos; + } + + /// Write an entire row as a JSON object: {"col1":val1,"col2":val2,...} + /// Requires column_names to be populated (queryOpts with .column_names = true). + pub fn writeJsonRow(self: *const Self, col_names: []const []const u8, buf: []u8) usize { + var pos: usize = 0; + buf[pos] = '{'; + pos += 1; + + const ncols = @min(col_names.len, self.values.len); + for (0..ncols) |i| { + if (i > 0) { + buf[pos] = ','; + pos += 1; + } + // Column name + buf[pos] = '"'; + pos += 1; + @memcpy(buf[pos..][0..col_names[i].len], col_names[i]); + pos += col_names[i].len; + buf[pos] = '"'; + pos += 1; + buf[pos] = ':'; + pos += 1; + // Value + pos += self.writeJsonValue(i, buf[pos..]); + } + + buf[pos] = '}'; + pos += 1; + return pos; + } + }; +} + +/// SIMD-accelerated JSON string escaping. +/// Scans 16 bytes at a time for chars needing escape (", \, control chars). +/// Falls back to scalar for remainder and when escapes are found. +fn simdJsonEscape(data: []const u8, buf: []u8) usize { + const simd_width = 16; + var pos: usize = 0; + var i: usize = 0; + + // SIMD fast path: check 16 bytes at a time for escape-needing chars + while (i + simd_width <= data.len and pos + simd_width + 16 <= buf.len) { + const chunk: @Vector(simd_width, u8) = data[i..][0..simd_width].*; + // Check for chars needing escape: control (<0x20), quote, backslash + const ctrl_mask = chunk < @as(@Vector(simd_width, u8), @splat(0x20)); + const quote_mask = chunk == @as(@Vector(simd_width, u8), @splat('"')); + const bslash_mask = chunk == @as(@Vector(simd_width, u8), @splat('\\')); + const any_ctrl = @reduce(.Or, ctrl_mask); + const any_quote = @reduce(.Or, quote_mask); + const any_bslash = @reduce(.Or, bslash_mask); + + if (!any_ctrl and !any_quote and !any_bslash) { + // Fast path: no escaping needed, bulk copy + @memcpy(buf[pos..][0..simd_width], data[i..][0..simd_width]); + pos += simd_width; + i += simd_width; + } else { + // Slow path: at least one char needs escaping, do scalar + for (data[i..][0..simd_width]) |ch| { + if (ch == '"' or ch == '\\') { + buf[pos] = '\\'; + pos += 1; + } + if (ch >= 0x20) { + buf[pos] = ch; + pos += 1; + } + } + i += simd_width; + } + } + + // Scalar remainder + while (i < data.len and pos + 2 < buf.len) { + const ch = data[i]; + if (ch == '"' or ch == '\\') { + buf[pos] = '\\'; + pos += 1; + } + if (ch >= 0x20) { + buf[pos] = ch; + pos += 1; + } + i += 1; + } + + return pos; +} + +/// Parse Postgres binary int array and write as JSON: [1,2,3] +fn writeIntArrayJson(data: []const u8, oid: i32, buf: []u8) usize { + // Postgres binary array format: + // 4 bytes: ndim, 4 bytes: flags, 4 bytes: element OID + // per dimension: 4 bytes length, 4 bytes lower bound + // then: per element: 4 bytes length (or -1 for null), then data + if (data.len < 12) { + @memcpy(buf[0..2], "[]"); + return 2; + } + + const ndim = std.mem.readInt(i32, data[0..4], .big); + if (ndim == 0) { + @memcpy(buf[0..2], "[]"); + return 2; + } + + // Element size based on array OID + const elem_size: usize = switch (oid) { + 1005 => 2, // int2[] + 1007 => 4, // int4[] + 1016 => 8, // int8[] + else => 4, + }; + + const nelems = std.mem.readInt(i32, data[12..16], .big); + var pos: usize = 0; + buf[pos] = '['; + pos += 1; + + var offset: usize = 20; // skip header (12 + 4 length + 4 lower bound) + var i: i32 = 0; + while (i < nelems and offset + 4 <= data.len) : (i += 1) { + if (i > 0) { + buf[pos] = ','; + pos += 1; + } + const elem_len = std.mem.readInt(i32, data[offset..][0..4], .big); + offset += 4; + if (elem_len == -1) { + @memcpy(buf[pos..][0..4], "null"); + pos += 4; + } else if (elem_size == 2 and offset + 2 <= data.len) { + const v = std.mem.readInt(i16, data[offset..][0..2], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch break; + pos += s.len; + offset += 2; + } else if (elem_size == 4 and offset + 4 <= data.len) { + const v = std.mem.readInt(i32, data[offset..][0..4], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch break; + pos += s.len; + offset += 4; + } else if (elem_size == 8 and offset + 8 <= data.len) { + const v = std.mem.readInt(i64, data[offset..][0..8], .big); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch break; + pos += s.len; + offset += 8; + } else { + offset += @intCast(@as(u32, @bitCast(elem_len))); + } + } + + buf[pos] = ']'; + pos += 1; + return pos; +} + +/// Parse Postgres binary text array and write as JSON: ["a","b","c"] +fn writeTextArrayJson(data: []const u8, buf: []u8) usize { + if (data.len < 12) { + @memcpy(buf[0..2], "[]"); + return 2; + } + + const ndim = std.mem.readInt(i32, data[0..4], .big); + if (ndim == 0) { + @memcpy(buf[0..2], "[]"); + return 2; + } + + const nelems = std.mem.readInt(i32, data[12..16], .big); + var pos: usize = 0; + buf[pos] = '['; + pos += 1; + + var offset: usize = 20; + var i: i32 = 0; + while (i < nelems and offset + 4 <= data.len) : (i += 1) { + if (i > 0) { + buf[pos] = ','; + pos += 1; + } + const elem_len = std.mem.readInt(i32, data[offset..][0..4], .big); + offset += 4; + if (elem_len == -1) { + @memcpy(buf[pos..][0..4], "null"); + pos += 4; + } else { + const slen: usize = @intCast(@as(u32, @bitCast(elem_len))); + if (offset + slen > data.len) break; + buf[pos] = '"'; + pos += 1; + for (data[offset..][0..slen]) |ch| { + if (pos + 2 >= buf.len) break; + if (ch == '"' or ch == '\\') { + buf[pos] = '\\'; + pos += 1; + } + buf[pos] = ch; + pos += 1; + } + buf[pos] = '"'; + pos += 1; + offset += slen; + } + } + + buf[pos] = ']'; + pos += 1; + return pos; +} + +fn isSlice(comptime T: type) ?type { + switch (@typeInfo(T)) { + .pointer => |ptr| { + if (ptr.size != .slice) { + compileHaltGetError(T); + } + return if (ptr.child == u8) null else ptr.child; + }, + .optional => |opt| return isSlice(opt.child), + else => return null, + } +} + +fn mapValue(comptime T: type, value: T, allocator: Allocator) !T { + switch (@typeInfo(T)) { + .optional => |opt| { + if (value) |v| { + return try mapValue(opt.child, v, allocator); + } + return null; + }, + else => {}, + } + + if (T == []u8 or T == []const u8) { + return try allocator.dupe(u8, value); + } + + if (std.meta.hasFn(T, "pgzMoveOwner")) { + return value.pgzMoveOwner(allocator); + } + + return value; +} + +pub fn Mapper(comptime T: type) type { + return struct { + result: *Result, + allocator: ?Allocator, + column_indexes: [std.meta.fields(T).len]?usize, + + const Self = @This(); + + pub fn next(self: *const Self) !?T { + const row = (try self.result.next()) orelse return null; + + var value: T = undefined; + + const allocator = self.allocator; + inline for (std.meta.fields(T), self.column_indexes) |field, optional_column_index| { + @field(value, field.name) = try row.mapColumn(&field, optional_column_index, allocator); + } + return value; + } + }; +} + +pub const QueryRow = QueryRowT(.safe); +pub const QueryRowUnsafe = QueryRowT(.unsafe); + +pub fn QueryRowT(comptime fail_mode: lib.FailMode) type { + return struct { + row: RowT(fail_mode), + result: *Result, + + const Self = @This(); + + pub fn get(self: *const Self, comptime T: type, col: usize) if (fail_mode == .safe) lib.TypeError!T else T { + return self.row.get(T, col); + } + + pub fn getCol(self: *const Self, comptime T: type, name: []const u8) if (fail_mode == .safe) lib.TypeError!T else T { + return self.row.getCol(T, name); + } + + pub fn iterator(self: *const Self, comptime T: type, col: usize) if (fail_mode == .safe) lib.TypeError!Iterator(T) else IteratorUnsafe(T) { + return self.row.iterator(T, col); + } + + pub fn iteratorCol(self: *const Self, comptime T: type, name: []const u8) if (fail_mode == .safe) lib.TypeError!Iterator(T) else IteratorUnsafe(T) { + return self.row.iteratorCol(T, name); + } + + pub fn record(self: *const Self, col: usize) RecordT(fail_mode) { + return self.row.record(col); + } + + pub fn recordCol(self: *const Self, name: []const u8) if (fail_mode == .safe) lib.TypeError!Record else RecordUnsafe { + return self.row.recordCol(name); + } + pub fn to(self: *const Self, T: type, opts: Row.ToOpts) !T { + return self.row.to(T, opts); + } + + pub fn deinit(self: *Self) !void { + // this is unfortunate + try self.result.drain(); + self.result.deinit(); + } + }; +} + +pub fn Iterator(comptime T: type) type { + return IteratorT(.safe, T); +} +pub fn IteratorUnsafe(comptime T: type) type { + return IteratorT(.unsafe, T); +} +pub fn IteratorT(comptime fail_mode: lib.FailMode, comptime T: type) type { + return struct { + is_null: bool, + _len: usize, + _pos: usize, + _data: []const u8, + _decoder: *const fn (data: []const u8) ItemType(), + + fn ItemType() type { + return switch (@typeInfo(T)) { + .optional => |opt| opt.child, + else => T, + }; + } + + const Self = @This(); + + pub fn len(self: Self) usize { + return self._len; + } + + fn asNull() Self { + return .{ + .is_null = true, + ._len = 0, + ._pos = 0, + ._data = &.{}, + ._decoder = struct { + fn noop(_: []const u8) ItemType() { + unreachable; + } + }.noop, + }; + } + + // used internally by row.get(Iterator(T)) + fn fromPgzRow(data: []const u8, oid: i32) !Self { + const TT = switch (@typeInfo(T)) { + .optional => |opt| opt.child, + else => T, + }; + + const decoder = switch (TT) { + u8 => blk: { + lib.verifyDecodeType(fail_mode, []u8, &.{types.CharArray.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Char.decodeKnown; + }, + i16 => blk: { + lib.verifyDecodeType(fail_mode, []i16, &.{types.Int16Array.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Int16.decodeKnown; + }, + i32 => blk: { + lib.verifyDecodeType(fail_mode, []i32, &.{types.Int32Array.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Int32.decodeKnown; + }, + i64 => switch (oid) { + types.TimestampArray.oid.decimal => &types.Timestamp.decodeKnown, + types.TimestampTzArray.oid.decimal => &types.Timestamp.decodeKnown, + types.Int64Array.oid.decimal => &types.Int64.decodeKnown, + else => std.debug.panic("{d} oid cannot target i64 iterator", .{oid}), + }, + f32 => blk: { + lib.verifyDecodeType(fail_mode, []f32, &.{types.Float32Array.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Float32.decodeKnown; + }, + f64 => switch (oid) { + types.Float64Array.oid.decimal => &types.Float64.decodeKnown, + types.NumericArray.oid.decimal => &types.Numeric.decodeKnownToFloat, + else => std.debug.panic("{d} oid cannot target f64 iterator", .{oid}), + }, + bool => blk: { + lib.verifyDecodeType(fail_mode, []bool, &.{types.BoolArray.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Bool.decodeKnown; + }, + []const u8 => switch (oid) { + types.JSONBArray.oid.decimal => &types.JSONB.decodeKnown, + else => &types.Bytea.decodeKnown, + }, + []u8 => switch (oid) { + types.JSONBArray.oid.decimal => &types.JSONB.decodeKnownMutable, + else => &types.Bytea.decodeKnownMutable, + }, + types.Numeric => blk: { + lib.verifyDecodeType(fail_mode, []f64, &.{types.NumericArray.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Numeric.decodeKnown; + }, + types.Cidr => blk: { + lib.verifyDecodeType(fail_mode, []types.Cidr, &.{ types.CidrArray.oid.decimal, types.CidrArray.inet_oid.decimal }, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &types.Cidr.decodeKnown; + }, + else => switch (@typeInfo(TT)) { + .@"enum" => blk: { + lib.verifyDecodeType(fail_mode, []const u8, &.{types.StringArray.oid.decimal}, oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + break :blk &EnumDecoder(TT).decodeKnown; + }, + else => compileHaltGetError(T), + }, + }; + + if (data.len == 12) { + // we have an empty array + return .{ + .is_null = false, + ._len = 0, + ._pos = 0, + ._data = &[_]u8{}, + ._decoder = decoder, + }; + } + + // minimum size for 1 empty array + lib.assert(data.len >= 20); + const dimensions = std.mem.readInt(i32, data[0..4], .big); + lib.assert(dimensions == 1); + + const has_nulls = std.mem.readInt(i32, data[4..8][0..4], .big); + lib.assert(has_nulls == 0 or @typeInfo(T) == .optional); + + // const oid = std.mem.readInt(i32, data[8..12][0..4], .big); + const l = std.mem.readInt(i32, data[12..16][0..4], .big); + // const lower_bound = std.mem.readInt(i32, data[16..20][0..4], .big); + + return .{ + .is_null = false, + ._len = @intCast(l), + ._pos = 0, + ._data = data[20..], + ._decoder = decoder, + }; + } + + pub fn pgzMoveOwner(self: Self, allocator: Allocator) !Self { + return .{ + .is_null = false, + ._len = self._len, + ._pos = self._pos, + ._data = try allocator.dupe(u8, self._data), + ._decoder = self._decoder, + }; + } + + // Should only be called if the Iterator was created with row.to(...) + // or a result mapper AND an explicit allocator was given + pub fn deinit(self: *const Self, allocator: Allocator) void { + allocator.free(self._data); + } + + pub fn next(self: *Self) ?T { + const pos = self._pos; + const data = self._data; + if (pos == data.len) { + return null; + } + + // TODO: for fixed length types, we don't need to decode the length + const len_end = pos + 4; + const value_len = std.mem.readInt(i32, data[pos..len_end][0..4], .big); + + const data_end = len_end + @as(usize, @intCast(value_len)); + lib.assert(data.len >= data_end); + + self._pos = data_end; + return self._decoder(data[len_end..data_end]); + } + + pub fn alloc(self: *const Self, allocator: Allocator) ![]T { + const into = try allocator.alloc(T, self._len); + try self.fillAlloc(true, into, allocator); + return into; + } + + pub fn fill(self: *const Self, into: []T) void { + self.fillAlloc(false, into, undefined) catch unreachable; + } + + fn fillAlloc(self: *const Self, comptime should_dupe: bool, into: []T, allocator: Allocator) !void { + const data = self._data; + const decoder = self._decoder; + + var pos: usize = 0; + const limit = @min(into.len, self._len); + for (0..limit) |i| { + // TODO: for fixed length types, we don't need to decode the length + const len_end = pos + 4; + const data_len = std.mem.readInt(i32, data[pos..len_end][0..4], .big); + + if ((comptime @typeInfo(T) == .optional) and data_len == -1) { + pos = len_end; + into[i] = null; + } else { + pos = len_end + @as(usize, @intCast(data_len)); + if (comptime should_dupe and (T == []u8 or T == []const u8)) { + into[i] = try allocator.dupe(u8, decoder(data[len_end..pos])); + } else { + into[i] = decoder(data[len_end..pos]); + } + } + } + } + }; +} + +fn EnumDecoder(comptime T: type) type { + return struct { + pub fn decodeKnown(data: []const u8) T { + return std.meta.stringToEnum(T, data).?; + } + }; +} + +fn compileHaltGetError(comptime T: type) noreturn { + @compileError("cannot get value of type " ++ @typeName(T)); +} + +pub const Record = RecordT(.safe); +pub const RecordUnsafe = RecordT(.unsafe); + +pub fn RecordT(comptime fail_mode: lib.FailMode) type { + return struct { + data: []const u8, + number_of_columns: usize, + + const Self = @This(); + + pub fn next(self: *Self, comptime T: type) if (fail_mode == .safe) lib.TypeError!T else T { + var data = self.data; + + // at least 4 bytes for the type and 4 bytes for the lenght + lib.assert(data.len >= 8); + + const oid = std.mem.readInt(i32, data[0..4], .big); + + data = data[4..]; + const len = std.mem.readInt(i32, data[0..4], .big); + + const TT = switch (@typeInfo(T)) { + .optional => |opt| blk: { + if (len == -1) return null; + break :blk opt.child; + }, + else => T, + }; + + // end of the data for this "column" + const end = @as(usize, @intCast(len)) + 4; + + // the rest of the data + self.data = data[end..]; + + // start at 4 to skip the length which we already read + return getScalar(fail_mode, TT, data[4..end], oid); + } + }; +} + +fn getScalar(comptime fail_mode: lib.FailMode, comptime T: type, data: []const u8, oid: i32) if (fail_mode == .safe) lib.TypeError!T else T { + switch (T) { + u8 => return types.Char.decode(fail_mode, data, oid), + i16 => return types.Int16.decode(fail_mode, data, oid), + i32 => return types.Int32.decode(fail_mode, data, oid), + i64 => return types.Int64.decode(fail_mode, data, oid), + f32 => return types.Float32.decode(fail_mode, data, oid), + f64 => return types.Float64.decode(fail_mode, data, oid), + bool => return types.Bool.decode(fail_mode, data, oid), + []const u8 => return types.Bytea.decode(data, oid), + []u8 => return @constCast(types.Bytea.decode(data, oid)), + types.Numeric => return types.Numeric.decode(fail_mode, data, oid), + types.Cidr => return types.Cidr.decode(fail_mode, data, oid), + else => switch (@typeInfo(T)) { + .@"enum" => { + const str = types.Bytea.decode(data, oid); + return std.meta.stringToEnum(T, str).?; + }, + else => compileHaltGetError(T), + }, + } +} + +const t = lib.testing; + +test "Result: ints" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::smallint, $2::int, $3::bigint"; + + { + // int max + var result = try c.query(sql, .{ @as(i16, 32767), @as(i32, 2147483647), @as(i64, 9223372036854775807) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(32767, row.get(i16, 0)); + try t.expectEqual(2147483647, row.get(i32, 1)); + try t.expectEqual(9223372036854775807, row.get(i64, 2)); + + try t.expectEqual(32767, row.get(?i16, 0)); + try t.expectEqual(2147483647, row.get(?i32, 1)); + try t.expectEqual(9223372036854775807, row.get(?i64, 2)); + + try t.expectEqual(null, result.next()); + } + + { + // int min + var result = try c.query(sql, .{ @as(i16, -32768), @as(i32, -2147483648), @as(i64, -9223372036854775808) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(-32768, row.get(i16, 0)); + try t.expectEqual(-2147483648, row.get(i32, 1)); + try t.expectEqual(-9223372036854775808, row.get(i64, 2)); + try result.drain(); + } + + { + // int null + var result = try c.query(sql, .{ null, null, null }); + defer result.deinit(); + defer result.drain() catch unreachable; + const row = (try result.nextUnsafe()).?; + try t.expectEqual(null, row.get(?i16, 0)); + try t.expectEqual(null, row.get(?i32, 1)); + try t.expectEqual(null, row.get(?i64, 2)); + } + + { + // uint within limit + var result = try c.query(sql, .{ @as(u16, 32767), @as(u32, 2147483647), @as(u64, 9223372036854775807) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(32767, row.get(i16, 0)); + try t.expectEqual(2147483647, row.get(i32, 1)); + try t.expectEqual(9223372036854775807, row.get(i64, 2)); + + try t.expectEqual(32767, row.get(?i16, 0)); + try t.expectEqual(2147483647, row.get(?i32, 1)); + try t.expectEqual(9223372036854775807, row.get(?i64, 2)); + try result.drain(); + } + + { + // u16 outside of limit + try t.expectError(error.IntWontFit, c.query(sql, .{ @as(u16, 32768), @as(u32, 0), @as(u64, 0) })); + // u32 outside of limit + try t.expectError(error.IntWontFit, c.query(sql, .{ @as(u16, 0), @as(u32, 2147483648), @as(u64, 0) })); + // u64 outside of limit + try t.expectError(error.IntWontFit, c.query(sql, .{ @as(u16, 0), @as(u32, 0), @as(u64, 9223372036854775808) })); + } +} + +test "Result: floats" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::float4, $2::float8"; + + { + // positive float + var result = try c.query(sql, .{ @as(f32, 1.23456), @as(f64, 1093.229183) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(1.23456, row.get(f32, 0)); + try t.expectEqual(1093.229183, row.get(f64, 1)); + + try t.expectEqual(1.23456, row.get(?f32, 0)); + try t.expectEqual(1093.229183, row.get(?f64, 1)); + + try t.expectEqual(null, result.next()); + } + + { + // negative float + var result = try c.query(sql, .{ @as(f32, -392.31), @as(f64, -99991.99992) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(-392.31, row.get(f32, 0)); + try t.expectEqual(-99991.99992, row.get(f64, 1)); + try t.expectEqual(null, result.next()); + } + + { + // null float + var result = try c.query(sql, .{ null, null }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(null, row.get(?f32, 0)); + try t.expectEqual(null, row.get(?f64, 1)); + try t.expectEqual(null, result.next()); + } +} + +test "Result: bool" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::bool"; + + { + // true + var result = try c.query(sql, .{true}); + defer result.deinit(); + defer result.drain() catch unreachable; + const row = (try result.nextUnsafe()).?; + try t.expectEqual(true, row.get(bool, 0)); + try t.expectEqual(true, row.get(?bool, 0)); + try t.expectEqual(null, result.next()); + } + + { + // false + var result = try c.query(sql, .{false}); + defer result.deinit(); + defer result.drain() catch unreachable; + const row = (try result.nextUnsafe()).?; + try t.expectEqual(false, row.get(bool, 0)); + try t.expectEqual(false, row.get(?bool, 0)); + try t.expectEqual(null, result.next()); + } + + { + // null + var result = try c.query(sql, .{null}); + defer result.deinit(); + defer result.drain() catch unreachable; + const row = (try result.nextUnsafe()).?; + try t.expectEqual(null, row.get(?bool, 0)); + try t.expectEqual(null, result.next()); + } +} + +test "Result: text and bytea" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::text, $2::bytea"; + + { + // empty + var result = try c.query(sql, .{ "", "" }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectString("", row.get([]u8, 0)); + try t.expectString("", row.get(?[]u8, 0).?); + try t.expectString("", row.get([]u8, 1)); + try t.expectString("", row.get(?[]u8, 1).?); + try result.drain(); + } + + { + // not empty + var result = try c.query(sql, .{ "it's over 9000!!!", "i will Not fear" }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectString("it's over 9000!!!", row.get([]u8, 0)); + try t.expectString("it's over 9000!!!", row.get(?[]const u8, 0).?); + try t.expectString("i will Not fear", row.get([]const u8, 1)); + try t.expectString("i will Not fear", row.get(?[]u8, 1).?); + try result.drain(); + } + + { + // as an array + var result = try c.query(sql, .{ [_]u8{ 'a', 'c', 'b' }, [_]u8{ 'z', 'z', '3' } }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectString("acb", row.get([]const u8, 0)); + try t.expectString("acb", row.get(?[]u8, 0).?); + try t.expectString("zz3", row.get([]const u8, 1)); + try t.expectString("zz3", row.get(?[]u8, 1).?); + try result.drain(); + } + + { + // as a slice + const s1 = try t.allocator.alloc(u8, 4); + defer t.allocator.free(s1); + @memcpy(s1, "Leto"); + + var result = try c.query(sql, .{ s1, constString() }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectString("Leto", row.get([]u8, 0)); + try t.expectString("Leto", row.get(?[]u8, 0).?); + try t.expectString("Ghanima", row.get([]u8, 1)); + try t.expectString("Ghanima", row.get(?[]u8, 1).?); + try result.drain(); + } + + { + // null + var result = try c.query(sql, .{ null, null }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(null, row.get(?[]u8, 0)); + try t.expectEqual(null, row.get(?[]u8, 1)); + try result.drain(); + } +} + +fn constString() []const u8 { + return "Ghanima"; +} + +test "Result: optional" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::int, $2::int"; + + { + // int max + var result = try c.query(sql, .{ @as(?i32, 321), @as(?i32, null) }); + defer result.deinit(); + const row = (try result.nextUnsafe()).?; + try t.expectEqual(321, row.get(i32, 0)); + + try t.expectEqual(321, row.get(?i32, 0)); + try t.expectEqual(null, row.get(?i32, 1)); + try t.expectEqual(null, result.next()); + } +} + +test "Result: iterator" { + var c = t.connect(.{}); + defer c.deinit(); + + { + // empty row.iterator() + var result = try c.query("select $1::int[]", .{[_]i32{}}); + defer result.deinit(); + var row = (try result.nextUnsafe()).?; + + var iterator = row.iterator(i32, 0); + try t.expectEqual(0, iterator.len()); + + try t.expectEqual(null, iterator.next()); + try t.expectEqual(null, iterator.next()); + + const a = try iterator.alloc(t.allocator); + try t.expectEqual(0, a.len); + try result.drain(); + } + + { + // empty row.get() + var result = try c.query("select $1::int[]", .{[_]i32{}}); + defer result.deinit(); + var row = (try result.nextUnsafe()).?; + + var iterator = row.get(Iterator(i32), 0); + try t.expectEqual(0, iterator.len()); + + try t.expectEqual(null, iterator.next()); + try t.expectEqual(null, iterator.next()); + + const a = try iterator.alloc(t.allocator); + try t.expectEqual(0, a.len); + try result.drain(); + } + + { + // one: row.iterator + var result = try c.query("select $1::int[]", .{[_]i32{9}}); + defer result.deinit(); + var row = (try result.nextUnsafe()).?; + + var iterator = row.iterator(i32, 0); + try t.expectEqual(1, iterator.len()); + + try t.expectEqual(9, iterator.next()); + try t.expectEqual(null, iterator.next()); + + const arr = try iterator.alloc(t.allocator); + defer t.allocator.free(arr); + try t.expectEqual(1, arr.len); + try t.expectSlice(i32, &.{9}, arr); + try result.drain(); + } + + { + // one: row.get + var result = try c.query("select $1::int[]", .{[_]i32{9}}); + defer result.deinit(); + var row = (try result.nextUnsafe()).?; + + var iterator = row.get(Iterator(i32), 0); + try t.expectEqual(1, iterator.len()); + + try t.expectEqual(9, iterator.next()); + try t.expectEqual(null, iterator.next()); + + const arr = try iterator.alloc(t.allocator); + defer t.allocator.free(arr); + try t.expectEqual(1, arr.len); + try t.expectSlice(i32, &.{9}, arr); + try result.drain(); + } + + { + // fill + var result = try c.query("select $1::int[]", .{[_]i32{ 0, -19 }}); + defer result.deinit(); + var row = (try result.nextUnsafe()).?; + + var iterator = row.iterator(i32, 0); + try t.expectEqual(2, iterator.len()); + + try t.expectEqual(0, iterator.next()); + try t.expectEqual(-19, iterator.next()); + try t.expectEqual(null, iterator.next()); + + var arr1: [2]i32 = undefined; + iterator.fill(&arr1); + try t.expectSlice(i32, &.{ 0, -19 }, &arr1); + try result.drain(); + + // smaller + var arr2: [1]i32 = undefined; + iterator.fill(&arr2); + try t.expectSlice(i32, &.{0}, &arr2); + try result.drain(); + } +} + +test "Result: null iterator" { + var c = t.connect(.{}); + defer c.deinit(); + + { + // null int + var result = try c.query("select $1::int[]", .{null}); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + var iterator = row.iterator(i32, 0); + try t.expectEqual(true, iterator.is_null); + try t.expectEqual(null, iterator.next()); + try result.drain(); + } + + { + // null text + var result = try c.query("select $1::text[]", .{null}); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + var iterator = row.iterator([]u8, 0); + try t.expectEqual(true, iterator.is_null); + try t.expectEqual(null, iterator.next()); + try result.drain(); + } +} + +test "Result: int[]" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::smallint[], $2::int[], $3::bigint[]"; + + var result = try c.query(sql, .{ [_]i16{ -303, 9449, 2 }, [_]i32{ -3003, 49493229, 0 }, [_]i64{ 944949338498392, -2 } }); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + const v1 = try row.iterator(i16, 0).alloc(t.allocator); + defer t.allocator.free(v1); + try t.expectSlice(i16, &.{ -303, 9449, 2 }, v1); + + const v2 = try row.iterator(i32, 1).alloc(t.allocator); + defer t.allocator.free(v2); + try t.expectSlice(i32, &.{ -3003, 49493229, 0 }, v2); + + const v3 = try row.iterator(i64, 2).alloc(t.allocator); + defer t.allocator.free(v3); + try t.expectSlice(i64, &.{ 944949338498392, -2 }, v3); +} + +test "Result: float[]" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::float4[], $2::float8[]"; + + var result = try c.query(sql, .{ [_]f32{ 1.1, 0, -384.2 }, [_]f64{ -888585.123322, 0.001 } }); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + const v1 = try row.iterator(f32, 0).alloc(t.allocator); + defer t.allocator.free(v1); + try t.expectSlice(f32, &.{ 1.1, 0, -384.2 }, v1); + + const v2 = try row.iterator(f64, 1).alloc(t.allocator); + defer t.allocator.free(v2); + try t.expectSlice(f64, &.{ -888585.123322, 0.001 }, v2); +} + +test "Result: bool[]" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::bool[]"; + + var result = try c.query(sql, .{[_]bool{ true, false, false }}); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + const v1 = try row.iterator(bool, 0).alloc(t.allocator); + defer t.allocator.free(v1); + try t.expectSlice(bool, &.{ true, false, false }, v1); +} + +test "Result: text[] & bytea[]" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::text[], $2::bytea[]"; + + var arr1 = [_]u8{ 0, 1, 2 }; + var arr2 = [_]u8{255}; + var result = try c.query(sql, .{ [_][]const u8{ "over", "9000" }, [_][]u8{ &arr1, &arr2 } }); + defer result.deinit(); + + var row = (try result.nextUnsafe()).?; + + const v1 = try row.iterator([]u8, 0).alloc(t.allocator); + defer { + t.allocator.free(v1[0]); + t.allocator.free(v1[1]); + t.allocator.free(v1); + } + try t.expectString("over", v1[0]); + try t.expectString("9000", v1[1]); + try t.expectEqual(2, v1.len); + + const v2 = try row.iterator([]const u8, 1).alloc(t.allocator); + defer { + t.allocator.free(v2[0]); + t.allocator.free(v2[1]); + t.allocator.free(v2); + } + try t.expectString(&arr1, v2[0]); + try t.expectString(&arr2, v2[1]); + try t.expectEqual(2, v2.len); +} + +test "Result: text[] alloc dupes" { + var c = t.connect(.{}); + defer c.deinit(); + + var arr1: [][]const u8 = undefined; + var arr2: [][]const u8 = undefined; + defer { + for (arr1) |str| { + t.allocator.free(str); + } + t.allocator.free(arr1); + + for (arr2) |str| { + t.allocator.free(str); + } + t.allocator.free(arr2); + } + + { + var row = (try c.rowUnsafe("select array['Leto', 'Test']::text[]", .{})) orelse unreachable; + defer row.deinit() catch {}; + arr1 = try row.iterator([]const u8, 0).alloc(t.allocator); + } + + { + var row = (try c.rowUnsafe("select array['Ghanima', 'Goku']::text[]", .{})) orelse unreachable; + defer row.deinit() catch {}; + arr2 = try row.iterator([]const u8, 0).alloc(t.allocator); + } + + try t.expectStringSlice(&.{ "Leto", "Test" }, arr1); + try t.expectStringSlice(&.{ "Ghanima", "Goku" }, arr2); +} + +test "Result: UUID" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::uuid, $2::uuid"; + var result = try c.query(sql, .{ "fcbebf0f-b996-43b9-9818-672bc689cda8", &[_]u8{ 174, 47, 71, 95, 128, 112, 65, 183, 186, 51, 134, 187, 168, 137, 123, 222 } }); + defer result.deinit(); + + const row = (try result.nextUnsafe()).?; + try t.expectSlice(u8, &.{ 252, 190, 191, 15, 185, 150, 67, 185, 152, 24, 103, 43, 198, 137, 205, 168 }, row.get([]u8, 0)); + try t.expectSlice(u8, &.{ 174, 47, 71, 95, 128, 112, 65, 183, 186, 51, 134, 187, 168, 137, 123, 222 }, row.get([]u8, 1)); +} + +test "Result: lsn" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::pg_lsn + 1"; + var result = try c.query(sql, .{32788447688}); + defer result.deinit(); + + const row = (try result.nextUnsafe()).?; + try t.expectEqual(32788447689, row.get(i64, 0)); +} + +test "Row: column names" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select 923 as id, 'Leto' as name"; + var row = (try c.rowUnsafeOpts(sql, .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + try t.expectEqual(923, row.getCol(i32, "id")); + try t.expectString("Leto", row.getCol([]u8, "name")); +} + +test "Result: mutable []u8" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select 'Leto'"; + var row = (try c.rowUnsafe(sql, .{})).?; + defer row.deinit() catch {}; + + var name = row.get([]u8, 0); + name[3] = '!'; + try t.expectString("Let!", name); +} + +test "Result: mutable [][]u8" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select array['Leto', 'Test']::text[]"; + var row = (try c.rowUnsafe(sql, .{})).?; + defer row.deinit() catch {}; + + var values = try row.iterator([]u8, 0).alloc(t.allocator); + defer { + t.allocator.free(values[0]); + t.allocator.free(values[1]); + t.allocator.free(values); + } + values[0][0] = 'n'; + try t.expectString("neto", values[0]); + try t.expectString("Test", values[1]); +} + +test "Row.to: ordinal" { + const User = struct { + id: i32, + active: bool, + name: []const u8, + note: ?[]const u8, + choice: Choice, + + const Choice = enum { + blue, + green, + red, + }; + }; + + var c = t.connect(.{}); + defer c.deinit(); + + { + // null, no dupe + var row = (try c.rowUnsafe("select 1::integer, true, 'teg', null::text, 'blue'", .{})).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{}); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + try t.expectEqual(.blue, user.choice); + } + + { + // not null, no dupe + var row = (try c.rowUnsafe("select 2::integer, false, 'ghanima', 'n1', 'red'", .{})).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{}); + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + try t.expectEqual(.red, user.choice); + } + + { + // null, dupe with internal arena + var row = (try c.rowUnsafe("select 1::integer, true, 'teg', null::text, 'red'", .{})).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .dupe = true }); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + try t.expectEqual(.red, user.choice); + } + + { + // not null, dupe with internal arena + var row = (try c.rowUnsafe("select 2::integer, false, 'ghanima', 'n1', 'red'", .{})).?; + const user = try row.to(User, .{ .dupe = true }); + defer row.deinit() catch {}; + + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + try t.expectEqual(.red, user.choice); + } + + { + // null, dupe with explicit allocator + var row = (try c.rowUnsafe("select 1::integer, true, 'teg', null::text, 'red'", .{})).?; + const user = try row.to(User, .{ .allocator = t.allocator }); + row.deinit() catch {}; + + defer t.allocator.free(user.name); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + try t.expectEqual(.red, user.choice); + } + + { + // not null, dupe with explicit allocator + var row = (try c.rowUnsafe("select 2::integer, false, 'ghanima', 'n1', 'red'", .{})).?; + + const user = try row.to(User, .{ .allocator = t.allocator }); + row.deinit() catch {}; + + defer t.allocator.free(user.name); + defer t.allocator.free(user.note.?); + + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + try t.expectEqual(.red, user.choice); + } +} + +test "Row.to: name no map" { + const User = struct { + id: i32 = 9876, + active: bool, + name: []const u8, + note: ?[]const u8 = null, + }; + + var c = t.connect(.{}); + defer c.deinit(); + + { + // null, no dupe + var row = (try c.rowUnsafeOpts("select 1 as id, true as active, 'teg' as name, null as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .map = .name }); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + } + + { + // default values are used if no colum + // and extra columns are ignored + var row = (try c.rowUnsafeOpts("select 2 as id, false as active, 'ghanima' as name, 'x123' as other", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .map = .name }); + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectEqual(null, user.note); + } + + { + // nullable fields are nulled if no column + // and extra columns are ignored + var row = (try c.rowUnsafeOpts("select false as active, 'ghanima' as name, 'x123' as other", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .map = .name }); + try t.expectEqual(9876, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectEqual(null, user.note); + } + + { + // error on missing column with non-default value + var row = (try c.rowUnsafeOpts("select 1 as id", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + try t.expectError(error.FieldColumnMismatch, row.to(User, .{ .map = .name })); + } + + { + // not null, no dupe + var row = (try c.rowUnsafeOpts("select 2::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .map = .name }); + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + } + + { + // null, dupe with internal arena + var row = (try c.rowUnsafeOpts("select 1::integer as id, true as active, 'teg' as name, null::text as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .dupe = true, .map = .name }); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + } + + { + // not null, dupe with internal arena + var row = (try c.rowUnsafeOpts("select 2::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .dupe = true, .map = .name }); + try t.expectEqual(2, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + } + + { + // null, dupe with explicit allocator + var row = (try c.rowUnsafeOpts("select 1::integer as id, true as active, 'teg' as name, null::text as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .allocator = t.allocator, .map = .name }); + defer t.allocator.free(user.name); + try t.expectEqual(1, user.id); + try t.expectEqual(true, user.active); + try t.expectString("teg", user.name); + try t.expectEqual(null, user.note); + } + + { + // not null, dupe with explicit allocator + var row = (try c.rowUnsafeOpts("select 5::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{}, .{ .column_names = true })).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{ .allocator = t.allocator, .map = .name }); + defer t.allocator.free(user.name); + defer t.allocator.free(user.note.?); + + try t.expectEqual(5, user.id); + try t.expectEqual(false, user.active); + try t.expectString("ghanima", user.name); + try t.expectString("n1", user.note.?); + } +} + +test "Result.Mapper" { + var c = t.connect(.{}); + defer c.deinit(); + + { + // mapper with missing column and non-default field + var result = try c.queryOpts("select 1", .{}, .{ .column_names = true }); + defer result.deinit(); + const mapper = result.mapper(struct { id: i32 }, .{}); + try t.expectError(error.FieldColumnMismatch, mapper.next()); + try result.drain(); + } + + // null, no dupe + try expectResultMapper(&c, "select 1 as id, true as active, 'teg' as name, null as note", .{ + .id = 1, + .active = true, + .name = "teg", + .note = null, + }, .{}); + + // default values are used if no colum + // and extra columns are ignored + try expectResultMapper(&c, "select 2 as id, false as active, 'ghanima' as name, 'x123' as other", .{ + .id = 2, + .active = false, + .name = "ghanima", + .note = null, + }, .{}); + + // nullable fields are nulled if no column + // and extra columns are ignored + try expectResultMapper(&c, "select false as active, 'ghanima' as name, 'x123' as other", .{ + .id = 9876, + .active = false, + .name = "ghanima", + .note = null, + }, .{}); + + // not null, no dupe + try expectResultMapper(&c, "select 2::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{ + .id = 2, + .active = false, + .name = "ghanima", + .note = "n1", + }, .{}); + + // null, dupe with internal arena + try expectResultMapper(&c, "select 1::integer as id, true as active, 'teg' as name, null::text as note", .{ + .id = 1, + .active = true, + .name = "teg", + .note = null, + }, .{ .dupe = true }); + + // not null, dupe with internal arena + try expectResultMapper(&c, "select 3::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{ + .id = 3, + .active = false, + .name = "ghanima", + .note = "n1", + }, .{ .dupe = true }); + + // null, dupe with explicit allocator + try expectResultMapper(&c, "select 4::integer as id, true as active, 'teg' as name, null::text as note", .{ + .id = 4, + .active = true, + .name = "teg", + .note = null, + }, .{ .allocator = t.allocator }); + + // not null, dupe with explicit allocator + try expectResultMapper(&c, "select 5::integer as id, false as active, 'ghanima' as name, 'n1' as note", .{ + .id = 5, + .active = false, + .name = "ghanima", + .note = "n1", + }, .{ .allocator = t.allocator }); +} + +test "Row.to: iterator" { + const User = struct { + parents: Iterator(i32), + tags: ?Iterator([]const u8), + }; + + defer t.reset(); + var c = t.connect(.{}); + defer c.deinit(); + + { + var row = (try c.rowUnsafe("select array[1, 99]::integer[], null", .{})).?; + defer row.deinit() catch {}; + + const user = try row.to(User, .{}); + try t.expectSlice(i32, &.{ 1, 99 }, try user.parents.alloc(t.arena.allocator())); + try t.expectEqual(null, user.tags); + } + + { + var row = (try c.rowUnsafe("select array[0]::integer[], array['over', '9000']::text[]", .{})).?; + const user = try row.to(User, .{ .allocator = t.allocator }); + row.deinit() catch {}; + + defer user.parents.deinit(t.allocator); + defer user.tags.?.deinit(t.allocator); + + try t.expectSlice(i32, &.{0}, try user.parents.alloc(t.arena.allocator())); + try t.expectStringSlice(&.{ "over", "9000" }, try user.tags.?.alloc(t.arena.allocator())); + } + + { + // dupe with result arena + var result = try c.query( + \\ select array[0]::integer[], array['over']::text[] + \\ union all + \\ select array[1]::integer[], array['9000']::text[] + , .{}); + + const user1 = try (try result.nextUnsafe()).?.to(User, .{ .dupe = true }); + const user2 = try (try result.nextUnsafe()).?.to(User, .{ .dupe = true }); + try t.expectEqual(null, try result.nextUnsafe()); + defer result.deinit(); + + try t.expectSlice(i32, &.{0}, try user1.parents.alloc(t.arena.allocator())); + try t.expectStringSlice(&.{"over"}, try user1.tags.?.alloc(t.arena.allocator())); + + try t.expectSlice(i32, &.{1}, try user2.parents.alloc(t.arena.allocator())); + try t.expectStringSlice(&.{"9000"}, try user2.tags.?.alloc(t.arena.allocator())); + } + + { + // dupe with explicit arena + var result = try c.query( + \\ select array[0]::integer[], array['over']::text[] + \\ union all + \\ select array[1]::integer[], array['9000']::text[] + , .{}); + + const user1 = try (try result.nextUnsafe()).?.to(User, .{ .allocator = t.allocator }); + const user2 = try (try result.nextUnsafe()).?.to(User, .{ .allocator = t.allocator }); + try t.expectEqual(null, try result.nextUnsafe()); + result.deinit(); + + defer user1.tags.?.deinit(t.allocator); + defer user1.parents.deinit(t.allocator); + defer user2.tags.?.deinit(t.allocator); + defer user2.parents.deinit(t.allocator); + + try t.expectSlice(i32, &.{0}, try user1.parents.alloc(t.arena.allocator())); + try t.expectStringSlice(&.{"over"}, try user1.tags.?.alloc(t.arena.allocator())); + + try t.expectSlice(i32, &.{1}, try user2.parents.alloc(t.arena.allocator())); + try t.expectStringSlice(&.{"9000"}, try user2.tags.?.alloc(t.arena.allocator())); + } +} + +test "Row.to: array" { + const User = struct { + parents: []i32, + tags: ?[][]const u8, + choices: ?[]Choice, + + const Choice = enum { + red, + blue, + green, + }; + }; + + defer t.reset(); + var c = t.connect(.{}); + defer c.deinit(); + + { + var row = (try c.rowUnsafe("select array[1, 99]::integer[], array['over', '9000']::text[], array['red', 'green']::text[]", .{})).?; + const user = try row.to(User, .{ .allocator = t.allocator }); + row.deinit() catch {}; + + defer { + t.allocator.free(user.tags.?[0]); + t.allocator.free(user.tags.?[1]); + t.allocator.free(user.tags.?); + t.allocator.free(user.parents); + t.allocator.free(user.choices.?); + } + try t.expectSlice(i32, &.{ 1, 99 }, user.parents); + try t.expectStringSlice(&.{ "over", "9000" }, user.tags.?); + try t.expectSlice(User.Choice, &.{ .red, .green }, user.choices.?); + } + + { + var row = (try c.rowUnsafe("select array[1, 99]::integer[], null::text[], null::text[]", .{})).?; + const user = try row.to(User, .{ .allocator = t.allocator }); + row.deinit() catch {}; + + defer { + t.allocator.free(user.parents); + } + try t.expectSlice(i32, &.{ 1, 99 }, user.parents); + try t.expectEqual(null, user.tags); + try t.expectEqual(null, user.choices); + } +} + +test "Result: safe" { + var c = t.connect(.{}); + defer c.deinit(); + const sql = "select $1::int, $2::int"; + + { + var result = try c.query(sql, .{ @as(?i32, 321), @as(?i32, null) }); + defer result.deinit(); + const row = (try result.next()).?; + try t.expectEqual(321, try row.get(i32, 0)); + try t.expectEqual(error.InvalidType, row.get(bool, 0)); + + try t.expectEqual(321, try row.get(?i32, 0)); + try t.expectEqual(null, try row.get(?i32, 1)); + try t.expectEqual(null, result.next()); + } +} + +fn expectResultMapper(conn: *Conn, sql: []const u8, expected: anytype, opts: Result.MapperOpts) !void { + const User = struct { + id: i32 = 9876, + active: bool, + name: []const u8, + note: ?[]const u8 = null, + }; + + var result = try conn.queryOpts(sql, .{}, .{ .column_names = true }); + defer result.deinit(); + var mapper = result.mapper(User, opts); + + const user = (try mapper.next()) orelse unreachable; + try t.expectEqual(expected.id, user.id); + try t.expectEqual(expected.active, user.active); + try t.expectString(expected.name, user.name); + if (opts.allocator) |a| { + a.free(user.name); + } + if (@TypeOf(expected.note) == @TypeOf(null)) { + try t.expectEqual(null, user.note); + } else { + try t.expectString(expected.note, user.note.?); + if (opts.allocator) |a| { + a.free(user.note.?); + } + } + + try t.expectEqual(null, mapper.next()); +} diff --git a/zig/pg/src/stmt.zig b/zig/pg/src/stmt.zig new file mode 100644 index 0000000..c7ea239 --- /dev/null +++ b/zig/pg/src/stmt.zig @@ -0,0 +1,407 @@ +const std = @import("std"); +const lib = @import("lib.zig"); +const Buffer = @import("buffer").Buffer; + +const types = lib.types; +const Conn = lib.Conn; +const Result = lib.Result; +const Allocator = std.mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +pub const Stmt = struct { + buf: *Buffer, + + opts: Conn.QueryOpts, + + conn: *Conn, + // Executing a stmt may or may not require allocations. It depends on the + // number of columns, number of parameters, size of the SQL, size of the + // serialized values and our configuration (e.g. how big + // our write buffer is). + arena: *ArenaAllocator, + + // Every call to stmt.bind increments this value. Important because the Bind + // message contains all the parameter meta data first, then the serialized + // values. So when we bind a parameter, we need to jump around our buf payload + // based on the param_index * $some_offset. + param_index: u16, + + // Number of parameters in the query. + param_count: u16, + + // Offset of the current Bind message inside the write buffer so we can + // patch its length without scanning arbitrary payload bytes. + bind_start: usize, + + // The type of each parameter, which postgresql tells us after we send it the + // SQL and ask for a description. `param_oids.len` can be greater than + // `param_count` because we initially use the conn._param_oids which is + // globally configured. + param_oids: []i32, + + // Number of colums in the result + column_count: u16, + + // Information about the colums in the result, which postgresql tells us after + // we send it the SQL and ask for a description. The slices in this structure + // can be larger than `column_count` because we initially conn._result_state + // which is globally configured. + result_state: Result.State, + + // Name of the prepared statement. Empty == unnamed, so it won't be cached + // by the server + name: []const u8, + + pub fn init(conn: *Conn, opts: Conn.QueryOpts) !Stmt { + const base_allocator = opts.allocator orelse conn._allocator; + const arena = try base_allocator.create(ArenaAllocator); + arena.* = ArenaAllocator.init(base_allocator); + + return .{ + .conn = conn, + .opts = opts, + .buf = &conn._buf, + .arena = arena, + .param_index = 0, + .param_count = 0, + .bind_start = 0, + .param_oids = conn._param_oids, + .column_count = 0, + .result_state = conn._result_state, + .name = opts.cache_name orelse "", + }; + } + + pub fn fromDescribe(conn: *Conn, describe: *Describe, opts: Conn.QueryOpts) !Stmt { + const base_allocator = opts.allocator orelse conn._allocator; + const arena = try base_allocator.create(ArenaAllocator); + arena.* = ArenaAllocator.init(base_allocator); + + return .{ + .conn = conn, + .opts = opts, + .buf = &conn._buf, + .arena = arena, + .param_index = 0, + .param_count = @intCast(describe.param_oids.len), + .bind_start = 0, + .param_oids = describe.param_oids, + .column_count = @intCast(describe.result_state.oids.len), + .result_state = describe.result_state, + .name = opts.cache_name.?, + }; + } + + // Should only be called in an error case. In a normal case, where + // stmt.execute() returns a result, stmt.deinit() must not be called (all + // ownership is passed to the result). + pub fn deinit(self: *Stmt) void { + self.conn._reader.endFlow() catch { + // this can only fail in extreme conditions (OOM) and it will only impact + // the next query (and if the app is using the pool, the pool will try to + // recover from this anyways) + self.conn._state = .fail; + }; + + const arena = self.arena; + const allocator = arena.child_allocator; + arena.deinit(); + allocator.destroy(arena); + } + + // When describe_allocator != null, we intend to cache the query information + // (in conn.__prepared_statements). + pub fn prepare(self: *Stmt, sql: []const u8, describe_allocator: ?Allocator) !void { + var conn = self.conn; + const opts = &self.opts; + const statement_arena = self.arena.allocator(); + + try conn._reader.startFlow(statement_arena, opts.timeout); + + var buf = self.buf; + buf.reset(); + + const name = self.name; + + // This function will issue 3 commands: Parse, Describe, Sync + // We need the response from describe to put together our Bind message. + // Specifically, describe will tell us the type of the return columns, and + // in Bind, we tell the server how we want it to encode each column (text + // or binary) and to do that, we need to know what they are. + { + // Build the payload from our 3 commands + + // We can calculate exactly how many bytes our 3 messages are going to be + // and make sure our buffer is big enough, thus avoiding some unecessary + // bound checking + const bind_payload_len = 8 + sql.len + name.len; + const describe_payload_len = 6 + name.len; + const sync_payload_len = 4; + + // the +3 for the initial byte message for each of the 3 messages + const total_length = 3 + bind_payload_len + describe_payload_len + sync_payload_len; + + try buf.ensureTotalCapacity(total_length); + var view = buf.skip(total_length) catch unreachable; + + // PARSE + view.writeByte('P'); + view.writeIntBig(u32, @intCast(bind_payload_len)); + view.write(name); + view.writeByte(0); + view.write(sql); + // null terminate sql string, and we'll be specifying 0 parameter types + view.write(&.{ 0, 0, 0 }); + + // DESCRIBE + view.writeByte('D'); + view.writeIntBig(u32, @intCast(describe_payload_len)); + view.writeByte('S'); // Describe a prepared statement + view.write(name); + view.writeByte(0); // null terminate our name + + // SYNC + view.write(&.{ 'S', 0, 0, 0, 4 }); + try conn.write(buf.string()); + } + + // no longer idle, we're now in a query + conn._state = .query; + + // First message we expect back is a ParseComplete, which has no data. + { + // If Parse fails, then the server won't reply to our other messages + // (i.e. Describe) and it'l immediately send a ReadyForQuery. + const msg = conn.read() catch |err| { + conn.readyForQuery() catch {}; + return err; + }; + + if (msg.type != '1') { + return conn.unexpectedDBMessage(); + } + } + + var param_count: u16 = 0; + + { + // we expect a ParameterDescription message + const msg = try conn.read(); + if (msg.type != 't') { + return conn.unexpectedDBMessage(); + } + + var param_oids = self.param_oids; + const data = msg.data; + param_count = std.mem.readInt(u16, data[0..2], .big); + if (describe_allocator) |da| { + // If we plan on caching this prepared statement, then we need + // to allocate a new param_oids list which will outlive this + // statement + param_oids = try da.alloc(i32, param_count); + self.param_oids = param_oids; + } else if (param_count > param_oids.len) { + lib.metrics.allocParams(param_count); + param_oids = try statement_arena.alloc(i32, param_count); + self.param_oids = param_oids; + } + + var pos: usize = 2; + for (0..param_count) |i| { + const end = pos + 4; + param_oids[i] = std.mem.readInt(i32, data[pos..end][0..4], .big); + pos = end; + } + self.param_count = param_count; + } + + { + // We now expect an answer to our describe message. + // This is either going to be a RowDescription, or a NoData. NoData means + // our statement doesn't return any data. Either way, we're going to use + // this information when we generate our Bind message, next. + const msg = try conn.read(); + switch (msg.type) { + 'n' => {}, // no data, column_count = 0 + 'T' => { + var state = self.result_state; + const data = msg.data; + const column_count = std.mem.readInt(u16, data[0..2], .big); + if (describe_allocator) |da| { + // If we plan on caching this prepared statement, then we need + // to allocate a new param_oids list which will outlive this + // statement + state = try Result.State.init(da, column_count); + self.result_state = state; + } else if (column_count > state.oids.len) { + lib.metrics.allocColumns(column_count); + // we have more columns than our self._result_state can handle, we + // need to create a new Result.State specifically for this + state = try Result.State.init(statement_arena, column_count); + self.result_state = state; + } + const a: ?Allocator = if (opts.column_names) (describe_allocator orelse statement_arena) else null; + try state.from(column_count, data, a); + self.column_count = column_count; + }, + else => return conn.unexpectedDBMessage(), + } + } + + return self.prepareForBind(param_count); + } + + // We need to call Bind for every value we're binding. Rather than having + // to check "is this the first call to bind" each time, we make it the caller's + // responsibility to "prepareForBind" upfront. + pub fn prepareForBind(self: *Stmt, param_count: u16) !void { + try self.conn.readyForQuery(); + + var buf = self.buf; + buf.resetRetainingCapacity(); + try self.startBindMessage(param_count); + } + + pub fn startBindMessage(self: *Stmt, param_count: u16) !void { + self.param_index = 0; + + var buf = self.buf; + const name = self.name; + self.bind_start = buf.len(); + + // Bind command = 'B' + // 4 byte length placeholder - 0, 0, 0, 0 + // portal name (empty string, length 0) - 0 + // prepared statement name + null terminator + // parameter format count + parameter formats + value count + const bind_prefix_len = 1 + 4 + 1 + name.len + 1 + 2 + (param_count * 2) + 2; + try buf.ensureUnusedCapacity(bind_prefix_len); + + // We reserved the fixed-size prefix above, so these append operations can + // skip their own bounds checks. + buf.writeAssumeCapacity(&.{ 'B', 0, 0, 0, 0, 0 }); + + buf.writeAssumeCapacity(name); + buf.writeByteAssumeCapacity(0); + + // number of parameters types we're sending a + try buf.writeIntBig(u16, param_count); + + // the format (text or binary) of each parameter. We'll default to text + // for now, and fill this in as we get the data + try buf.writeByteNTimes(0, param_count * 2); + + // number of parameters we're sending a + try buf.writeIntBig(u16, param_count); + } + + pub fn bind(self: *Stmt, value: anytype) !void { + const name = self.name; + + const param_index = self.param_index; + lib.assert(param_index < self.param_count); + + // We tell PostgreSQL the format (text or binary) of each parameter. This + // information is at the start of the message, always starts at byte 9 + // and each value is 2 bytes. + const format_offset = self.bind_start + 9 + (param_index * 2) + name.len; + + try types.bindValue(@TypeOf(value), self.param_oids[param_index], value, self.buf, format_offset); + self.param_index = param_index + 1; + } + + pub fn bindDynamic(self: *Stmt, value: types.DynamicValue) !void { + const name = self.name; + + const param_index = self.param_index; + lib.assert(param_index < self.param_count); + + const format_offset = self.bind_start + 9 + (param_index * 2) + name.len; + + try types.bindDynamicValue(self.param_oids[param_index], value, self.buf, format_offset); + self.param_index = param_index + 1; + } + + pub fn finishExecuteMessage(self: *Stmt, append_sync: bool) !void { + lib.assert(self.param_index == self.param_count); + + const buf = self.buf; + + try lib.types.resultEncoding(self.result_state.oids[0..self.column_count], buf); + const bind_end = buf.len(); + var bind_len: [4]u8 = undefined; + std.mem.writeInt(u32, &bind_len, @intCast(bind_end - self.bind_start - 1), .big); + @memcpy(buf.buf[self.bind_start + 1 .. self.bind_start + 5], &bind_len); + + try buf.write(&.{ + 'E', + 0, + 0, + 0, + 9, + 0, + 0, + 0, + 0, + 0, + }); + if (append_sync) { + try buf.write(&.{ 'S', 0, 0, 0, 4 }); + } + } + + pub fn execute(self: *Stmt) !*Result { + const buf = self.buf; + const conn = self.conn; + + try self.finishExecuteMessage(true); + + try conn.write(buf.string()); + + { + const msg = conn.read() catch |err| { + conn.readyForQuery() catch {}; + return err; + }; + if (msg.type != '2') { + // expecting a BindComplete + return conn.unexpectedDBMessage(); + } + } + + try conn.peekForError(); + + // our call to readyForQuery above changed the state, but as far as we're + // concerned, we're still doing the query. + conn._state = .query; + + lib.metrics.query(); + + const opts = &self.opts; + const state = self.result_state; + const column_count = self.column_count; + + const arena = self.arena; + + // Put result on the heap largely for the QueryRow (created via the + // conn.row(...) helper). This allows QueryRow.result and QueryRow.row._result + // to reference the result, which isn't otherwise owned. + const result = try arena.allocator().create(Result); + result.* = .{ + ._conn = conn, + ._arena = self.arena, + ._release_conn = opts.release_conn, + ._oids = state.oids[0..column_count], + ._values = state.values[0..column_count], + .column_names = if (opts.column_names) state.names[0..column_count] else &[_][]const u8{}, + .number_of_columns = column_count, + }; + return result; + } + + pub const Describe = struct { + param_oids: []i32, + arena: ArenaAllocator, + result_state: Result.State, + }; +}; diff --git a/zig/pg/src/stream.zig b/zig/pg/src/stream.zig new file mode 100644 index 0000000..e1ff254 --- /dev/null +++ b/zig/pg/src/stream.zig @@ -0,0 +1,338 @@ +const std = @import("std"); +const lib = @import("lib.zig"); + +const openssl = lib.openssl; + +const posix = std.posix; + +const Conn = lib.Conn; +const Allocator = std.mem.Allocator; + +const DEFAULT_HOST = "127.0.0.1"; + +// `-Diouring=true` selects a per-connection io_uring transport on +// Linux. It's incompatible with the OpenSSL TLS path in this PR +// (plaintext only), so prefer TLS if both are configured. +pub const Stream = if (lib.has_openssl) + TLSStream +else if (lib.has_iouring) + IoUringStream +else + PlainStream; + +const TLSStream = struct { + valid: bool, + ssl: ?*openssl.SSL, + socket: posix.socket_t, + + pub fn connect(allocator: Allocator, opts: Conn.Opts, ctx_: ?*openssl.SSL_CTX) !Stream { + const plain = try PlainStream.connect(allocator, opts, null); + errdefer plain.close(); + + const sock_fd = plain.socket; + + var ssl: ?*openssl.SSL = null; + if (ctx_) |ctx| { + // PostgreSQL TLS starts off as a plain connection which we upgrade + try writeSocket(sock_fd, &.{ 0, 0, 0, 8, 4, 210, 22, 47 }); + var buf = [1]u8{0}; + _ = try readSocket(sock_fd, &buf); + if (buf[0] != 'S') { + return error.SSLNotSupportedByServer; + } + + ssl = openssl.SSL_new(ctx) orelse return error.SSLNewFailed; + errdefer openssl.SSL_free(ssl); + + if (opts.host) |host| { + if (isHostName(host)) { + // don't send this for an ip address + var owned = false; + const h = opts._hostz orelse blk: { + owned = true; + break :blk try allocator.dupeZ(u8, host); + }; + + defer if (owned) { + allocator.free(h); + }; + + if (openssl.SSL_set_tlsext_host_name(ssl, h.ptr) != 1) { + return error.SSLHostNameFailed; + } + } + switch (opts.tls) { + .verify_full => openssl.SSL_set_verify(ssl, openssl.SSL_VERIFY_PEER, null), + else => {}, + } + } + + if (openssl.SSL_set_fd(ssl, if (@import("builtin").os.tag == .windows) @intCast(@intFromPtr(sock_fd)) else sock_fd) != 1) { + return error.SSLSetFdFailed; + } + + { + const ret = openssl.SSL_connect(ssl); + if (ret != 1) { + const verification_code = openssl.SSL_get_verify_result(ssl); + if (comptime lib._stderr_tls) { + lib.printSSLError(); + } + if (verification_code != openssl.X509_V_OK) { + if (comptime lib._stderr_tls) { + std.debug.print("ssl verification error: {s}\n", .{openssl.X509_verify_cert_error_string(verification_code)}); + } + return error.SSLCertificationVerificationError; + } + return error.SSLConnectFailed; + } + } + } + + return .{ + .ssl = ssl, + .valid = true, + .socket = sock_fd, + }; + } + + pub fn close(self: *Stream) void { + if (self.ssl) |ssl| { + if (self.valid) { + _ = openssl.SSL_shutdown(ssl); + self.valid = false; + } + openssl.SSL_free(ssl); + } + _ = std.c.close(self.socket); + } + + pub fn writeAll(self: *Stream, data: []const u8) !void { + if (self.ssl) |ssl| { + const result = openssl.SSL_write(ssl, data.ptr, @intCast(data.len)); + if (result <= 0) { + self.valid = false; + return error.SSLWriteFailed; + } + return; + } + return writeSocket(self.socket, data); + } + + pub fn read(self: *Stream, buf: []u8) !usize { + if (self.ssl) |ssl| { + var read_len: usize = undefined; + const result = openssl.SSL_read_ex(ssl, buf.ptr, @intCast(buf.len), &read_len); + if (result <= 0) { + self.valid = false; + return error.SSLReadFailed; + } + return read_len; + } + + return readSocket(self.socket, buf); + } +}; + +const PlainStream = struct { + socket: posix.socket_t, + + pub fn connect(_: Allocator, opts: Conn.Opts, _: anytype) !PlainStream { + const sock_fd = blk: { + const host = opts.host orelse DEFAULT_HOST; + if (host.len > 0 and host[0] == '/') { + if (comptime std.Io.net.has_unix_sockets == false or std.posix.AF == void) { + return error.UnixPathNotSupported; + } + break :blk try connectUnixSocket(host); + } + const port = opts.port orelse 5432; + break :blk try tcpConnectToHost(host, port); + }; + errdefer _ = std.c.close(sock_fd); + + return .{ + .socket = sock_fd, + }; + } + + pub fn close(self: *const PlainStream) void { + _ = std.c.close(self.socket); + } + + pub fn writeAll(self: *const PlainStream, data: []const u8) !void { + return writeSocket(self.socket, data); + } + + pub fn read(self: *const PlainStream, buf: []u8) !usize { + return readSocket(self.socket, buf); + } +}; + +// Per-connection io_uring transport. Each connection owns a small ring; +// every writeAll / read submits a single SEND / RECV SQE and waits for +// the matching CQE. This is intentionally the simplest possible shape +// — no SQPOLL, no multi-shot, no fd registration. Per-op cost is +// roughly 2 enter syscalls (submit + wait) vs 1 send/recv syscall on +// the blocking path, so we expect a small regression for low-latency +// loopback queries unless other factors dominate. +// +// The point of this PR is to land the abstraction; future work can +// turn on SQPOLL or batch outbound flushes. +const IoUringStream = if (@import("builtin").os.tag == .linux) struct { + socket: posix.socket_t, + // Pointer-stable ring so the kernel's internal references survive + // the `Stream` value being copied around inside Conn. + ring: *std.os.linux.IoUring, + allocator: Allocator, + + const RING_ENTRIES: u16 = 8; + + pub fn connect(allocator: Allocator, opts: Conn.Opts, _: anytype) !IoUringStream { + // Reuse the existing blocking connect path (getaddrinfo + + // connect). We only redirect read/write through io_uring. + const plain = try PlainStream.connect(allocator, opts, null); + errdefer plain.close(); + + const ring = try allocator.create(std.os.linux.IoUring); + errdefer allocator.destroy(ring); + ring.* = try std.os.linux.IoUring.init(RING_ENTRIES, 0); + errdefer ring.deinit(); + + return .{ + .socket = plain.socket, + .ring = ring, + .allocator = allocator, + }; + } + + pub fn close(self: *IoUringStream) void { + self.ring.deinit(); + self.allocator.destroy(self.ring); + _ = std.c.close(self.socket); + } + + pub fn writeAll(self: *IoUringStream, data: []const u8) !void { + var remaining = data; + while (remaining.len > 0) { + const sqe = try self.ring.get_sqe(); + sqe.prep_send(self.socket, remaining, 0); + sqe.user_data = 1; + _ = try self.ring.submit(); + const cqe = try self.ring.copy_cqe(); + if (cqe.res < 0) return error.BrokenPipe; + const n: usize = @intCast(cqe.res); + if (n == 0) return error.BrokenPipe; + remaining = remaining[n..]; + } + } + + pub fn read(self: *IoUringStream, buf: []u8) !usize { + const sqe = try self.ring.get_sqe(); + sqe.prep_recv(self.socket, buf, 0); + sqe.user_data = 2; + _ = try self.ring.submit(); + const cqe = try self.ring.copy_cqe(); + if (cqe.res <= 0) return error.ConnectionResetByPeer; + return @intCast(cqe.res); + } +} else struct { + // Non-Linux stub. `lib.has_iouring` is gated on os.tag == .linux, + // so the Stream alias never resolves to this on macOS/Windows; the + // type still has to exist so the alias compiles. + socket: posix.socket_t = undefined, + pub fn connect(_: Allocator, _: Conn.Opts, _: anytype) !@This() { + return error.IoUringUnsupported; + } + pub fn close(_: *@This()) void {} + pub fn writeAll(_: *@This(), _: []const u8) !void { + return error.IoUringUnsupported; + } + pub fn read(_: *@This(), _: []u8) !usize { + return error.IoUringUnsupported; + } +}; + +fn readSocket(fd: posix.socket_t, buf: []u8) !usize { + const n = posix.read(fd, buf) catch return error.ConnectionResetByPeer; + if (n == 0) return error.ConnectionResetByPeer; + return n; +} + +fn writeSocket(fd: posix.socket_t, data: []const u8) !void { + var remaining = data; + while (remaining.len > 0) { + const n = write(fd, remaining.ptr, remaining.len); + if (n <= 0) return error.BrokenPipe; + remaining = remaining[@intCast(n)..]; + } +} + +fn isHostName(host: []const u8) bool { + if (std.mem.indexOfScalar(u8, host, ':') != null) { + // IPv6 + return false; + } + return std.mem.indexOfNone(u8, host, "0123456789.") != null; +} + +const builtin = @import("builtin"); + +extern "c" fn socket(domain: c_int, socket_type: c_int, protocol: c_int) c_int; +extern "c" fn connect(sockfd: c_int, addr: *const anyopaque, addrlen: u32) c_int; +extern "c" fn write(fd: c_int, buf: [*]const u8, nbytes: usize) isize; + +const SockaddrUn = switch (builtin.os.tag) { + .driverkit, .ios, .maccatalyst, .macos, .tvos, .visionos, .watchos => extern struct { + len: u8 = 0, + family: u8 = 1, + path: [104]u8 = [_]u8{0} ** 104, + }, + else => extern struct { + family: u16 = 1, + path: [108]u8 = [_]u8{0} ** 108, + }, +}; + +fn connectUnixSocket(path: []const u8) !posix.socket_t { + const fd = socket(std.c.AF.UNIX, std.c.SOCK.STREAM, 0); + if (fd < 0) return error.SystemResources; + errdefer _ = std.c.close(fd); + var addr: SockaddrUn = .{}; + if (path.len >= addr.path.len) return error.NameTooLong; + if (comptime builtin.os.tag.isDarwin()) addr.len = @as(u8, @sizeOf(SockaddrUn)); + @memcpy(addr.path[0..path.len], path); + if (connect(fd, &addr, @sizeOf(SockaddrUn)) < 0) return error.ConnectionRefused; + return fd; +} + +fn tcpConnectToHost(host: []const u8, port: u16) !posix.socket_t { + var host_buf: [1025]u8 = std.mem.zeroes([1025]u8); + if (host.len > 1024) return error.NameTooLong; + @memcpy(host_buf[0..host.len], host); + var port_buf: [8]u8 = undefined; + const port_str = std.fmt.bufPrintZ(&port_buf, "{d}", .{port}) catch unreachable; + const hints = std.c.addrinfo{ + .flags = .{}, + .family = std.c.AF.UNSPEC, + .socktype = std.c.SOCK.STREAM, + .protocol = 0, + .addrlen = 0, + .canonname = null, + .addr = null, + .next = null, + }; + var result: ?*std.c.addrinfo = null; + if (@intFromEnum(std.c.getaddrinfo(host_buf[0..host.len :0].ptr, port_str.ptr, &hints, &result)) != 0) + return error.UnknownHostName; + defer std.c.freeaddrinfo(result.?); + var it = result; + while (it) |info| : (it = info.next) { + const addr = info.addr orelse continue; + const fd = socket(@intCast(info.family), @intCast(info.socktype), @intCast(info.protocol)); + if (fd < 0) continue; + if (connect(fd, addr, @intCast(info.addrlen)) >= 0) return fd; + _ = std.c.close(fd); + } + return error.ConnectionRefused; +} diff --git a/zig/pg/src/t.zig b/zig/pg/src/t.zig new file mode 100644 index 0000000..e2a4c14 --- /dev/null +++ b/zig/pg/src/t.zig @@ -0,0 +1,241 @@ +const std = @import("std"); + +const Allocator = std.mem.Allocator; +const Conn = @import("conn.zig").Conn; + +pub const allocator = std.testing.allocator; + +pub var arena = std.heap.ArenaAllocator.init(allocator); + +pub fn reset() void { + _ = arena.reset(.free_all); +} + +// std.testing.expectEqual won't coerce expected to actual, which is a problem +// when expected is frequently a comptime. +// https://github.com/ziglang/zig/issues/4437 +pub fn expectEqual(expected: anytype, actual: anytype) !void { + try std.testing.expectEqual(@as(@TypeOf(actual), expected), actual); +} +pub fn expectDelta(expected: anytype, actual: anytype, delta: anytype) !void { + expectEqual(true, expected - delta <= actual) catch |err| { + std.debug.print("{d} !~ {d}", .{ expected, actual }); + return err; + }; + expectEqual(true, expected + delta >= actual) catch |err| { + std.debug.print("{d} !~ {d}", .{ expected, actual }); + return err; + }; +} +pub const expectError = std.testing.expectError; +pub const expectSlice = std.testing.expectEqualSlices; +pub const expectString = std.testing.expectEqualStrings; +pub fn expectStringSlice(expected: []const []const u8, actual: [][]const u8) !void { + try expectEqual(expected.len, actual.len); + for (expected, actual) |e, a| { + try expectString(e, a); + } +} + +pub fn getRandom() std.Random.DefaultPrng { + var seed: u64 = undefined; + std.posix.getrandom(std.mem.asBytes(&seed)) catch unreachable; + return std.Random.DefaultPrng.init(seed); +} + +pub fn setup() !void { + var c = connect(.{}); + defer c.deinit(); + _ = c.exec( + \\ drop user if exists pgz_user_nopass; + \\ drop user if exists pgz_user_clear; + \\ drop user if exists pgz_user_scram_sha256; + \\ drop user if exists pgz_user_ssl; + \\ create user pgz_user_nopass; + \\ create user pgz_user_clear with password 'pgz_user_clear_pw'; + \\ create user pgz_user_scram_sha256 with password 'pgz_user_scram_sha256_pw'; + \\ create user pgz_user_ssl with password 'pgz_user_ssl_pw'; + , .{}) catch |err| try fail(c, err); + + _ = c.exec( + \\ drop table if exists simple_table; + \\ create table simple_table (value text); + , .{}) catch |err| try fail(c, err); + + _ = c.exec( + \\ drop type if exists custom_enum cascade; + \\ create type custom_enum as enum ('val1', 'val2'); + , .{}) catch |err| try fail(c, err); + + _ = c.exec( + \\ drop table if exists all_types; + \\ create table all_types ( + \\ id integer primary key, + \\ col_int2 smallint, + \\ col_int4 integer, + \\ col_int8 bigint, + \\ col_float4 float4, + \\ col_float8 float8, + \\ col_bool bool, + \\ col_text text, + \\ col_bytea bytea, + \\ col_int2_arr smallint[], + \\ col_int4_arr integer[], + \\ col_int8_arr bigint[], + \\ col_float4_arr float4[], + \\ col_float8_arr float[], + \\ col_bool_arr bool[], + \\ col_text_arr text[], + \\ col_bytea_arr bytea[], + \\ col_enum custom_enum, + \\ col_enum_arr custom_enum[], + \\ col_uuid uuid, + \\ col_uuid_arr uuid[], + \\ col_numeric numeric, + \\ col_numeric_arr numeric[], + \\ col_timestamp timestamp, + \\ col_timestamp_arr timestamp[], + \\ col_json json, + \\ col_json_arr json[], + \\ col_jsonb jsonb, + \\ col_jsonb_arr jsonb[], + \\ col_char char, + \\ col_char_arr char[], + \\ col_charn char(3), + \\ col_charn_arr char(2)[], + \\ col_timestamptz timestamptz, + \\ col_timestamptz_arr timestamptz[], + \\ col_cidr cidr, + \\ col_cidr_arr cidr[], + \\ col_inet inet, + \\ col_inet_arr inet[], + \\ col_macaddr macaddr, + \\ col_macaddr_arr macaddr[], + \\ col_macaddr8 macaddr8, + \\ col_macaddr8_arr macaddr8[] + \\ ); + , .{}) catch |err| try fail(c, err); +} + +// Dummy net.Stream, lets us setup data to be read and capture data that is written. +pub const Stream = struct { + closed: bool, + _read_index: usize, + socket: c_int = 0, + _to_read: std.ArrayList(u8), + _received: std.ArrayList(u8), + + pub fn init() *Stream { + const s = allocator.create(Stream) catch unreachable; + s.* = .{ + .closed = false, + ._read_index = 0, + ._to_read = .empty, + ._received = .empty, + }; + return s; + } + + pub fn deinit(self: *Stream) void { + self._to_read.deinit(allocator); + self._received.deinit(allocator); + allocator.destroy(self); + } + + pub fn reset(self: *Stream) void { + self._read_index = 0; + self._to_read.clearRetainingCapacity(); + self._received.clearRetainingCapacity(); + } + + pub fn received(self: *Stream) []const u8 { + return self._received.items; + } + + pub fn add(self: *Stream, value: []const u8) void { + self._to_read.appendSlice(allocator, value) catch unreachable; + } + + pub fn read(self: *Stream, buf: []u8) !usize { + std.debug.assert(!self.closed); + + const read_index = self._read_index; + const items = self._to_read.items; + + if (read_index == items.len) { + return 0; + } + if (buf.len == 0) { + return 0; + } + + // let's fragment this message + const left_to_read = items.len - read_index; + const max_can_read = if (buf.len < left_to_read) buf.len else left_to_read; + + const to_read = max_can_read; + var data = items[read_index..(read_index + to_read)]; + if (data.len > buf.len) { + // we have more data than we have space in buf (our target) + // we'll give it when it can take + data = data[0..buf.len]; + } + self._read_index = read_index + data.len; + + @memcpy(buf[0..data.len], data); + return data.len; + } + + // store messages that are written to the stream + pub fn writeAll(self: *Stream, data: []const u8) !void { + self._received.appendSlice(allocator, data) catch unreachable; + } + + pub fn close(self: *Stream) void { + self.closed = true; + } +}; + +pub fn connect(opts: anytype) Conn { + const T = @TypeOf(opts); + + var c = Conn.open(allocator, .{ + .tls = if (@hasField(T, "tls")) opts.tls else .off, + .host = if (@hasField(T, "host")) opts.host else "localhost", + .read_buffer = if (@hasField(T, "read_buffer")) opts.read_buffer else 2000, + }) catch unreachable; + + c.auth(authOpts(opts)) catch |err| { + if (c.err) |pg| { + @panic(pg.message); + } + @panic(@errorName(err)); + }; + return c; +} + +pub fn authOpts(opts: anytype) Conn.AuthOpts { + const T = @TypeOf(opts); + return .{ + .database = if (@hasField(T, "database")) opts.database else "postgres", + .username = if (@hasField(T, "username")) opts.username else "postgres", + .password = if (@hasField(T, "password")) opts.password else "postgres", + }; +} + +pub fn fail(c: Conn, err: anyerror) !void { + if (c.err) |pg_err| { + std.debug.print("PG ERROR: {s}\n", .{pg_err.message}); + } + return err; +} + +pub fn scalar(c: *Conn, sql: []const u8) i32 { + var result = c.query(sql, .{}) catch unreachable; + defer result.deinit(); + + const row = (result.nextUnsafe() catch unreachable).?; + const value = row.get(i32, 0); + result.drain() catch unreachable; + return value; +} diff --git a/zig/pg/src/types.zig b/zig/pg/src/types.zig new file mode 100644 index 0000000..388c01a --- /dev/null +++ b/zig/pg/src/types.zig @@ -0,0 +1,1643 @@ +const std = @import("std"); +const lib = @import("lib.zig"); +const buffer = @import("buffer"); + +// These are nested inside the the Types structure so that we can generate an +// oid => encoding maping. See the oidEncoding function. +pub const OID = struct { + decimal: i32, + encoded: [4]u8, + + pub fn make(decimal: i32) OID { + var encoded: [4]u8 = undefined; + std.mem.writeInt(i32, &encoded, decimal, .big); + return .{ + .decimal = decimal, + .encoded = encoded, + }; + } +}; + +pub const text_encoding = [2]u8{ 0, 0 }; +pub const binary_encoding = [2]u8{ 0, 1 }; + +pub const DynamicValue = union(enum) { + null, + bool: bool, + i64: i64, + f64: f64, + text: []const u8, +}; + +// Any "decodeKnown" you see is just an optimization to avoid extra assertions +// when decoding an individual array value. Once we know the array type, we don't +// need to assert the oid of each individual value. + +// Every supported type is here. This includes the format we want to +// encode/decode (text or binary), and the logic for encoding and decoding. + +pub const Cidr = @import("types/cidr.zig").Cidr; +pub const Numeric = @import("types/numeric.zig").Numeric; +pub const Vector = @import("types/vector.zig").Vector; +pub const Char = struct { + // A blank-padded char + pub const oid = OID.make(1042); + const encoding = &binary_encoding; + + fn encode(value: u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Char.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 1 }); // length of our data + return buf.writeByte(value); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) u8 else lib.TypeError!u8 { + lib.verifyDecodeType(fail_mode, u8, &.{Char.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return data[0]; + } + + pub fn decodeKnown(data: []const u8) u8 { + return data[0]; + } +}; + +pub const Int16 = struct { + pub const oid = OID.make(21); + const encoding = &binary_encoding; + + fn encode(value: i16, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Int16.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 2 }); // length of our data + return buf.writeIntBig(i16, value); + } + + fn encodeUnsigned(value: u16, buf: *buffer.Buffer, format_pos: usize) !void { + if (value > 32767) return error.UnsignedIntWouldBeTruncated; + return Int16.encode(@intCast(value), buf, format_pos); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) i16 else lib.TypeError!i16 { + lib.verifyDecodeType(fail_mode, i16, &.{Int16.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return Int16.decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) i16 { + return std.mem.readInt(i16, data[0..2], .big); + } +}; + +pub const Int32 = struct { + pub const oid = OID.make(23); + const encoding = &binary_encoding; + + fn encode(value: i32, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Int32.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 4 }); // length of our data + return buf.writeIntBig(i32, value); + } + + fn encodeUnsigned(value: u32, buf: *buffer.Buffer, format_pos: usize) !void { + if (value > 2147483647) return error.UnsignedIntWouldBeTruncated; + return Int32.encode(@intCast(value), buf, format_pos); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) i32 else lib.TypeError!i32 { + lib.verifyDecodeType(fail_mode, i32, &.{ Int32.oid.decimal, Xid.oid.decimal }, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return Int32.decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) i32 { + return std.mem.readInt(i32, data[0..4], .big); + } +}; + +pub const Int64 = struct { + pub const oid = OID.make(20); + const encoding = &binary_encoding; + + fn encode(value: i64, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Int64.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 8 }); // length of our data + return buf.writeIntBig(i64, value); + } + + fn encodeUnsigned(value: u64, buf: *buffer.Buffer, format_pos: usize) !void { + if (value > 9223372036854775807) return error.UnsignedIntWouldBeTruncated; + return Int64.encode(@intCast(value), buf, format_pos); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) i64 else lib.TypeError!i64 { + switch (data_oid) { + Timestamp.oid.decimal, TimestampTz.oid.decimal => return Timestamp.decodeKnown(data), + else => { + lib.verifyDecodeType(fail_mode, i64, &.{ Int64.oid.decimal, PgLSN.oid.decimal, Xid8.oid.decimal }, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return Int64.decodeKnown(data); + }, + } + } + + pub fn decodeKnown(data: []const u8) i64 { + return std.mem.readInt(i64, data[0..8], .big); + } +}; + +pub const Timestamp = struct { + pub const oid = OID.make(1114); + const encoding = &binary_encoding; + const us_from_epoch_to_y2k = 946_684_800_000_000; + + fn encode(value: i64, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Timestamp.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 8 }); // length of our data + return buf.writeIntBig(i64, value - us_from_epoch_to_y2k); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) i64 else lib.TypeError!i64 { + lib.verifyDecodeType(fail_mode, i64, &.{ Timestamp.oid.decimal, TimestampTz.oid.decimal }, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return std.mem.readInt(i64, data[0..8], .big) + us_from_epoch_to_y2k; + } + + pub fn decodeKnown(data: []const u8) i64 { + return std.mem.readInt(i64, data[0..8], .big) + us_from_epoch_to_y2k; + } +}; + +pub const TimestampTz = struct { + pub const oid = OID.make(1184); + const encoding = &binary_encoding; +}; + +pub const Float32 = struct { + pub const oid = OID.make(700); + const encoding = &binary_encoding; + + fn encode(value: f32, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Float32.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 4 }); // length of our data + const tmp: *i32 = @ptrCast(@constCast(&value)); + return buf.writeIntBig(i32, tmp.*); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) f32 else lib.TypeError!f32 { + lib.verifyDecodeType(fail_mode, f32, &.{Float32.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return Float32.decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) f32 { + const n = std.mem.readInt(i32, data[0..4], .big); + const tmp: *f32 = @ptrCast(@constCast(&n)); + return tmp.*; + } +}; + +pub const Float64 = struct { + pub const oid = OID.make(701); + const encoding = &binary_encoding; + + fn encode(value: f64, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Float64.encoding, format_pos); + + try buf.write(&.{ 0, 0, 0, 8 }); // length of our data + // not sure if this is the best option... + const tmp: *i64 = @ptrCast(@constCast(&value)); + return buf.writeIntBig(i64, tmp.*); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) f64 else lib.TypeError!f64 { + switch (data_oid) { + Numeric.oid.decimal => { + const numeric = Numeric.decode(fail_mode, data, data_oid); + if (comptime fail_mode == .unsafe) { + return numeric.toFloat(); + } + return (try numeric).toFloat(); + }, + else => { + lib.verifyDecodeType(fail_mode, f64, &.{Float64.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return Float64.decodeKnown(data); + }, + } + } + + pub fn decodeKnown(data: []const u8) f64 { + const n = std.mem.readInt(i64, data[0..8], .big); + const tmp: *f64 = @ptrCast(@constCast(&n)); + return tmp.*; + } +}; + +pub const Bool = struct { + pub const oid = OID.make(16); + const encoding = &binary_encoding; + + fn encode(value: bool, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Bool.encoding, format_pos); + try buf.write(&.{ 0, 0, 0, 1 }); // length of our data + return buf.writeByte(if (value) 1 else 0); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) bool else lib.TypeError!bool { + lib.verifyDecodeType(fail_mode, bool, &.{Bool.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) bool { + return data[0] == 1; + } +}; + +pub const String = struct { + pub const oid = OID.make(25); + // https://www.postgresql.org/message-id/CAMovtNoHFod2jMAKQjjxv209PCTJx5Kc66anwWvX0mEiaXwgmA%40mail.gmail.com + // says using the text format for text-like things is faster. There was + // some other threads that discussed solutions, but it isn't clear if it was + // ever fixed. + const encoding = &text_encoding; + + fn encode(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(String.encoding, format_pos); + var view = try buf.skip(4 + value.len); + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } +}; + +pub const Bytea = struct { + pub const oid = OID.make(17); + const encoding = &binary_encoding; + + fn encode(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(Bytea.encoding, format_pos); + var view = try buf.skip(4 + value.len); + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } + + pub fn decode(data: []const u8, data_oid: i32) []const u8 { + switch (data_oid) { + JSONB.oid.decimal => return JSONB.decodeKnown(data), + else => return data, + } + } + + pub fn decodeKnown(data: []const u8) []const u8 { + return data; + } + + pub fn decodeKnownMutable(data: []const u8) []u8 { + // we know the underlying []u8 is mutable, it comes from our Reader + return @constCast(data); + } +}; + +pub const UUID = struct { + pub const oid = OID.make(2950); + const encoding = &binary_encoding; + + fn encode(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(UUID.encoding, format_pos); + var view = try buf.skip(20); + view.write(&.{ 0, 0, 0, 16 }); + switch (value.len) { + 16 => view.write(value), + 36 => view.write(&(try UUID.toBytes(value))), + else => return error.InvalidUUID, + } + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) []const u8 else lib.TypeError![]const u8 { + lib.verifyDecodeType(fail_mode, []const u8, &.{UUID.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return data; + } + + const hex = "0123456789abcdef"; + const encoded_pos = [16]u8{ 0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34 }; + const hex_to_nibble = [256]u8{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + }; + + pub fn toString(uuid: []const u8) ![36]u8 { + if (uuid.len != 16) { + return error.InvalidUUID; + } + + var out: [36]u8 = undefined; + out[8] = '-'; + out[13] = '-'; + out[18] = '-'; + out[23] = '-'; + + inline for (encoded_pos, 0..) |i, j| { + out[i + 0] = hex[uuid[j] >> 4]; + out[i + 1] = hex[uuid[j] & 0x0f]; + } + return out; + } + + pub fn toBytes(str: []const u8) ![16]u8 { + if (str.len != 36 or str[8] != '-' or str[13] != '-' or str[18] != '-' or str[23] != '-') { + return error.InvalidUUID; + } + + var out: [16]u8 = undefined; + inline for (encoded_pos, 0..) |i, j| { + const hi = hex_to_nibble[str[i + 0]]; + const lo = hex_to_nibble[str[i + 1]]; + if (hi == 0xff or lo == 0xff) { + return error.InvalidUUID; + } + out[j] = hi << 4 | lo; + } + return out; + } +}; + +pub const PgLSN = struct { + pub const oid = OID.make(3220); + const encoding = &binary_encoding; +}; + +pub const Xid = struct { + pub const oid = OID.make(28); + const encoding = &binary_encoding; +}; + +pub const Xid8 = struct { + pub const oid = OID.make(5069); + const encoding = &binary_encoding; +}; + +pub const MacAddr = struct { + pub const oid = OID.make(829); + const encoding = &binary_encoding; + + fn encode(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + if (value.len != 6) { + // assume this is a text representation + return String.encode(value, buf, format_pos); + } + buf.writeAt(MacAddr.encoding, format_pos); + var view = try buf.skip(4 + value.len); + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } +}; + +pub const MacAddr8 = struct { + pub const oid = OID.make(774); + const encoding = &binary_encoding; + + fn encode(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + if (value.len != 8) { + // assume this is a text representation + return String.encode(value, buf, format_pos); + } + buf.writeAt(MacAddr8.encoding, format_pos); + var view = try buf.skip(4 + value.len); + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } +}; + +pub const JSON = struct { + pub const oid = OID.make(114); + const encoding = &binary_encoding; + + fn encodeBytes(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(JSON.encoding, format_pos); + var view = try buf.skip(4 + value.len); + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } + + fn encode(value: anytype, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(JSON.encoding, format_pos); + const state = try Encode.variableLengthStart(buf); + try std.json.Stringify.value(value, .{}, &buf.interface); + Encode.variableLengthFill(buf, state); + } +}; + +pub const JSONB = struct { + pub const oid = OID.make(3802); + const encoding = &binary_encoding; + + fn encodeBytes(value: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(JSONB.encoding, format_pos); + var view = try buf.skip(5 + value.len); + // + 1 for the version + view.writeIntBig(i32, @intCast(value.len + 1)); + view.writeByte(1); // jsonb version + view.write(value); + } + + fn encode(value: anytype, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(JSON.encoding, format_pos); + const state = try Encode.variableLengthStart(buf); + try buf.writeByte(1); // jsonb version + try std.json.Stringify.value(value, .{}, &buf.interface); + Encode.variableLengthFill(buf, state); + } + + fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) []const u8 else lib.TypeError![]const u8 { + lib.verifyDecodeType(fail_mode, []const u8, &.{JSONB.oid.decimal}, data_oid) catch |err| { + if (comptime fail_mode == .unsafe) unreachable; + return err; + }; + return JSONB.decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) []const u8 { + return data[1..]; + } + + pub fn decodeKnownMutable(data: []const u8) []u8 { + // we know the underlying []u8 is mutable, it comes from our Reader + return @constCast(data[1..]); + } +}; + +pub const Int16Array = struct { + pub const oid = OID.make(1005); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Int16.oid.encoded, oid_pos); + return Encode.writeIntArray(i16, values, buf); + } + + fn encodeUnsigned(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + for (values) |v| { + if (v > 32767) return error.UnsignedIntWouldBeTruncated; + } + buf.writeAt(&Int16.oid.encoded, oid_pos); + return Encode.writeIntArray(i16, values, buf); + } +}; + +pub const Int32Array = struct { + pub const oid = OID.make(1007); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Int32.oid.encoded, oid_pos); + return Encode.writeIntArray(i32, values, buf); + } + + fn encodeUnsigned(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + for (values) |v| { + if (v > 2147483647) return error.UnsignedIntWouldBeTruncated; + } + buf.writeAt(&Int32.oid.encoded, oid_pos); + return Encode.writeIntArray(i32, values, buf); + } +}; + +pub const Int64Array = struct { + pub const oid = OID.make(1016); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Int64.oid.encoded, oid_pos); + return Encode.writeIntArray(i64, values, buf); + } + + fn encodeUnsigned(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + for (values) |v| { + if (v > 9223372036854775807) return error.UnsignedIntWouldBeTruncated; + } + buf.writeAt(&Int64.oid.encoded, oid_pos); + return Encode.writeIntArray(i64, values, buf); + } +}; + +pub const TimestampArray = struct { + pub const oid = OID.make(1115); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Timestamp.oid.encoded, oid_pos); + try writeTimestampArray(values, buf); + } +}; + +pub const TimestampTzArray = struct { + pub const oid = OID.make(1185); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&TimestampTz.oid.encoded, oid_pos); + try writeTimestampArray(values, buf); + } +}; + +fn writeTimestampArray(values: anytype, buf: *buffer.Buffer) !void { + const us_from_epoch_to_y2k = 946_684_800_000_000; + + // at most, every value is 12 bytes, 4 byte length + 8 byte value + var view = try buf.skip(12 * values.len); + + const nullables = @typeInfo(@TypeOf(values)).pointer.child == ?i64; + var null_count: usize = 0; + + for (values) |value| { + var v: i64 = undefined; + if (comptime nullables) { + v = value orelse { + null_count += 1; + view.write(&.{ 255, 255, 255, 255 }); // null, + continue; + }; + } else v = value; + view.write(&.{ 0, 0, 0, 8 }); // length of value + view.writeIntBig(i64, v - us_from_epoch_to_y2k); + } + + if (comptime nullables) { + buf.truncate(null_count * 8); + } +} + +pub const Float32Array = struct { + pub const oid = OID.make(1021); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Float32.oid.encoded, oid_pos); + return writeFloatArray(f32, i32, values, buf); + } +}; + +pub const Float64Array = struct { + pub const oid = OID.make(1022); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Float64.oid.encoded, oid_pos); + return writeFloatArray(f64, i64, values, buf); + } +}; + +fn writeFloatArray(comptime F: type, comptime I: type, values: anytype, buf: *buffer.Buffer) !void { + // The most space this can take, + 4 for the length; + var view = try buf.skip((@sizeOf(I) + 4) * values.len); + + const nullables = @typeInfo(@typeInfo(@TypeOf(values)).pointer.child) == .optional; + var null_count: usize = 0; + + for (values) |value| { + var v: F = undefined; + if (comptime nullables) { + v = value orelse { + null_count += 1; + view.write(&.{ 255, 255, 255, 255 }); // null, + continue; + }; + } else v = value; + + const tmp: *I = @ptrCast(@constCast(&v)); + view.write(&.{ 0, 0, 0, @sizeOf(I) }); //length + view.writeIntBig(I, tmp.*); + } + + if (comptime nullables) { + buf.truncate(null_count * @sizeOf(I)); + } +} + +pub const BoolArray = struct { + pub const oid = OID.make(1000); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Bool.oid.encoded, oid_pos); + // at most every value takes 5 bytes, 4 for the length, 1 for the value + var view = try buf.skip(5 * values.len); + + const nullables = @typeInfo(@typeInfo(@TypeOf(values)).pointer.child) == .optional; + var null_count: usize = 0; + + for (values) |value| { + var v: bool = undefined; + if (comptime nullables) { + v = value orelse { + null_count += 1; + view.write(&.{ 255, 255, 255, 255 }); // null, + continue; + }; + } else v = value; + + // each value is prefixed with a 4 byte length + if (v) { + view.write(&.{ 0, 0, 0, 1, 1 }); + } else { + view.write(&.{ 0, 0, 0, 1, 0 }); + } + } + + if (comptime nullables) { + buf.truncate(null_count); + } + } +}; + +pub const NumericArray = struct { + pub const oid = OID.make(1231); + const encoding = &binary_encoding; + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Numeric.oid.encoded, oid_pos); + + for (values) |value| { + try Numeric.encodeBuf(value, buf); + } + } +}; + +pub const CidrArray = struct { + pub const oid = OID.make(651); + pub const inet_oid = OID.make(1041); + const encoding = &binary_encoding; +}; + +pub const MacAddrArray = struct { + pub const oid = OID.make(1040); + const encoding = &binary_encoding; + + fn encode(values: []const []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + // This has challenges. Do we have a binary representation or a text representation? + // Or maybe we have a mix (maybe we shouldn't support that)? + // We handle this with UUID by converting the text representation to binary + // but it's harder wit MacAddr because it supports 7 different text representations + // and I don't really want this library to become a text parsing library which attempts + // to mimic what PostgreSQL does. + // So we're going to send a text-encoded array with text values, which emans + // we need to convert any binary representation to text (which is a lot easier). + + // The worst-case scenario is that each value takes 17 bytes. This is the + // most verbose text-encoded value. When we encode a binary value as text + // we'll use the most compact (12 bytes), but we might be given a 17-byte + // text-encoded value, which we'll write as-is + var l: usize = 0; + for (values) |v| { + // binary values will be encoded in a 12-byte text representation + l += if (v.len == 6) 12 else v.len; + } + + return Encode.writeTextEncodedArray(values, l, buf, format_pos, MacAddrArray.writeOneAsText); + } + + fn writeOneAsText(value: []const u8, buf: *buffer.Buffer) void { + if (value.len == 6) { + buf.interface.print("{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}", .{ value[0], value[1], value[2], value[3], value[4], value[5] }) catch unreachable; + } else { + buf.writeAssumeCapacity(value); + } + } +}; + +pub const MacAddr8Array = struct { + pub const oid = OID.make(775); + const encoding = &binary_encoding; + + fn encode(values: []const []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + // See comments in MacAddrArray.encode + var l: usize = 0; + for (values) |v| { + // binary values will be encoded in a 16-byte text representation + l += if (v.len == 8) 16 else v.len; + } + + return Encode.writeTextEncodedArray(values, l, buf, format_pos, MacAddr8Array.writeOneAsText); + } + + fn writeOneAsText(value: []const u8, buf: *buffer.Buffer) void { + if (value.len == 8) { + buf.interface.print("{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}", .{ value[0], value[1], value[2], value[3], value[4], value[5], value[6], value[7] }) catch unreachable; + } else { + buf.writeAssumeCapacity(value); + } + } +}; + +pub const ByteaArray = struct { + pub const oid = OID.make(1001); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Bytea.oid.encoded, oid_pos); + return Encode.writeByteArray(values, buf); + } +}; + +pub const StringArray = struct { + pub const oid = OID.make(1009); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&String.oid.encoded, oid_pos); + return Encode.writeByteArray(values, buf); + } + + fn encodeEnum(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&String.oid.encoded, oid_pos); + for (values.*) |value| { + const str = @tagName(value); + try buf.writeIntBig(i32, @intCast(str.len)); + try buf.write(str); + } + } +}; + +pub const UUIDArray = struct { + pub const oid = OID.make(2951); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&UUID.oid.encoded, oid_pos); + + const T = @typeInfo(@TypeOf(values)).pointer.child; + const TT = switch (@typeInfo(T)) { + .optional => |opt| opt.child, + else => T, + }; + const nullables = @typeInfo(T) == .optional; + + var null_count: usize = 0; + + // at most every value is 20 bytes, 4 byte length + 16 byte value + var view = try buf.skip(20 * values.len); + for (values) |value| { + var v: TT = undefined; + if (comptime nullables) { + v = value orelse { + null_count += 1; + view.write(&.{ 255, 255, 255, 255 }); + continue; + }; + } else v = value; + + view.write(&.{ 0, 0, 0, 16 }); // length of value + switch (v.len) { + 16 => view.write(v), + 36 => view.write(&(try UUID.toBytes(v))), + else => return error.InvalidUUID, + } + } + + if (comptime nullables) { + buf.truncate(null_count * 16); + } + } +}; + +pub const JSONArray = struct { + pub const oid = OID.make(199); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&JSON.oid.encoded, oid_pos); + return Encode.writeByteArray(values, buf); + } +}; + +pub const JSONBArray = struct { + pub const oid = OID.make(3807); + const encoding = &binary_encoding; + + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&JSONB.oid.encoded, oid_pos); + if (@typeInfo(@typeInfo(@TypeOf(values)).pointer.child) == .optional) { + return encodeNullables(values, buf); + } + + // every value has a 5 byte prefix, a 4 byte length and a 1 byte version + var len = values.len * 5; + for (values) |value| { + len += value.len; + } + + var view = try buf.skip(len); + for (values) |value| { + // + 1 for the version + view.writeIntBig(i32, @intCast(value.len + 1)); + view.writeByte(1); // version + view.write(value); + } + } + + fn encodeNullables(values: []const ?[]const u8, buf: *buffer.Buffer) !void { + // every value has a 5 byte prefix, a 4 byte length and a 1 byte version + var len = values.len * 5; + for (values) |value| { + if (value) |v| { + len += v.len; + } + } + + var view = try buf.skip(len); + for (values) |value| { + if (value) |v| { + // + 1 for the version + view.writeIntBig(i32, @intCast(v.len + 1)); + view.writeByte(1); // version + view.write(v); + } else { + view.write(&.{ 255, 255, 255, 255 }); // null, + } + } + } +}; + +pub const CharArray = struct { + pub const oid = OID.make(1014); + const encoding = &binary_encoding; + + // This is for a char[] bound to a []u8 + fn encodeOne(values: []const u8, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Char.oid.encoded, oid_pos); + + // every value has a 5 byte prefix, a 4 byte length and a 1 byte char + const len = values.len * 5; + var view = try buf.skip(len); + for (values) |value| { + view.write(&.{ 0, 0, 0, 1 }); + view.writeByte(value); + } + } + + // This is for a char[] bound to a [][]u8 + fn encode(values: anytype, buf: *buffer.Buffer, oid_pos: usize) !void { + buf.writeAt(&Char.oid.encoded, oid_pos); + return Encode.writeByteArray(values, buf); + } +}; + +// Return the encoding we want PG to use for a particular OID +fn resultEncodingFor(oid: i32) *const [2]u8 { + inline for (@typeInfo(@This()).@"struct".decls) |decl| { + const S = @field(@This(), decl.name); + if (@typeInfo(@TypeOf(S)) == .type and @hasField(S, "oid")) { + if (oid == S.oid.decimal) { + return S.encoding; + } + } + } + // default to text encoding + return &binary_encoding; +} + +pub const Encode = struct { + // helpers for encoding data (or part of the data) + pub fn writeIntArray(comptime T: type, values: anytype, buf: *buffer.Buffer) !void { + const size = @sizeOf(T); + // at most, every value is a 4 byte length + the size of the underlying it + var view = try buf.skip((size + 4) * values.len); + + const nullables = @typeInfo(@typeInfo(@TypeOf(values)).pointer.child) == .optional; + var null_count: usize = 0; + + var value_len: [4]u8 = undefined; + std.mem.writeInt(i32, &value_len, @intCast(size), .big); + + for (values) |value| { + var v: T = undefined; + if (comptime nullables) { + v = value orelse { + null_count += 1; + view.write(&.{ 255, 255, 255, 255 }); // null, + continue; + }; + } else v = value; + view.write(&value_len); + view.writeIntBig(T, v); + } + + if (comptime nullables) { + buf.truncate(null_count * size); + } + } + + pub fn writeByteArray(values: anytype, buf: *buffer.Buffer) !void { + if (@typeInfo(@typeInfo(@TypeOf(values)).pointer.child) == .optional) { + return writeNullableByteArray(values, buf); + } + // each value has a 4 byte length prefix + var len = values.len * 4; + for (values) |value| { + len += value.len; + } + + var view = try buf.skip(len); + for (values) |value| { + view.writeIntBig(i32, @intCast(value.len)); + view.write(value); + } + } + + pub fn writeNullableByteArray(values: []const ?[]const u8, buf: *buffer.Buffer) !void { + // each value has a 4 byte length prefix + var len = values.len * 4; + for (values) |value| { + if (value) |v| { + len += v.len; + } + } + + var view = try buf.skip(len); + for (values) |value| { + if (value) |v| { + view.writeIntBig(i32, @intCast(v.len)); + view.write(v); + } else { + view.write(&.{ 255, 255, 255, 255 }); + } + } + } + + pub fn variableLengthStart(buf: *buffer.Buffer) !usize { + try buf.write(&.{ 0, 0, 0, 0 }); // length placeholder + return buf.len(); + } + + pub fn variableLengthFill(buf: *buffer.Buffer, pos: usize) void { + const len = buf.len() - pos; + var encoded_len: [4]u8 = undefined; + std.mem.writeInt(i32, &encoded_len, @intCast(len), .big); + buf.writeAt(&encoded_len, pos - 4); + } + + pub fn writeTextEncodedArray(values: []const []const u8, values_len: usize, buf: *buffer.Buffer, format_pos: usize, writeFn: *const fn ([]const u8, *buffer.Buffer) void) !void { + buf.writeAt(&text_encoding, format_pos); + if (values.len == 0) { + // empty array, with length prefix + return buf.write(&.{ 0, 0, 0, 2, '{', '}' }); + } + + // We're relying one our caller to give us an accurate values_len + // The total value length will be: + // 2 + values_len + values.len + // {} delimiter + given to us + ',' delimiter between values + const max_len = 2 + values_len + values.len; + try buf.ensureUnusedCapacity(max_len); + + // our max_len is just an estimate, we'll get the actual length and fill + // it in later, for now, we skip the length + var view = try buf.skip(4); + const start = buf.len(); + buf.writeByteAssumeCapacity('{'); + for (values) |value| { + writeFn(value, buf); + buf.writeByteAssumeCapacity(','); + } + + // strip out last comma + buf.truncate(1); + buf.writeByteAssumeCapacity('}'); + // -6 since the oid and the + view.writeIntBig(i32, @intCast(buf.len() - start)); + } + + // Fairly special case for text-encoded arrays where we _always_ want to quote the value + // but don't need to escape. This idea is taken from Java's PostgreSQL JDBC driver + // specificallly for dealing with possible scientific notation in float/numeric text values + pub fn writeTextEncodedEscapedArray(values: []const []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + var l: usize = 0; + for (values) |v| { + // +2 for the quotes around the value we'll need + l += v.len + 2; + } + return Encode.writeTextEncodedArray(values, l, buf, format_pos, writeQuotedValue); + } + + fn writeQuotedValue(value: []const u8, buf: *buffer.Buffer) void { + buf.writeByteAssumeCapacity('"'); + buf.writeAssumeCapacity(value); + buf.writeByteAssumeCapacity('"'); + } + + // Fairly special case for text-encoded arrays where we _always_ want to quote the value + // but don't need to escape. This idea is taken from Java's PostgreSQL JDBC driver + // specificallly for dealing with possible scientific notation in float/numeric text values + pub fn writeTextEncodedRawArray(values: []const []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + var l: usize = 0; + for (values) |v| { + l += v.len; + } + return Encode.writeTextEncodedArray(values, l, buf, format_pos, writeRawValue); + } + + fn writeRawValue(value: []const u8, buf: *buffer.Buffer) void { + buf.writeAssumeCapacity(value); + } + + pub fn writeTextEncodedCharArray(values: []const u8, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(&text_encoding, format_pos); + if (values.len == 0) { + // empty array, with length prefix + return buf.write(&.{ 0, 0, 0, 2, '{', '}' }); + } + + // 6 = 4-byte length + opening brace + closing brace + // v.len * 5 is the max guess about how much room we'll need. 1 byte + // per character, delimiter + double quotes + escape + const estimated_len: usize = 6 + values.len * 5; + try buf.ensureUnusedCapacity(estimated_len); + + // skip the length, which we'll fill later + var view = try buf.skip(4); + const start = buf.len(); + + // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO + buf.writeByteAssumeCapacity('{'); + for (values) |c| { + if (c == '"' or c == '\\') { + buf.writeAssumeCapacity("\"\\"); + buf.writeByteAssumeCapacity(c); + buf.writeByteAssumeCapacity('"'); + } else if (std.ascii.isWhitespace(c) or c == ',' or c == '{' or c == '}' or c == '\\') { + buf.writeByteAssumeCapacity('"'); + buf.writeByteAssumeCapacity(c); + buf.writeByteAssumeCapacity('"'); + } else { + buf.writeByteAssumeCapacity(c); + } + buf.writeByteAssumeCapacity(','); + } + + // strip out last comma + buf.truncate(1); + buf.writeByteAssumeCapacity('}'); + view.writeIntBig(i32, @intCast(buf.len() - start)); + } +}; + +pub fn oidToString(oid: i32) []const u8 { + switch (oid) { + 16 => return "T_bool", + 17 => return "T_bytea", + 18 => return "T_char", + 19 => return "T_name", + 20 => return "T_int8", + 21 => return "T_int2", + 22 => return "T_int2vector", + 23 => return "T_int4", + 24 => return "T_regproc", + 25 => return "T_text", + 26 => return "T_oid", + 27 => return "T_tid", + 28 => return "T_xid", + 29 => return "T_cid", + 30 => return "T_oidvector", + 32 => return "T_pg_ddl_command", + 71 => return "T_pg_type", + 75 => return "T_pg_attribute", + 81 => return "T_pg_proc", + 83 => return "T_pg_class", + 114 => return "T_json", + 142 => return "T_xml", + 143 => return "T__xml", + 194 => return "T_pg_node_tree", + 199 => return "T__json", + 210 => return "T_smgr", + 325 => return "T_index_am_handler", + 600 => return "T_point", + 601 => return "T_lseg", + 602 => return "T_path", + 603 => return "T_box", + 604 => return "T_polygon", + 628 => return "T_line", + 629 => return "T__line", + 650 => return "T_cidr", + 651 => return "T__cidr", + 700 => return "T_float4", + 701 => return "T_float8", + 702 => return "T_abstime", + 703 => return "T_reltime", + 704 => return "T_tinterval", + 705 => return "T_unknown", + 718 => return "T_circle", + 719 => return "T__circle", + 790 => return "T_money", + 791 => return "T__money", + 829 => return "T_macaddr", + 869 => return "T_inet", + 1000 => return "T__bool", + 1001 => return "T__bytea", + 1002 => return "T__char", + 1003 => return "T__name", + 1005 => return "T__int2", + 1006 => return "T__int2vector", + 1007 => return "T__int4", + 1008 => return "T__regproc", + 1009 => return "T__text", + 1010 => return "T__tid", + 1011 => return "T__xid", + 1012 => return "T__cid", + 1013 => return "T__oidvector", + 1014 => return "T__bpchar", + 1015 => return "T__varchar", + 1016 => return "T__int8", + 1017 => return "T__point", + 1018 => return "T__lseg", + 1019 => return "T__path", + 1020 => return "T__box", + 1021 => return "T__float4", + 1022 => return "T__float8", + 1023 => return "T__abstime", + 1024 => return "T__reltime", + 1025 => return "T__tinterval", + 1027 => return "T__polygon", + 1028 => return "T__oid", + 1033 => return "T_aclitem", + 1034 => return "T__aclitem", + 1040 => return "T__macaddr", + 1041 => return "T__inet", + 1042 => return "T_bpchar", + 1043 => return "T_varchar", + 1082 => return "T_date", + 1083 => return "T_time", + 1114 => return "T_timestamp", + 1115 => return "T__timestamp", + 1182 => return "T__date", + 1183 => return "T__time", + 1184 => return "T_timestamptz", + 1185 => return "T__timestamptz", + 1186 => return "T_interval", + 1187 => return "T__interval", + 1231 => return "T__numeric", + 1248 => return "T_pg_database", + 1263 => return "T__cstring", + 1266 => return "T_timetz", + 1270 => return "T__timetz", + 1560 => return "T_bit", + 1561 => return "T__bit", + 1562 => return "T_varbit", + 1563 => return "T__varbit", + 1700 => return "T_numeric", + 1790 => return "T_refcursor", + 2201 => return "T__refcursor", + 2202 => return "T_regprocedure", + 2203 => return "T_regoper", + 2204 => return "T_regoperator", + 2205 => return "T_regclass", + 2206 => return "T_regtype", + 2207 => return "T__regprocedure", + 2208 => return "T__regoper", + 2209 => return "T__regoperator", + 2210 => return "T__regclass", + 2211 => return "T__regtype", + 2249 => return "T_record", + 2275 => return "T_cstring", + 2276 => return "T_any", + 2277 => return "T_anyarray", + 2278 => return "T_void", + 2279 => return "T_trigger", + 2280 => return "T_language_handler", + 2281 => return "T_internal", + 2282 => return "T_opaque", + 2283 => return "T_anyelement", + 2287 => return "T__record", + 2776 => return "T_anynonarray", + 2842 => return "T_pg_authid", + 2843 => return "T_pg_auth_members", + 2949 => return "T__txid_snapshot", + 2950 => return "T_uuid", + 2951 => return "T__uuid", + 2970 => return "T_txid_snapshot", + 3115 => return "T_fdw_handler", + 3220 => return "T_pg_lsn", + 3221 => return "T__pg_lsn", + 3310 => return "T_tsm_handler", + 3500 => return "T_anyenum", + 3614 => return "T_tsvector", + 3615 => return "T_tsquery", + 3642 => return "T_gtsvector", + 3643 => return "T__tsvector", + 3644 => return "T__gtsvector", + 3645 => return "T__tsquery", + 3734 => return "T_regconfig", + 3735 => return "T__regconfig", + 3769 => return "T_regdictionary", + 3770 => return "T__regdictionary", + 3802 => return "T_jsonb", + 3807 => return "T__jsonb", + 3831 => return "T_anyrange", + 3838 => return "T_event_trigger", + 3904 => return "T_int4range", + 3905 => return "T__int4range", + 3906 => return "T_numrange", + 3907 => return "T__numrange", + 3908 => return "T_tsrange", + 3909 => return "T__tsrange", + 3910 => return "T_tstzrange", + 3911 => return "T__tstzrange", + 3912 => return "T_daterange", + 3913 => return "T__daterange", + 3926 => return "T_int8range", + 3927 => return "T__int8range", + 4066 => return "T_pg_shseclabel", + 4089 => return "T_regnamespace", + 4090 => return "T__regnamespace", + 4096 => return "T_regrole", + 4097 => return "T__regrole", + else => return "unknown", + } +} + +// The oid is what PG is expecting. In some cases, we'll use that to figure +// out what to do. +pub fn bindValue(comptime T: type, oid: i32, value: anytype, buf: *buffer.Buffer, format_pos: usize) !void { + switch (@typeInfo(T)) { + .null => { + // type can stay 0 (text) + // special length of -1 indicates null, no other data for this value + return buf.write(&.{ 255, 255, 255, 255 }); + }, + .comptime_int => switch (oid) { + Int16.oid.decimal => { + if (value > 32767 or value < -32768) return error.IntWontFit; + return Int16.encode(@intCast(value), buf, format_pos); + }, + Int32.oid.decimal => { + if (value > 2147483647 or value < -2147483648) return error.IntWontFit; + return Int32.encode(@intCast(value), buf, format_pos); + }, + Timestamp.oid.decimal, TimestampTz.oid.decimal => return Timestamp.encode(@intCast(value), buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(@as(f64, @floatFromInt(value)), buf, format_pos), + Char.oid.decimal => { + if (value > 255 or value < 0) return error.IntWontFit; + return Char.encode(@intCast(value), buf, format_pos); + }, + Int64.oid.decimal, PgLSN.oid.decimal, Xid8.oid.decimal => return Int64.encode(@intCast(value), buf, format_pos), + else => return error.BindWrongType, + }, + .int => switch (oid) { + Int16.oid.decimal => { + if (value > 32767 or value < -32768) return error.IntWontFit; + return Int16.encode(@intCast(value), buf, format_pos); + }, + Int32.oid.decimal, Xid.oid.decimal => { + if (value > 2147483647 or value < -2147483648) return error.IntWontFit; + return Int32.encode(@intCast(value), buf, format_pos); + }, + Timestamp.oid.decimal, TimestampTz.oid.decimal => return Timestamp.encode(@intCast(value), buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(@as(f64, @floatFromInt(value)), buf, format_pos), + Char.oid.decimal => { + if (value > 255 or value < 0) return error.IntWontFit; + return Char.encode(@intCast(value), buf, format_pos); + }, + Int64.oid.decimal, PgLSN.oid.decimal, Xid8.oid.decimal => { + if (value > 9223372036854775807 or value < -9223372036854775808) { + return error.IntWontFit; + } + return Int64.encode(@intCast(value), buf, format_pos); + }, + else => return error.BindWrongType, + }, + .comptime_float => switch (oid) { + Float64.oid.decimal => return Float64.encode(@floatCast(value), buf, format_pos), + Float32.oid.decimal => return Float32.encode(@floatCast(value), buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(value, buf, format_pos), + else => return error.BindWrongType, + }, + .float => switch (oid) { + Float64.oid.decimal => return Float64.encode(@floatCast(value), buf, format_pos), + Float32.oid.decimal => return Float32.encode(@floatCast(value), buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(value, buf, format_pos), + else => return error.BindWrongType, + }, + .bool => switch (oid) { + Bool.oid.decimal => return Bool.encode(value, buf, format_pos), + else => return error.BindWrongType, + }, + .pointer => |ptr| switch (ptr.size) { + .slice => { + if (ptr.is_const) { + return bindSlice(oid, @as([]const ptr.child, value), buf, format_pos); + } else { + return bindSlice(oid, @as([]ptr.child, value), buf, format_pos); + } + }, + .one => switch (@typeInfo(ptr.child)) { + .array => { + const Slice = []const std.meta.Elem(ptr.child); + return bindSlice(oid, @as(Slice, value), buf, format_pos); + }, + .@"struct" => switch (oid) { + JSON.oid.decimal => return JSON.encode(value, buf, format_pos), + JSONB.oid.decimal => return JSONB.encode(value, buf, format_pos), + else => { + if (ptr.child != lib.Binary) { + return error.CannotBindStruct; + } + buf.writeAt(&binary_encoding, format_pos); + try buf.writeIntBig(i32, @intCast(value.data.len)); + return buf.write(value.data); + }, + }, + else => compileHaltBindError(T), + }, + else => compileHaltBindError(T), + }, + .array => return bindValue(@TypeOf(&value), oid, &value, buf, format_pos), + .@"struct" => return bindValue(@TypeOf(&value), oid, &value, buf, format_pos), + .optional => |opt| { + if (value) |v| { + return bindValue(opt.child, oid, v, buf, format_pos); + } + // null + return buf.write(&.{ 255, 255, 255, 255 }); + }, + .@"enum", .enum_literal => return String.encode(@tagName(value), buf, format_pos), + else => compileHaltBindError(T), + } +} + +pub fn bindDynamicValue(oid: i32, value: DynamicValue, buf: *buffer.Buffer, format_pos: usize) !void { + switch (value) { + .null => return buf.write(&.{ 255, 255, 255, 255 }), + .bool => |v| switch (oid) { + Bool.oid.decimal => return Bool.encode(v, buf, format_pos), + else => return bindSlice(oid, if (v) "true" else "false", buf, format_pos), + }, + .i64 => |v| switch (oid) { + Int16.oid.decimal => { + if (v > 32767 or v < -32768) return error.IntWontFit; + return Int16.encode(@intCast(v), buf, format_pos); + }, + Int32.oid.decimal, Xid.oid.decimal => { + if (v > 2147483647 or v < -2147483648) return error.IntWontFit; + return Int32.encode(@intCast(v), buf, format_pos); + }, + Timestamp.oid.decimal, TimestampTz.oid.decimal => return Timestamp.encode(v, buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(@as(f64, @floatFromInt(v)), buf, format_pos), + Char.oid.decimal => { + if (v > 255 or v < 0) return error.IntWontFit; + return Char.encode(@intCast(v), buf, format_pos); + }, + Int64.oid.decimal, PgLSN.oid.decimal, Xid8.oid.decimal => return Int64.encode(v, buf, format_pos), + else => { + var tmp: [32]u8 = undefined; + const s = try std.fmt.bufPrint(&tmp, "{d}", .{v}); + return bindSlice(oid, s, buf, format_pos); + }, + }, + .f64 => |v| switch (oid) { + Float64.oid.decimal => return Float64.encode(v, buf, format_pos), + Float32.oid.decimal => return Float32.encode(@floatCast(v), buf, format_pos), + Numeric.oid.decimal => return Numeric.encode(v, buf, format_pos), + else => { + var tmp: [64]u8 = undefined; + const s = try std.fmt.bufPrint(&tmp, "{d}", .{v}); + return bindSlice(oid, s, buf, format_pos); + }, + }, + .text => |v| return bindSlice(oid, v, buf, format_pos), + } +} + +fn bindSlice(oid: i32, value: anytype, buf: *buffer.Buffer, format_pos: usize) !void { + const T = @TypeOf(value); + if (T == []u8 or T == []const u8) { + switch (oid) { + Bytea.oid.decimal => return Bytea.encode(value, buf, format_pos), + UUID.oid.decimal => return UUID.encode(value, buf, format_pos), + JSONB.oid.decimal => return JSONB.encodeBytes(value, buf, format_pos), + JSON.oid.decimal => return JSON.encodeBytes(value, buf, format_pos), + MacAddr.oid.decimal => return MacAddr.encode(value, buf, format_pos), + MacAddr8.oid.decimal => return MacAddr8.encode(value, buf, format_pos), + Bool.oid.decimal => { + const b = std.mem.eql(u8, value, "true") or std.mem.eql(u8, value, "t") or + std.mem.eql(u8, value, "yes") or std.mem.eql(u8, value, "on") or + std.mem.eql(u8, value, "1"); + return Bool.encode(b, buf, format_pos); + }, + CharArray.oid.decimal => { + // This is actually an array, and in theory we could let it fallthrough + // to the binary-array handling. BUT, if we do that, the code won't compile + // because it would mean T can be []u8 or []const u8, and that makes parts + // of the code invalid. Also, encoding a char array using the text protocol + // is going to be more efficient than encoding it using the binary protocol. + return Encode.writeTextEncodedCharArray(value, buf, format_pos); + }, + else => return String.encode(value, buf, format_pos), + } + } + + // For now, a few types are text-encoded. This largely has to do with the fact + // that there's no native Zig type, so a text representation lets us use PG's + // own text->type conversion. + if (comptime isStringArray(T)) { + switch (oid) { + TimestampArray.oid.decimal, NumericArray.oid.decimal => return Encode.writeTextEncodedEscapedArray(value, buf, format_pos), + TimestampTzArray.oid.decimal, CidrArray.oid.decimal, CidrArray.inet_oid.decimal => return Encode.writeTextEncodedRawArray(value, buf, format_pos), + MacAddrArray.oid.decimal => return MacAddrArray.encode(value, buf, format_pos), + MacAddr8Array.oid.decimal => return MacAddr8Array.encode(value, buf, format_pos), + else => {}, // fallthrough to binary encoding + } + } + + // We have an array. All arrays have the same header. We'll write this into + // buf now. It's possible we don't support the array type, so this can still + // fail. + + // arrays are always binary encoded (for now...) + + buf.writeAt(&binary_encoding, format_pos); + + const start_pos = buf.len(); + + try buf.write(&.{ + 0, 0, 0, 0, // placeholder for the length of this parameter + 0, 0, 0, 1, // number of dimensions, for now, we only support one + 0, 0, 0, 0, // bitmask of null, currently, with a single dimension, we don't have null arrays + 0, 0, 0, 0, // placeholder for the oid of each value + }); + + // where in buf, to write the OID of the values + const oid_pos = buf.len() - 4; + + // number of values in our first (and currently only) dimension + try buf.writeIntBig(i32, @intCast(value.len)); + try buf.write(&.{ 0, 0, 0, 1 }); // lower bound of this demension + + const ElemT = @typeInfo(T).pointer.child; + const ElemTT = switch (@typeInfo(ElemT)) { + .optional => |opt| opt.child, + else => ElemT, + }; + switch (@typeInfo(ElemTT)) { + .int => |int| { + if (int.signedness == .signed) { + switch (int.bits) { + 16 => try Int16Array.encode(value, buf, oid_pos), + 32 => try Int32Array.encode(value, buf, oid_pos), + 64 => { + switch (oid) { + TimestampArray.oid.decimal => try TimestampArray.encode(value, buf, oid_pos), + TimestampTzArray.oid.decimal => try TimestampTzArray.encode(value, buf, oid_pos), + else => try Int64Array.encode(value, buf, oid_pos), + } + }, + else => compileHaltBindError(T), + } + } else { + switch (int.bits) { + 8 => try CharArray.encodeOne(value, buf, oid_pos), + 16 => try Int16Array.encodeUnsigned(value, buf, oid_pos), + 32 => try Int32Array.encodeUnsigned(value, buf, oid_pos), + 64 => try Int64Array.encodeUnsigned(value, buf, oid_pos), + else => compileHaltBindError(T), + } + } + }, + .float => |float| { + if (oid == NumericArray.oid.decimal) { + try NumericArray.encode(value, buf, oid_pos); + } else switch (float.bits) { + 32 => try Float32Array.encode(value, buf, oid_pos), + 64 => try Float64Array.encode(value, buf, oid_pos), + else => compileHaltBindError(T), + } + }, + .bool => try BoolArray.encode(value, buf, oid_pos), + .pointer => |ptr| switch (ptr.size) { + .slice => switch (ptr.child) { + u8 => switch (oid) { + StringArray.oid.decimal => try StringArray.encode(value, buf, oid_pos), + UUIDArray.oid.decimal => try UUIDArray.encode(value, buf, oid_pos), + JSONBArray.oid.decimal => try JSONBArray.encode(value, buf, oid_pos), + JSONArray.oid.decimal => try JSONArray.encode(value, buf, oid_pos), + CharArray.oid.decimal => try CharArray.encode(value, buf, oid_pos), + // we try this as a default to support user defined types with unknown oids + // (like an array of enums) + else => try ByteaArray.encode(value, buf, oid_pos), + }, + else => compileHaltBindError(T), + }, + else => compileHaltBindError(T), + }, + .@"enum", .enum_literal => try StringArray.encodeEnum(&value, buf, oid_pos), + .array => try bindSlice(oid, &value, buf, format_pos), + else => compileHaltBindError(T), + } + + var param_len: [4]u8 = undefined; + // write the lenght of the parameter, -4 because for paremeters, the length + // prefix itself isn't included. + std.mem.writeInt(i32, ¶m_len, @intCast(buf.len() - start_pos - 4), .big); + buf.writeAt(¶m_len, start_pos); +} + +fn isStringArray(comptime T: type) bool { + switch (@typeInfo(T)) { + .pointer => |ptr| switch (ptr.size) { + .slice => switch (ptr.child) { + []u8, []const u8 => return true, + else => return false, + }, + else => return false, + }, + else => return false, + } +} + +// Write the last part of the Bind message: telling postgresql how it should +// encode each column of the response +pub fn resultEncoding(oids: []i32, buf: *buffer.Buffer) !void { + if (oids.len == 0) { + return buf.write(&.{ 0, 0 }); // we are specifying 0 return types + } + + // 2 bytes for the # of columns we're specifying + 2 bytes per column + const space_needed = 2 + oids.len * 2; + var view = try buf.skip(space_needed); + + view.writeIntBig(u16, @intCast(oids.len)); + for (oids) |oid| { + view.write(resultEncodingFor(oid)); + } +} + +fn compileHaltBindError(comptime T: type) noreturn { + @compileError("cannot bind value of type " ++ @typeName(T)); +} + +const t = lib.testing; +test "UUID: toString" { + try t.expectError(error.InvalidUUID, UUID.toString(&.{ 73, 190, 142, 9, 170, 250, 176, 16, 73, 21 })); + + const s = try UUID.toString(&.{ 183, 204, 40, 47, 236, 67, 73, 190, 142, 9, 170, 250, 176, 16, 73, 21 }); + try t.expectString("b7cc282f-ec43-49be-8e09-aafab0104915", &s); +} + +test "UUID: toBytes" { + try t.expectError(error.InvalidUUID, UUID.toBytes("")); + + { + const s = try UUID.toBytes("166B4751-D702-4FB9-9A2A-CD6B69ED18D6"); + try t.expectSlice(u8, &.{ 22, 107, 71, 81, 215, 2, 79, 185, 154, 42, 205, 107, 105, 237, 24, 214 }, &s); + } + + { + const s = try UUID.toBytes("166b4751-d702-4fb9-9a2a-cd6b69ed18d7"); + try t.expectSlice(u8, &.{ 22, 107, 71, 81, 215, 2, 79, 185, 154, 42, 205, 107, 105, 237, 24, 215 }, &s); + } +} diff --git a/zig/pg/src/types/cidr.zig b/zig/pg/src/types/cidr.zig new file mode 100644 index 0000000..d6ce2fc --- /dev/null +++ b/zig/pg/src/types/cidr.zig @@ -0,0 +1,41 @@ +const std = @import("std"); +const buffer = @import("buffer"); +const lib = @import("../lib.zig"); +const types = @import("../types.zig"); + +pub const Cidr = struct { + pub const encoding = &types.binary_encoding; + pub const oid = types.OID.make(650); + pub const inet_oid = types.OID.make(869); + + address: []const u8, + netmask: u8, + family: Family, + + pub const Family = enum { + v4, + v6, + }; + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) Cidr else lib.TypeError!Cidr { + lib.verifyDecodeType(fail_mode, Cidr, &.{ Cidr.oid.decimal, Cidr.inet_oid.decimal }, data_oid) catch |err| { + if (fail_mode == .unsafe) unreachable; + return err; + }; + + lib.assert(data.len == 8 or data.len == 20); + return decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) Cidr { + // data[0] is 2 for v4 and 3 for v6, but we can infer this from the length + // data[1] is the netmask + // data[2] is an is_cidr flag, don't think we need to care about that? + // data[3] is the length of the address, which we can ignore since the rest of the payload is the address + return .{ + .address = data[4..], + .netmask = data[1], + .family = if (data.len == 20) .v6 else .v4, + }; + } +}; diff --git a/zig/pg/src/types/numeric.zig b/zig/pg/src/types/numeric.zig new file mode 100644 index 0000000..de2d2b5 --- /dev/null +++ b/zig/pg/src/types/numeric.zig @@ -0,0 +1,423 @@ +const std = @import("std"); +const buffer = @import("buffer"); +const lib = @import("../lib.zig"); +const types = @import("../types.zig"); + +const Encode = types.Encode; + +const math = std.math; + +// Until Zig has a native decimal type, or a third party library becomes de +// facto standard, this library is going to have half-baked numeric support. +// Specifically, we capture the PG wire-format for the numeric as-is. If you +// need numeric-precision, then you'll have to make due with this and interpret +// the data yourself. However, we do provide convience functions to get an f64 +// or a string value from the numeric. +// +// PG's format ends with a group of base 10_000 digits. So to represent the +// number 3950.123456, `numeric.digits` will be: +// {0x0F, 0x6E, 0x04, 0xD2, 0x15, 0xE0} +// 3950 1234 5600 +// +// In this case `number_of_digits` will be 3. Take note: both `digits` and +// `number_of_digits` are base 10_000. +// +// `weight` is the weight of the first digit. For 3950.123456, the weight will +// be 0, so we end up with 3950 * (10_000 ^ 0). +// +// `scale` is the display scale, and represents the # of BASE 10 digits to print +// after the decimal. At first glance, this might seem redundant, but there's +// a difference between 3950.123456 and 3950.12345600. The latter indicates a +// greater degree of precision. `scale` can also be larger than the number of +// digits we have, which is used to indicate additional 0s. + +pub const Numeric = struct { + pub const encoding = &types.binary_encoding; + pub const oid = types.OID.make(1700); + + number_of_digits: u16, + weight: i16, + sign: Sign, + scale: u16, + // this is tied to the current row and is only valid while the row is valid + // calling `next()`, `deinit()` `drain()` on there result, or `deinit()` on + // a QueryRow will invalidate this. + digits: []const u8, + + const Sign = enum { + nan, + inf, + negative, + positive, + negativeInf, + }; + + pub fn encode(value: anytype, buf: *buffer.Buffer, format_pos: usize) !void { + buf.writeAt(encoding, format_pos); + return encodeBuf(value, buf); + } + + pub fn encodeBuf(value: anytype, buf: *buffer.Buffer) !void { + const T = @TypeOf(value); + if (T == comptime_float) { + return encodeValue(value, buf); + } + + const TT = switch (@typeInfo(T)) { + .optional => |opt| opt.child, + else => T, + }; + + var v: TT = undefined; + if (comptime @typeInfo(T) == .optional) { + v = value orelse return encodeValidString("null", buf); + } else { + v = value; + } + + if (math.isNan(v)) { + return encodeNaN(buf); + } + if (math.isNegativeInf(v)) { + return encodeNegativeInf(buf); + } + if (math.isInf(v)) { + return encodeInf(buf); + } + + return encodeValue(v, buf); + } + + fn encodeValue(value: anytype, buf: *buffer.Buffer) !void { + // turn our float into a string + var str_buf: [512]u8 = undefined; + const slice = try std.fmt.bufPrint(&str_buf, "{d}", .{value}); + return encodeValidString(slice, buf); + } + + pub fn decode(comptime fail_mode: lib.FailMode, data: []const u8, data_oid: i32) if (fail_mode == .unsafe) Numeric else lib.TypeError!Numeric { + lib.verifyDecodeType(fail_mode, Numeric, &.{Numeric.oid.decimal}, data_oid) catch |err| { + if (fail_mode == .unsafe) unreachable; + return err; + }; + + lib.assert(data.len >= 8); + return decodeKnown(data); + } + + pub fn decodeKnown(data: []const u8) Numeric { + return .{ + .number_of_digits = std.mem.readInt(u16, data[0..2], .big), + .weight = std.mem.readInt(i16, data[2..4], .big), + .sign = switch (std.mem.readInt(u16, data[4..6], .big)) { + 0x0000 => .positive, + 0x4000 => .negative, + 0xd000 => .inf, + 0xf000 => .negativeInf, + else => .nan, // 0xc000 + }, + .scale = std.mem.readInt(u16, data[6..8], .big), + .digits = data[8..], + }; + } + + pub fn decodeKnownToFloat(data: []const u8) f64 { + return decodeKnown(data).toFloat(); + } + + pub fn toFloat(self: Numeric) f64 { + switch (self.sign) { + .nan => return math.nan(f64), + .inf => return math.inf(f64), + .negativeInf => return -math.inf(f64), + else => {}, + } + + var value: f64 = 0; + var weight = self.weight; + var digits: []const u8 = self.digits; + for (0..self.number_of_digits) |_| { + const t = std.mem.readInt(i16, digits[0..2], .big); + value += @as(f64, @floatFromInt(t)) * math.pow(f64, 10_000, @floatFromInt(weight)); + digits = digits[2..]; + weight -= 1; + } + + return if (self.sign == .negative) -value else value; + } + + pub fn estimatedStringLen(self: Numeric) usize { + // for the decimal point + var l: usize = 1; + switch (self.sign) { + .nan => return 3, + .inf => return 3, + .negativeInf => return 4, + .negative => l += 1, + .positive => {}, + } + + // max size per base-10000 digit + if (self.number_of_digits == 0) { + return l + 2; // 0.0 but we already added the decimal place + } + + l += self.number_of_digits * 4; + // there's no integer in the number, but our string output will have + // a leading 0 (so it'll be 0.123 instead of just .123) + if (self.weight < 0) { + l += 1; + } + + return l; + } + + pub fn toString(self: Numeric, buf: []u8) ![]u8 { + switch (self.sign) { + .nan => { + @memcpy(buf[0..3], "nan"); + return buf[0..3]; + }, + .inf => { + @memcpy(buf[0..3], "inf"); + return buf[0..3]; + }, + .negativeInf => { + @memcpy(buf[0..4], "-inf"); + return buf[0..4]; + }, + else => {}, + } + + var pos: usize = 0; + var weight = self.weight; + var digits: []const u8 = self.digits; + const number_of_digits = self.number_of_digits; + + if (self.sign == .negative) { + buf[0] = '-'; + pos += 1; + } + + if (number_of_digits == 0) { + const end = pos + 3; + @memcpy(buf[pos..end], "0.0"); + return buf[0..end]; + } + + // do the integer part first + if (weight < 0) { + buf[pos] = '0'; + pos += 1; + } else { + while (weight >= 0) { + if (digits.len == 0) { + const end = pos + 4; + @memcpy(buf[pos..end], "0000"); + pos = end; + } else { + const t = std.mem.readInt(i16, digits[0..2], .big); + pos += (try std.fmt.bufPrint(buf[pos..], "{d}", .{t})).len; + digits = digits[2..]; + } + weight -= 1; + } + } + + buf[pos] = '.'; + pos += 1; + + // now the fraction + if (digits.len == 0) { + buf[pos] = '0'; + pos += 1; + } else { + while (digits.len > 0) { + const t = std.mem.readInt(i16, digits[0..2], .big); + if (t < 10) { + buf[pos + 2] = '0'; + buf[pos + 1] = '0'; + buf[pos] = '0'; + pos += 3; + } else if (t < 100) { + buf[pos + 1] = '0'; + buf[pos] = '0'; + pos += 2; + } else if (t < 1000) { + buf[pos] = '0'; + pos += 1; + } + pos += (try std.fmt.bufPrint(buf[pos..], "{d}", .{t})).len; + digits = digits[2..]; + } + } + + // we wrote the fraction in 4-digit groups, but our scale (aka display scale) + // might indicate that we should have less precision. For example, we might + // have written 0.1230, but the scale might be 3, in which case we should + // have written 0.123. + const display_scale = @mod(self.scale, 4); + if (display_scale > 0) { + pos -= 4 - display_scale; + } + return buf[0..pos]; + } +}; + +// encode a string that we know isn't NaN, Inf or -Inf. +fn encodeValidString(str: []const u8, buf: *buffer.Buffer) !void { + // the length of our parameter is dynamic, we reserve 4 bytes to fill in once + // we know what our length is. + const length_state = try Encode.variableLengthStart(buf); + + // we have 8 bytes of meta (# of digits, weight, sign and scale) that we + // don't yet know how to fill up, reserve the space. + var meta_view = try buf.skip(8); + + // buf now points to 12 bytes ahead of where we started. 4 bytes for the + // length and 8 bytes for the meta. This is the position where we fill in + // our base 10000 digits. + + var pos: usize = 0; + var positive = true; + + if (str[0] == '-') { + // check this here, so we don't need to check it on each iteration + pos = 1; + positive = false; + } + + // if no decimal, assume this is a whole number + const decimal_pos = std.mem.indexOfScalarPos(u8, str, pos, '.') orelse str.len; + + // The number of 4-digit groups in our integer + const integer_groups = blk: { + // We're going to write the digits of the integer portion of our float. This + // is base 10_000, so we're going to group them in 4s. Given a number like + // 12345, there are two ways to group these. The correct way is (1) (2345), the + // incorrect way is (1234) (5). + // The correct way will let us recombine the value as + // (1 * 10000) + (2345) = 12345 + // The incorrect way would result in + // (1234 * 10000) + 5 = 11239 + const integer_digits: u16 = @intCast(decimal_pos - pos); + + // our first group can be 1-4 digits + const first_group = @mod(integer_digits, 4); + if (first_group > 0) { + // if first_group == 0, then it's a full 4-digit group and can be handled + // by the more general case that follows + const end = pos + first_group; + try buf.writeIntBig(u16, generateGroup(str[pos..end])); + pos = end; + } + + // At this point, we know that decimal_pos - pos is a multilpe of 4 (possibly + // 0) and we can handle it 4 digits at a time. + while (pos < decimal_pos) { + const end = pos + 4; + try buf.writeIntBig(u16, generateGroup(str[pos..end])); + pos = end; + } + + break :blk try std.math.divCeil(u16, integer_digits, 4); + }; + + // skip decimal point + pos += 1; + + { + // Now we do the fraction. This is similar to above, with a different little + // concern. Given 0.12345, you might thing we need to group as (1234) (5) + // but actually we need to group as (1234) (5000). Just (5) would mean + // 0.12340005 (we'd pass {0, 5} as the big-16 encoded base-10000 for that 2nd + // group). + // This causes a new problem. There's a difference between 0.12345 and + // 0.12345000, the latter is indicative of greater precision. This is what + // the dscale (or just scale) meta parameter resolves. Our dscale will be + // 5, indicating that the fraction is "12345" and not "12345000".. + + if (str.len > decimal_pos + 4) { + const loop_end = str.len - 4; + while (pos < loop_end) { + const end = pos + 4; + try buf.writeIntBig(u16, generateGroup(str[pos..end])); + pos = end; + } + } + + if (pos < str.len) { + const leftover = str.len - pos; + if (leftover > 0) { + // we have an incomplete group left over, read comment above for why + // we're multiplying this + var group_value = generateGroup(str[pos..]); + group_value *= switch (leftover) { + 3 => 10, + 2 => 100, + 1 => 1000, + else => 1, + }; + try buf.writeIntBig(u16, group_value); + } + } + } + + { + // -1 to exclude the decimal point itself + const display_scale: u16 = if (decimal_pos == str.len) 0 else @intCast(str.len - decimal_pos - 1); + + // Fill in our meta + // Number of base-10000 digits that we wrote. + meta_view.writeIntBig(u16, @intCast(integer_groups + try std.math.divCeil(u16, display_scale, 4))); + + // weight is the number of integer groups - 1; + if (integer_groups == 0 or integer_groups == 1) { + meta_view.write(&.{ 0, 0 }); + } else { + meta_view.writeIntBig(u16, integer_groups - 1); + } + + if (positive) { + meta_view.write(&.{ 0, 0 }); + } else { + meta_view.write(&.{ 64, 0 }); + } + meta_view.writeIntBig(u16, display_scale); + } + + // fill our our length + Encode.variableLengthFill(buf, length_state); +} + +fn encodeNaN(buf: *buffer.Buffer) !void { + // 8 length, 0 digits, 0 weight, nan sign, 0 dscale + return buf.write(&.{ 0, 0, 0, 8, 0, 0, 0, 0, 192, 0, 0, 0 }); +} + +fn encodeInf(buf: *buffer.Buffer) !void { + // 8 length, 0 digits, 0 weight, inf sign, 0 dscale + return buf.write(&.{ 0, 0, 0, 8, 0, 0, 0, 0, 208, 0, 0, 0 }); +} + +fn encodeNegativeInf(buf: *buffer.Buffer) !void { + // 8 length, 0 digits, 0 weight, -inf sign, 0 dscale + return buf.write(&.{ 0, 0, 0, 8, 0, 0, 0, 0, 240, 0, 0, 0 }); +} + +fn generateGroup(str: []const u8) u16 { + const number_of_digits = str.len; + + var group_value: u16 = 0; + for (str, 0..) |c, i| { + const d = @as(u16, c - '0'); + group_value += switch (number_of_digits - i) { + 4 => d * 1_000, + 3 => d * 100, + 2 => d * 10, + 1 => d, + else => unreachable, + }; + } + return group_value; +} diff --git a/zig/pg/src/types/vector.zig b/zig/pg/src/types/vector.zig new file mode 100644 index 0000000..9410a13 --- /dev/null +++ b/zig/pg/src/types/vector.zig @@ -0,0 +1,205 @@ +// pgvector support — decodes vector binary format and serializes to JSON. +// Binary format: int16 dim, int16 unused, float32[dim] values (big-endian). +// SIMD-accelerated: uses @Vector for batch float32 endian conversion. + +const std = @import("std"); + +pub const Vector = struct { + // pgvector registers dynamically — OID varies per database. + // Use configureVectorOid() at startup to set it. + pub var oid_decimal: i32 = 0; + + dim: u16, + values: []const f32, + // Raw data pointer — values are decoded on demand via SIMD + raw: []const u8, + + pub fn decode(data: []const u8) Vector { + if (data.len < 4) return .{ .dim = 0, .values = &.{}, .raw = data }; + const dim = std.mem.readInt(u16, data[0..2], .big); + // skip unused (bytes 2-3) + return .{ + .dim = dim, + .values = &.{}, // lazy — use toFloats() or writeJson() + .raw = data, + }; + } + + /// Decode all float32 values using SIMD batch conversion. + /// Caller must free the returned slice. + pub fn toFloats(self: Vector, alloc: std.mem.Allocator) ![]f32 { + const dim = self.dim; + if (dim == 0) return &.{}; + const float_data = self.raw[4..]; // skip dim + unused + const result = try alloc.alloc(f32, dim); + + // SIMD path: process 4 floats at a time using @Vector + const simd_width = 4; + const simd_batches = dim / simd_width; + + var i: usize = 0; + while (i < simd_batches) : (i += 1) { + const base = i * simd_width * 4; + // Load 4 big-endian i32s and reinterpret as f32 + var ints: @Vector(simd_width, i32) = undefined; + inline for (0..simd_width) |j| { + const offset = base + j * 4; + ints[j] = std.mem.readInt(i32, float_data[offset..][0..4], .big); + } + // Bitcast i32 vector to f32 vector + const floats: @Vector(simd_width, f32) = @bitCast(ints); + inline for (0..simd_width) |j| { + result[i * simd_width + j] = floats[j]; + } + } + + // Scalar remainder + var k: usize = simd_batches * simd_width; + while (k < dim) : (k += 1) { + const offset = k * 4; + const n = std.mem.readInt(i32, float_data[offset..][0..4], .big); + result[k] = @bitCast(n); + } + + return result; + } + + /// Write vector as JSON array: [0.1, 0.2, 0.3, ...] + /// SIMD-accelerated endian conversion. + pub fn writeJson(self: Vector, buf: []u8) usize { + const dim = self.dim; + if (dim == 0) { + @memcpy(buf[0..2], "[]"); + return 2; + } + + const float_data = self.raw[4..]; + var pos: usize = 0; + buf[pos] = '['; + pos += 1; + + // SIMD batch decode + format + const simd_width = 4; + const simd_batches = dim / simd_width; + + var batch: usize = 0; + while (batch < simd_batches) : (batch += 1) { + const base = batch * simd_width * 4; + var ints: @Vector(simd_width, i32) = undefined; + inline for (0..simd_width) |j| { + const offset = base + j * 4; + ints[j] = std.mem.readInt(i32, float_data[offset..][0..4], .big); + } + const floats: @Vector(simd_width, f32) = @bitCast(ints); + + inline for (0..simd_width) |j| { + if (batch > 0 or j > 0) { + buf[pos] = ','; + pos += 1; + } + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{floats[j]}) catch break; + pos += s.len; + } + } + + // Scalar remainder + var k: usize = simd_batches * simd_width; + while (k < dim) : (k += 1) { + if (k > 0) { + buf[pos] = ','; + pos += 1; + } + const offset = k * 4; + const n = std.mem.readInt(i32, float_data[offset..][0..4], .big); + const v: f32 = @bitCast(n); + const s = std.fmt.bufPrint(buf[pos..], "{d}", .{v}) catch break; + pos += s.len; + } + + buf[pos] = ']'; + pos += 1; + return pos; + } + + /// Compute L2 distance between two vectors using SIMD. + pub fn l2Distance(self: Vector, other: Vector, alloc: std.mem.Allocator) !f32 { + if (self.dim != other.dim) return error.DimensionMismatch; + const a = try self.toFloats(alloc); + defer alloc.free(a); + const b = try other.toFloats(alloc); + defer alloc.free(b); + + const dim = self.dim; + const simd_width = 4; + const simd_batches = dim / simd_width; + var sum: @Vector(simd_width, f32) = @splat(0); + + var i: usize = 0; + while (i < simd_batches) : (i += 1) { + var va: @Vector(simd_width, f32) = undefined; + var vb: @Vector(simd_width, f32) = undefined; + inline for (0..simd_width) |j| { + va[j] = a[i * simd_width + j]; + vb[j] = b[i * simd_width + j]; + } + const diff = va - vb; + sum += diff * diff; + } + + var total: f32 = @reduce(.Add, sum); + + // Scalar remainder + var k: usize = simd_batches * simd_width; + while (k < dim) : (k += 1) { + const diff = a[k] - b[k]; + total += diff * diff; + } + + return @sqrt(total); + } + + /// Compute cosine similarity between two vectors using SIMD. + pub fn cosineSimilarity(self: Vector, other: Vector, alloc: std.mem.Allocator) !f32 { + if (self.dim != other.dim) return error.DimensionMismatch; + const a = try self.toFloats(alloc); + defer alloc.free(a); + const b = try other.toFloats(alloc); + defer alloc.free(b); + + const dim = self.dim; + const simd_width = 4; + const simd_batches = dim / simd_width; + var dot_sum: @Vector(simd_width, f32) = @splat(0); + var a_sq_sum: @Vector(simd_width, f32) = @splat(0); + var b_sq_sum: @Vector(simd_width, f32) = @splat(0); + + var i: usize = 0; + while (i < simd_batches) : (i += 1) { + var va: @Vector(simd_width, f32) = undefined; + var vb: @Vector(simd_width, f32) = undefined; + inline for (0..simd_width) |j| { + va[j] = a[i * simd_width + j]; + vb[j] = b[i * simd_width + j]; + } + dot_sum += va * vb; + a_sq_sum += va * va; + b_sq_sum += vb * vb; + } + + var dot: f32 = @reduce(.Add, dot_sum); + var a_sq: f32 = @reduce(.Add, a_sq_sum); + var b_sq: f32 = @reduce(.Add, b_sq_sum); + + // Scalar remainder + var k: usize = simd_batches * simd_width; + while (k < dim) : (k += 1) { + dot += a[k] * b[k]; + a_sq += a[k] * a[k]; + b_sq += b[k] * b[k]; + } + + const denom = @sqrt(a_sq) * @sqrt(b_sq); + if (denom == 0) return 0; + return dot / denom; + } +}; diff --git a/zig/pg/test_runner.zig b/zig/pg/test_runner.zig new file mode 100644 index 0000000..00f457f --- /dev/null +++ b/zig/pg/test_runner.zig @@ -0,0 +1,294 @@ +// in your build.zig, you can specify a custom test runner: +// const tests = b.addTest(.{ +// .root_module = $MODULE_BEING_TESTED, +// .test_runner = .{ .path = b.path("test_runner.zig"), .mode = .simple }, +// }); + +const std = @import("std"); +const builtin = @import("builtin"); + +const Allocator = std.mem.Allocator; + +const BORDER = "=" ** 80; + +// use in custom panic handler +var current_test: ?[]const u8 = null; + +pub fn main() !void { + var mem: [8192]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&mem); + + const allocator = fba.allocator(); + + const env = Env.init(allocator); + defer env.deinit(allocator); + + var slowest = SlowTracker.init(allocator, 5); + defer slowest.deinit(); + + var pass: usize = 0; + var fail: usize = 0; + var skip: usize = 0; + var leak: usize = 0; + + Printer.fmt("\r\x1b[0K", .{}); // beginning of line and clear to end of line + + for (builtin.test_functions) |t| { + if (isSetup(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nsetup \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + for (builtin.test_functions) |t| { + if (isSetup(t) or isTeardown(t)) { + continue; + } + + var status = Status.pass; + slowest.startTiming(); + + const is_unnamed_test = isUnnamed(t); + if (env.filter) |f| { + if (!is_unnamed_test and std.mem.indexOf(u8, t.name, f) == null) { + continue; + } + } + + const friendly_name = blk: { + const name = t.name; + var it = std.mem.splitScalar(u8, name, '.'); + while (it.next()) |value| { + if (std.mem.eql(u8, value, "test")) { + const rest = it.rest(); + break :blk if (rest.len > 0) rest else name; + } + } + break :blk name; + }; + + current_test = friendly_name; + std.testing.allocator_instance = .{}; + const result = t.func(); + current_test = null; + + const ns_taken = slowest.endTiming(friendly_name); + + if (std.testing.allocator_instance.deinit() == .leak) { + leak += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - Memory Leak\n{s}\n", .{ BORDER, friendly_name, BORDER }); + } + + if (result) |_| { + pass += 1; + } else |err| switch (err) { + error.SkipZigTest => { + skip += 1; + status = .skip; + }, + else => { + status = .fail; + fail += 1; + Printer.status(.fail, "\n{s}\n\"{s}\" - {s}\n{s}\n", .{ BORDER, friendly_name, @errorName(err), BORDER }); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + if (env.fail_first) { + break; + } + }, + } + + if (env.verbose) { + const ms = @as(f64, @floatFromInt(ns_taken)) / 1_000_000.0; + Printer.status(status, "{s} ({d:.2}ms)\n", .{ friendly_name, ms }); + } else { + Printer.status(status, ".", .{}); + } + } + + for (builtin.test_functions) |t| { + if (isTeardown(t)) { + t.func() catch |err| { + Printer.status(.fail, "\nteardown \"{s}\" failed: {}\n", .{ t.name, err }); + return err; + }; + } + } + + const total_tests = pass + fail; + const status = if (fail == 0) Status.pass else Status.fail; + Printer.status(status, "\n{d} of {d} test{s} passed\n", .{ pass, total_tests, if (total_tests != 1) "s" else "" }); + if (skip > 0) { + Printer.status(.skip, "{d} test{s} skipped\n", .{ skip, if (skip != 1) "s" else "" }); + } + if (leak > 0) { + Printer.status(.fail, "{d} test{s} leaked\n", .{ leak, if (leak != 1) "s" else "" }); + } + Printer.fmt("\n", .{}); + try slowest.display(); + Printer.fmt("\n", .{}); + std.posix.exit(if (fail == 0) 0 else 1); +} + +const Printer = struct { + fn fmt(comptime format: []const u8, args: anytype) void { + std.debug.print(format, args); + } + + fn status(s: Status, comptime format: []const u8, args: anytype) void { + switch (s) { + .pass => std.debug.print("\x1b[32m", .{}), + .fail => std.debug.print("\x1b[31m", .{}), + .skip => std.debug.print("\x1b[33m", .{}), + else => {}, + } + std.debug.print(format ++ "\x1b[0m", args); + } +}; + +const Status = enum { + pass, + fail, + skip, + text, +}; + +const SlowTracker = struct { + const SlowestQueue = std.PriorityDequeue(TestInfo, void, compareTiming); + max: usize, + slowest: SlowestQueue, + timer: std.time.Timer, + + fn init(allocator: Allocator, count: u32) SlowTracker { + const timer = std.time.Timer.start() catch @panic("failed to start timer"); + var slowest = SlowestQueue.init(allocator, {}); + slowest.ensureTotalCapacity(count) catch @panic("OOM"); + return .{ + .max = count, + .timer = timer, + .slowest = slowest, + }; + } + + const TestInfo = struct { + ns: u64, + name: []const u8, + }; + + fn deinit(self: SlowTracker) void { + self.slowest.deinit(); + } + + fn startTiming(self: *SlowTracker) void { + self.timer.reset(); + } + + fn endTiming(self: *SlowTracker, test_name: []const u8) u64 { + var timer = self.timer; + const ns = timer.lap(); + + var slowest = &self.slowest; + + if (slowest.count() < self.max) { + // Capacity is fixed to the # of slow tests we want to track + // If we've tracked fewer tests than this capacity, than always add + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + { + // Optimization to avoid shifting the dequeue for the common case + // where the test isn't one of our slowest. + const fastest_of_the_slow = slowest.peekMin() orelse unreachable; + if (fastest_of_the_slow.ns > ns) { + // the test was faster than our fastest slow test, don't add + return ns; + } + } + + // the previous fastest of our slow tests, has been pushed off. + _ = slowest.removeMin(); + slowest.add(TestInfo{ .ns = ns, .name = test_name }) catch @panic("failed to track test timing"); + return ns; + } + + fn display(self: *SlowTracker) !void { + var slowest = self.slowest; + const count = slowest.count(); + Printer.fmt("Slowest {d} test{s}: \n", .{ count, if (count != 1) "s" else "" }); + while (slowest.removeMinOrNull()) |info| { + const ms = @as(f64, @floatFromInt(info.ns)) / 1_000_000.0; + Printer.fmt(" {d:.2}ms\t{s}\n", .{ ms, info.name }); + } + } + + fn compareTiming(context: void, a: TestInfo, b: TestInfo) std.math.Order { + _ = context; + return std.math.order(a.ns, b.ns); + } +}; + +const Env = struct { + verbose: bool, + fail_first: bool, + filter: ?[]const u8, + + fn init(allocator: Allocator) Env { + return .{ + .verbose = readEnvBool(allocator, "TEST_VERBOSE", true), + .fail_first = readEnvBool(allocator, "TEST_FAIL_FIRST", false), + .filter = readEnv(allocator, "TEST_FILTER"), + }; + } + + fn deinit(self: Env, allocator: Allocator) void { + if (self.filter) |f| { + allocator.free(f); + } + } + + fn readEnv(allocator: Allocator, key: []const u8) ?[]const u8 { + const v = std.process.getEnvVarOwned(allocator, key) catch |err| { + if (err == error.EnvironmentVariableNotFound) { + return null; + } + std.log.warn("failed to get env var {s} due to err {}", .{ key, err }); + return null; + }; + return v; + } + + fn readEnvBool(allocator: Allocator, key: []const u8, deflt: bool) bool { + const value = readEnv(allocator, key) orelse return deflt; + defer allocator.free(value); + return std.ascii.eqlIgnoreCase(value, "true"); + } +}; + +pub const panic = std.debug.FullPanic(struct { + pub fn panicFn(msg: []const u8, first_trace_addr: ?usize) noreturn { + if (current_test) |ct| { + std.debug.print("\x1b[31m{s}\npanic running \"{s}\"\n{s}\x1b[0m\n", .{ BORDER, ct, BORDER }); + } + std.debug.defaultPanic(msg, first_trace_addr); + } +}.panicFn); + +fn isUnnamed(t: std.builtin.TestFn) bool { + const marker = ".test_"; + const test_name = t.name; + const index = std.mem.indexOf(u8, test_name, marker) orelse return false; + _ = std.fmt.parseInt(u32, test_name[index + marker.len ..], 10) catch return false; + return true; +} + +fn isSetup(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:beforeAll"); +} + +fn isTeardown(t: std.builtin.TestFn) bool { + return std.mem.endsWith(u8, t.name, "tests:afterAll"); +} diff --git a/zig/pg/tests/client.crt b/zig/pg/tests/client.crt new file mode 100644 index 0000000..c7eeb92 --- /dev/null +++ b/zig/pg/tests/client.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDqDCCApCgAwIBAgIUbR2RMBCPcwkULT7q9WyFvoKL4HcwDQYJKoZIhvcNAQEL +BQAwYTELMAkGA1UEBhMCU0cxCzAJBgNVBAgMAlNHMQswCQYDVQQHDAJTRzERMA8G +A1UECgwIUGVyc29uYWwxETAPBgNVBAsMCFBlcnNvbmFsMRIwEAYDVQQDDAlsb2Nh +bGhvc3QwHhcNMjUwMzAyMDM1NjIxWhcNMzUwMjI4MDM1NjIxWjB3MQswCQYDVQQG +EwJTRzELMAkGA1UECAwCU0cxCzAJBgNVBAcMAlNHMREwDwYDVQQKDAhQZXJzb25h +bDERMA8GA1UECwwIUGVyc29uYWwxEjAQBgNVBAMMCWxvY2FsaG9zdDEUMBIGA1UE +AwwLdGVzdGNsaWVudDEwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDQ +S9LrrcH+BM7oVOWWNtAjuo9LDyTLAE9wznVj9mQnV4EuAzigNXVwSnJiPAT8swVf +RB7iJw4syPmhtin3U+kTH8lHrLw0trT6FvKpvamP4O/7pAWCm7HPuuR8cNqhvgad +kH92xu12lzBR63fligs5divveexjoK5vtfH3SLUL4nB7/wInlj+q/Y/gde96nc3o +PX/j1GW6wj8GE/90mnsG9o1azjS2+ORU99qsPndoKWuS9fzydzWT9TsBGLQ9NmHs +UWxkxo6y4m7vt73dpGUPdZFqsRFCaHwKRGhVmWzk7s7fmpndF2EH4WFSVLJo3ovS +gkvxussIma+VfKBf0wI7AgMBAAGjQjBAMB0GA1UdDgQWBBR3k83lO7Y0fp0Xw27R +naQ0n/LwUzAfBgNVHSMEGDAWgBRzXFPvmXF7H7h9EWUjeYY7XkRPEzANBgkqhkiG +9w0BAQsFAAOCAQEAoMd6s6xNF3Oqgt1yP3ixN+bGpW4F7etO0+A5FWcO3/ttIMjs +dYVDKIrcnHqMvYIKRM4+VR5Wxqg2hLu9kxDSVNxMjZtM+Wt1hnYuso5/L3hqMUft +EYg9nTSoizhd8ZS6rgNoei7gHt7ZohZ2l2io3VAqUczvQ68zxTyuhzlKR5gRESf4 +MxVGYlCakUddHFud5PiMv+wOEO5vmZ2q9+F2UHSM03X4Cs7ZTBi10UmitP4f/nFo +7Wj76tJIn8jiiewU7Fy7alV4s4DHYcmSmUXHMiX77aiEKb+kuw36R1EXx8WOqiYw +Rhdzbe8K9zlN4JZ6i8660SHeTbJ4rGpiLY00gg== +-----END CERTIFICATE----- diff --git a/zig/pg/tests/client.key b/zig/pg/tests/client.key new file mode 100644 index 0000000..3dc61e6 --- /dev/null +++ b/zig/pg/tests/client.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDQS9LrrcH+BM7o +VOWWNtAjuo9LDyTLAE9wznVj9mQnV4EuAzigNXVwSnJiPAT8swVfRB7iJw4syPmh +tin3U+kTH8lHrLw0trT6FvKpvamP4O/7pAWCm7HPuuR8cNqhvgadkH92xu12lzBR +63fligs5divveexjoK5vtfH3SLUL4nB7/wInlj+q/Y/gde96nc3oPX/j1GW6wj8G +E/90mnsG9o1azjS2+ORU99qsPndoKWuS9fzydzWT9TsBGLQ9NmHsUWxkxo6y4m7v +t73dpGUPdZFqsRFCaHwKRGhVmWzk7s7fmpndF2EH4WFSVLJo3ovSgkvxussIma+V +fKBf0wI7AgMBAAECggEAHGeKnrr2LlanhIU3PbHB1m11Zu3svYYQTzjIR1ZtN/Q1 +2Hl1+lrv5d0xFfZLU85x2BjpATwEMdVCPWwi8uhNa181SoHitGmJ6mTAuKb1fXpW +H4GxgqsKp2I8EEAvgEjyjAANcbwU28woObObiQC3ISHdQe3lb3yU7QTptygCEFmU +ndSDtW/gY3KznqiWDdWIiWDcgotuWHO5VDaYZIJI4tZaymFQGTwwTrBVL/kOtLnA +ZavETG/HBlXPEAWciKJe+ut+JHC1YwVjMuALZBOBK6VX1kY1YEpNSQSWgXwzgWXb +JiSY9+7yGvx+r4eEqadB8jQJTMqaKWrKOcwob4GyCQKBgQDzdq9+kg7QX0nJozgF +3UIZTnjeemdvJkFOcEJBodE//1cMhZWuFnP8qoGJsyxJzMLsBFQhCdAUdISJIewa +VUZS0buPXsYmrWVnZIle/gJx2N5Qopy9Urihpj4RZJ9b3Z6g2m7qq4eCtiVqb8pj +Vt/pz2pcaLMPT6l/6PNQe4YXZwKBgQDbBZCNH89lWLutriAv9GS3z7lKUcuUpXWh +nOWTObAgVEUjmeTlxRct4M/dLmah+BUV9bEvWIJdXCTGNuxOVWSRsBk6NRkhw1qI +3j6ay5zfcaRbzgHypMoh9x7Jk4ucoH7/pN+g+HJG7r+rQvQPgJEFI5sk0eCMXyqT +K3OBr7ReDQKBgQC7hVjalkkOubYttqe57Jeywjxar9DnTYHTlqeRwb9YGaXEoUeO +lQC1RecMVpLwLOSdwR/Darl4Z96FeTlPdwr5U02xuf/JXpjSMB+WqPLdGXryhK1R +LVvENjVsVCJiMaqynkv8OC3hwcXD22L5bLp+biGwF3yDeIpHWPe/r8SyVQKBgG/n +xvSkNSZWEQZrelymJTPZeZWkdz0K0TBy5sWzau8Jv42yGsbfTbmOLQaYp63IAJYI +w7AqK+mho9R2yYQ4kzrw2+LmsGGU29QkoZ4bvJpaCR0zA8HDOtfh2KQrs+CiDGF4 +Dx7C8jiV0e6iNesZyH70s1c7uNxf33P0dn7jlGedAoGAQMSOvZPXbFb9BNekvCwc +A+xl5QcTIhiQ8t6EcZ6E1PfwmR5mDNNktbyugzakzqvDhwxU7rt7E60vI+uUqL31 +zNTXHB02uMy//8xkScW8G2321ocyX7SlCH5jobR25Cgdn/hG1yJgTw99Tbjf9quw +VhAjyNx2AVyko1kdlaY8FrI= +-----END PRIVATE KEY----- diff --git a/zig/pg/tests/compose.yml b/zig/pg/tests/compose.yml new file mode 100644 index 0000000..1fc359c --- /dev/null +++ b/zig/pg/tests/compose.yml @@ -0,0 +1,25 @@ +services: + postgres: + image: postgis/postgis + platform: linux/amd64 + environment: + POSTGRES_USER: "postgres" + POSTGRES_PASSWORD: "postgres" + PGDATA: "/var/lib/postgresql/pgdata/" + LANG: "en_US.utf8" + ports: + - 5432:5432 + volumes: + - "./root.crt:/etc/postgresql/root.crt:ro" + - "./server.crt:/etc/postgresql/server.crt:ro" + - "./server.key:/etc/postgresql/server.key:ro" + - "./pg_hba.conf:/etc/postgresql/pg_hba.conf:ro" + - "./postgresql.conf:/etc/postgresql/postgresql.conf:ro" + - "./init_ssl.sql:/docker-entrypoint-initdb.d/init_ssl.sql:ro" + command: + - "postgres" + - "-c" + - "config_file=/etc/postgresql/postgresql.conf" + - "-c" + - "hba_file=/etc/postgresql/pg_hba.conf" + diff --git a/zig/pg/tests/init_ssl.sql b/zig/pg/tests/init_ssl.sql new file mode 100644 index 0000000..8a5d5a4 --- /dev/null +++ b/zig/pg/tests/init_ssl.sql @@ -0,0 +1,4 @@ +ALTER SYSTEM SET ssl_ca_file TO '/etc/postgresql/root.crt'; +ALTER SYSTEM SET ssl_key_file TO '/etc/postgresql/server.key'; +ALTER SYSTEM SET ssl_cert_file TO '/etc/postgresql/server.crt'; +ALTER SYSTEM SET ssl TO 'ON'; diff --git a/zig/pg/tests/pg_hba.conf b/zig/pg/tests/pg_hba.conf new file mode 100644 index 0000000..a26b19b --- /dev/null +++ b/zig/pg/tests/pg_hba.conf @@ -0,0 +1,6 @@ +local all postgres trust +host all postgres all trust +host all pgz_user_clear all password +host all pgz_user_nopass all trust +host all pgz_user_scram_sha256 all scram-sha-256 +hostssl all pgz_user_ssl all password diff --git a/zig/pg/tests/postgresql.conf b/zig/pg/tests/postgresql.conf new file mode 100644 index 0000000..cf98a71 --- /dev/null +++ b/zig/pg/tests/postgresql.conf @@ -0,0 +1,13 @@ +max_connections = 30 +shared_buffers = 512MB +listen_addresses = '0.0.0.0' +work_mem = 4MB +effective_cache_size = 512MB +log_timezone = 'UTC' +datestyle = 'iso, mdy' +timezone = 'UTC' +lc_messages = 'en_US.UTF-8' +lc_monetary = 'en_US.UTF-8' +lc_numeric = 'en_US.UTF-8' +lc_time = 'en_US.UTF-8' +fsync = false diff --git a/zig/pg/tests/root.crt b/zig/pg/tests/root.crt new file mode 100644 index 0000000..c5bc6cd --- /dev/null +++ b/zig/pg/tests/root.crt @@ -0,0 +1,80 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 12:2c:3f:3d:ca:26:dc:f3:ed:b5:1a:7a:55:d1:50:f6:6d:b5:44:e1 + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=SG, ST=SG, L=SG, O=Personal, OU=Personal, CN=localhost + Validity + Not Before: Mar 2 03:56:21 2025 GMT + Not After : Feb 28 03:56:21 2035 GMT + Subject: C=SG, ST=SG, L=SG, O=Personal, OU=Personal, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:b0:7b:21:98:82:94:d3:24:bc:ff:9d:ac:2d:49: + 79:b4:30:69:99:fc:e4:45:2e:ff:d5:1d:a3:02:9a: + 85:52:de:d8:89:ad:89:2a:37:90:98:31:8d:62:3a: + 6b:07:32:fc:e1:77:b6:eb:91:6f:4a:e1:4a:cb:4c: + 79:6a:eb:06:60:9c:76:7c:99:ec:7b:9b:29:c0:01: + fe:fb:97:28:82:19:3b:55:73:a8:ab:57:94:27:b6: + c2:d9:f6:1a:de:1b:60:50:5e:1a:5e:96:47:c9:01: + 89:3b:67:1c:8b:91:0e:2b:7c:c5:97:72:36:e4:e6: + 8d:a3:d4:d6:5c:83:17:7b:24:7c:a4:95:dd:7e:d3: + d2:10:5a:19:3c:63:48:79:4f:00:c0:b2:28:02:b0: + 83:2b:5f:6e:14:a0:9c:4e:0a:90:76:b4:db:af:b9: + 92:c0:a8:7a:70:59:c3:1b:b3:3a:63:73:0d:42:d6: + a3:bb:c4:d9:ee:c9:0e:f1:e9:88:fa:a9:83:9f:5b: + 10:d8:e2:37:16:f0:43:6f:b9:a4:b9:c0:18:00:29: + e4:3c:24:5e:07:9a:35:f7:4a:d6:d4:4a:17:a1:85: + 8a:4f:d5:22:7f:65:28:04:ea:33:46:cf:93:89:7d: + 56:aa:bc:7f:bb:d1:82:16:df:e2:0b:12:13:19:40: + d5:05 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + 73:5C:53:EF:99:71:7B:1F:B8:7D:11:65:23:79:86:3B:5E:44:4F:13 + X509v3 Authority Key Identifier: + 73:5C:53:EF:99:71:7B:1F:B8:7D:11:65:23:79:86:3B:5E:44:4F:13 + X509v3 Basic Constraints: critical + CA:TRUE + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 05:a7:9b:17:2b:0f:99:aa:37:04:1f:cc:8c:ec:16:16:32:4e: + b4:95:97:4f:b2:fc:09:c8:2b:27:06:8c:d0:ea:1f:49:17:1c: + 70:05:80:d0:5f:f1:bc:e1:d3:d9:30:24:59:53:d0:e4:7f:72: + 9f:8f:7f:15:4f:59:a1:89:19:50:91:a1:de:dd:99:9e:db:6c: + 0b:87:44:d8:9e:0e:26:7d:7e:6d:24:c1:b6:a8:4f:b3:40:75: + 01:9b:a5:70:a7:af:b5:e1:5d:44:8e:ae:8f:24:1b:79:d3:4c: + 05:4c:be:49:f2:00:d4:80:e0:ac:3d:be:7c:2e:fb:ca:d2:8a: + 95:7a:88:ee:9f:09:d9:17:10:46:0e:e2:0e:07:24:bd:66:36: + 53:25:36:e0:76:47:7b:dc:69:8d:a4:90:84:da:1e:a1:f1:8f: + a4:c9:a4:bf:3c:f2:25:2b:d6:82:d6:98:55:42:68:d8:8f:cf: + ac:f8:41:fa:92:0d:7d:9c:58:52:42:c1:e2:6d:5a:fb:d7:6b: + d3:f1:b4:f5:50:e6:12:7a:02:78:e1:ff:29:80:4b:b4:10:7f: + 3a:14:fd:f5:05:55:57:86:07:38:84:20:86:e3:35:80:eb:83: + 2b:20:b5:b5:50:de:39:6d:09:2e:5f:e5:3e:e9:1a:53:95:59: + b8:f9:9e:77 +-----BEGIN CERTIFICATE----- +MIIDozCCAougAwIBAgIUEiw/Pcom3PPttRp6VdFQ9m21ROEwDQYJKoZIhvcNAQEL +BQAwYTELMAkGA1UEBhMCU0cxCzAJBgNVBAgMAlNHMQswCQYDVQQHDAJTRzERMA8G +A1UECgwIUGVyc29uYWwxETAPBgNVBAsMCFBlcnNvbmFsMRIwEAYDVQQDDAlsb2Nh +bGhvc3QwHhcNMjUwMzAyMDM1NjIxWhcNMzUwMjI4MDM1NjIxWjBhMQswCQYDVQQG +EwJTRzELMAkGA1UECAwCU0cxCzAJBgNVBAcMAlNHMREwDwYDVQQKDAhQZXJzb25h +bDERMA8GA1UECwwIUGVyc29uYWwxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJ +KoZIhvcNAQEBBQADggEPADCCAQoCggEBALB7IZiClNMkvP+drC1JebQwaZn85EUu +/9UdowKahVLe2ImtiSo3kJgxjWI6awcy/OF3tuuRb0rhSstMeWrrBmCcdnyZ7Hub +KcAB/vuXKIIZO1VzqKtXlCe2wtn2Gt4bYFBeGl6WR8kBiTtnHIuRDit8xZdyNuTm +jaPU1lyDF3skfKSV3X7T0hBaGTxjSHlPAMCyKAKwgytfbhSgnE4KkHa026+5ksCo +enBZwxuzOmNzDULWo7vE2e7JDvHpiPqpg59bENjiNxbwQ2+5pLnAGAAp5DwkXgea +NfdK1tRKF6GFik/VIn9lKATqM0bPk4l9Vqq8f7vRghbf4gsSExlA1QUCAwEAAaNT +MFEwHQYDVR0OBBYEFHNcU++ZcXsfuH0RZSN5hjteRE8TMB8GA1UdIwQYMBaAFHNc +U++ZcXsfuH0RZSN5hjteRE8TMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEL +BQADggEBAAWnmxcrD5mqNwQfzIzsFhYyTrSVl0+y/AnIKycGjNDqH0kXHHAFgNBf +8bzh09kwJFlT0OR/cp+PfxVPWaGJGVCRod7dmZ7bbAuHRNieDiZ9fm0kwbaoT7NA +dQGbpXCnr7XhXUSOro8kG3nTTAVMvknyANSA4Kw9vnwu+8rSipV6iO6fCdkXEEYO +4g4HJL1mNlMlNuB2R3vcaY2kkITaHqHxj6TJpL888iUr1oLWmFVCaNiPz6z4QfqS +DX2cWFJCweJtWvvXa9PxtPVQ5hJ6Anjh/ymAS7QQfzoU/fUFVVeGBziEIIbjNYDr +gysgtbVQ3jltCS5f5T7pGlOVWbj5nnc= +-----END CERTIFICATE----- diff --git a/zig/pg/tests/root.srl b/zig/pg/tests/root.srl new file mode 100644 index 0000000..97aa070 --- /dev/null +++ b/zig/pg/tests/root.srl @@ -0,0 +1 @@ +6D1D9130108F7309142D3EEAF56C85BE828BE077