diff --git a/README.md b/README.md index ba92173..47c4d15 100644 --- a/README.md +++ b/README.md @@ -208,12 +208,12 @@ The provider manages the WebSocket lifecycle. When another user inserts a row, y EdgePod helps you stay within Durable Object limits with lightweight, always-on guards: -| Guard | What it does | -| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Result limit** | Queries are capped at **1 000 rows**. If a query returns exactly 1 000 rows, a warning is logged to paginate with `.limit()` and `.offset()`. | -| **WHERE enforcement** | `UPDATE` and `DELETE` without a `.where()` clause are blocked. If you really mean to affect every row, chain `.withoutWhere()` to opt out per-query. | -| **Raw SQL guard** | Dangerous raw methods like `db.run()` and `db.get()` are blocked on the tracked database instance. Use `ctx.unsafeRawDb` explicitly if you need raw access, and call `ctx.invalidate()` manually. | -| **Bulk insert limit** | `insert().values()` arrays are capped at 1 000 rows to avoid oversized writes. | +| Guard | What it does | +| --------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Result limit** | Queries are capped at **1 000 rows**. If a query returns exactly 1 000 rows, a warning is logged to paginate with `.limit()` and `.offset()`. | +| **WHERE enforcement** | `UPDATE` and `DELETE` without a `.where()` clause are blocked. If you really mean to affect every row, chain `.withoutWhere()` to opt out per-query. | +| **Raw SQL guard** | Dangerous raw methods like `db.run()` and `db.get()` are blocked on the tracked database instance. Use `ctx.unsafeRawDb` explicitly if you need raw access — raw SQL is automatically tracked via SQL parsing. | +| **Bulk insert limit** | `insert().values()` arrays are capped at 1 000 rows to avoid oversized writes. | These are not configuration options — they are designed to catch accidental misuse early, while giving you explicit escape hatches when you need them. diff --git a/packages/server/package.json b/packages/server/package.json index 2f432b0..e2649d4 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -48,7 +48,8 @@ "@logtape/logtape": "^2.0.5", "drizzle-orm": "^0.45.2", "jose": "^6.2.3", - "neverthrow": "^8.2.0" + "neverthrow": "^8.2.0", + "sqlite3-parser": "^0.7.1" }, "devDependencies": { "@types/better-sqlite3": "^7.6.13", diff --git a/packages/server/src/server/do.test.ts b/packages/server/src/server/do.test.ts index 6fc95be..887e6b2 100644 --- a/packages/server/src/server/do.test.ts +++ b/packages/server/src/server/do.test.ts @@ -24,14 +24,18 @@ vi.mock("drizzle-orm/durable-sqlite/migrator", () => ({ migrate: vi.fn(), })); -vi.mock("../tools/createTrackedDb", () => ({ - createTrackedDb: vi.fn(), +vi.mock("../tools/createSafetyProxy", () => ({ + createSafetyProxy: vi.fn(), })); vi.mock("../tools/buildCascadeGraph", () => ({ buildCascadeGraph: vi.fn(() => new Map()), })); +vi.mock("../tools/buildPkMap", () => ({ + buildPkMap: vi.fn(() => new Map()), +})); + vi.mock("./auth", () => ({ initJwtSigner: vi.fn(async () => ({ match: vi.fn() })), getJwtSigner: vi.fn(() => null), diff --git a/packages/server/src/server/do.ts b/packages/server/src/server/do.ts index b6baf7a..a82370d 100644 --- a/packages/server/src/server/do.ts +++ b/packages/server/src/server/do.ts @@ -1,8 +1,10 @@ import { DurableObject } from "cloudflare:workers"; import { drizzle } from "drizzle-orm/durable-sqlite"; import { migrate } from "drizzle-orm/durable-sqlite/migrator"; -import { createTrackedDb } from "../tools/createTrackedDb"; +import { createSafetyProxy } from "../tools/createSafetyProxy"; +import { createTrackedRawDb } from "../tools/createTrackedRawDb"; import { buildCascadeGraph } from "../tools/buildCascadeGraph"; +import { buildPkMap } from "../tools/buildPkMap"; import { initJwtSigner, getJwtSigner } from "./auth"; import { initLogger, createLogger } from "./logger"; import type { EdgePodSessionMap, EdgePodContext, RpcRequest, RpcMeta, JsonValue } from "../types"; @@ -19,6 +21,7 @@ export class BaseEdgePodEngine extends DurableObject { private rawDb: ReturnType; private activeSessions: EdgePodSessionMap = new Map(); private cascadeGraph: Map> = new Map(); + private pkMap: Map = new Map(); protected userFunctions: Record Promise | JsonValue> = {}; protected schema: Record = {}; protected migrations: { @@ -43,6 +46,7 @@ export class BaseEdgePodEngine extends DurableObject { this.restoreActiveSessions(); this.cascadeGraph = buildCascadeGraph(this.schema); + this.pkMap = buildPkMap(this.schema); await initLogger(); await initJwtSigner(this.env as any).match( @@ -125,26 +129,31 @@ export class BaseEdgePodEngine extends DurableObject { // Prepare the read/write trackers for this specific run const tablesRead = new Set(); const tablesWritten = new Set(); + const rowIds = new Map>(); const warnings: string[] = []; // When reactive is false, pass an empty session map so no table subscriptions are registered - const sessionMap = reactive ? this.activeSessions : new Map(); + const activeSessions = reactive ? this.activeSessions : new Map(); - // Instantiate the Proxy - const dbProxy = createTrackedDb( - this.rawDb, + // Shared tracking context for this RPC invocation + const trackCtx = { sessionId, - sessionMap, + activeSessions, tablesRead, tablesWritten, - this.cascadeGraph, + cascadeGraph: this.cascadeGraph, warnings, - ); + rowIds, + pkMap: this.pkMap, + }; + + // Instantiate the Proxy with SQL parser-based tracking + const dbProxy = createSafetyProxy(this.rawDb, trackCtx); // Build the Context const edgepodCtx: EdgePodContext> = { db: dbProxy as any, - unsafeRawDb: this.rawDb, + unsafeRawDb: createTrackedRawDb(this.rawDb, trackCtx), user, env: this.env, headers, @@ -161,7 +170,6 @@ export class BaseEdgePodEngine extends DurableObject { }); } }, - invalidate: (tables: string[]) => tables.forEach((t: string) => tablesWritten.add(t)), set: (key: string, value: any) => variableStore.set(key, value), get: (key: string) => variableStore.get(key) as any, }; @@ -175,10 +183,18 @@ export class BaseEdgePodEngine extends DurableObject { this.broadcastInvalidations(hashedTableNames); } + const rowsMeta: Record = {}; + for (const [table, ids] of rowIds) { + rowsMeta[table] = [...ids]; + } return { success: true, data, - meta: { read: [...tablesRead], changed: [...tablesWritten] }, + meta: { + read: [...tablesRead], + changed: [...tablesWritten], + ...(Object.keys(rowsMeta).length > 0 ? { rows: rowsMeta } : {}), + }, warnings, }; } catch (e) { diff --git a/packages/server/src/server/index.ts b/packages/server/src/server/index.ts index 44f2340..105ec73 100644 --- a/packages/server/src/server/index.ts +++ b/packages/server/src/server/index.ts @@ -2,7 +2,7 @@ import pkg from "../../package.json" with { type: "json" }; import type { BaseEdgePodEngine } from "./do"; import type { RpcRequest } from "../types"; import { verifyJwt } from "./auth"; -import { hashMetaTableNames } from "../tools/hashTableName"; +import { hashMetaTableNames, hashTableName } from "../tools/hashTableName"; import { ResultAsync } from "neverthrow"; // EdgePod is origin-agnostic by design. Every request is authenticated via the @@ -140,11 +140,19 @@ export const edgePodFetch = async ( }); if (result.success) { + const rowsMeta = result.meta.rows + ? Object.fromEntries( + Object.entries(result.meta.rows).map(([t, ids]) => [hashTableName(t), ids]), + ) + : undefined; return Response.json( { success: true, data: result.data, - _meta: { t: hashMetaTableNames(result.meta.read) }, + _meta: { + t: hashMetaTableNames(result.meta.read), + ...(rowsMeta ? { r: rowsMeta } : {}), + }, ...(result.warnings.length > 0 ? { warnings: result.warnings } : {}), }, { headers: serverHeader }, diff --git a/packages/server/src/tools/buildPkMap.ts b/packages/server/src/tools/buildPkMap.ts new file mode 100644 index 0000000..261e267 --- /dev/null +++ b/packages/server/src/tools/buildPkMap.ts @@ -0,0 +1,13 @@ +import { getTableConfig } from "drizzle-orm/sqlite-core"; + +export function buildPkMap(schema: Record): Map { + const map = new Map(); + for (const key in schema) { + const table = schema[key]; + if (!table || !(table as any)[Symbol.for("drizzle:Name")]) continue; + const config = getTableConfig(table as any); + const pkCols = config.columns.filter((c: any) => c.primary).map((c: any) => c.name); + map.set(config.name, pkCols); + } + return map; +} diff --git a/packages/server/src/tools/checkResultWarnings.test.ts b/packages/server/src/tools/checkResultWarnings.test.ts deleted file mode 100644 index 213c389..0000000 --- a/packages/server/src/tools/checkResultWarnings.test.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { describe, it, expect } from "vitest"; -import { checkResultWarnings } from "./checkResultWarnings"; - -describe("checkResultWarnings", () => { - it("adds warning when result hits max limit", () => { - const warnings: string[] = []; - const result = Array(1000).fill({ id: 1 }); - - checkResultWarnings(result, warnings, 1000); - - expect(warnings).toHaveLength(1); - expect(warnings[0]).toContain("1000 rows"); - expect(warnings[0]).toContain("paginate"); - }); - - it("does not add warning when result is under limit", () => { - const warnings: string[] = []; - const result = Array(500).fill({ id: 1 }); - - checkResultWarnings(result, warnings, 1000); - - expect(warnings).toHaveLength(0); - }); - - it("does not add warning for non-array results", () => { - const warnings: string[] = []; - - checkResultWarnings({ id: 1 }, warnings, 1000); - - expect(warnings).toHaveLength(0); - }); - - it("does not add warning for empty array", () => { - const warnings: string[] = []; - - checkResultWarnings([], warnings, 1000); - - expect(warnings).toHaveLength(0); - }); -}); diff --git a/packages/server/src/tools/checkResultWarnings.ts b/packages/server/src/tools/checkResultWarnings.ts deleted file mode 100644 index 8586c91..0000000 --- a/packages/server/src/tools/checkResultWarnings.ts +++ /dev/null @@ -1,7 +0,0 @@ -export function checkResultWarnings(result: unknown, warnings: string[], maxLimit: number) { - if (Array.isArray(result) && result.length === maxLimit) { - warnings.push( - `Query returned exactly ${maxLimit} rows — there may be more results. Use .limit() and .offset() to paginate.`, - ); - } -} diff --git a/packages/server/src/tools/createBuilderProxy.ts b/packages/server/src/tools/createBuilderProxy.ts new file mode 100644 index 0000000..171a700 --- /dev/null +++ b/packages/server/src/tools/createBuilderProxy.ts @@ -0,0 +1,139 @@ +import type { TrackContext } from "./createSafetyProxy"; +import { trackExec, warnRowLimit } from "./tracking"; + +const STATE = Symbol("edgepod_builder_state"); +const EXEC = ["then", "run", "all", "get", "values", "execute"]; +const MAX_LIMIT = 1000; +const MAX_BULK_INSERT = 1000; + +type BuilderConfig = { + type: "select" | "insert" | "update" | "delete"; + tableName?: string; +}; + +export function createBuilderProxy(builder: any, ctx: TrackContext, config: BuilderConfig): any { + if (!builder[STATE]) { + builder[STATE] = { limitSet: false, whereSet: false, withoutWhereSet: false }; + } + + return new Proxy(builder, { + get(target: any, prop: string) { + const state: { limitSet: boolean; whereSet: boolean; withoutWhereSet: boolean } = target[ + STATE + ] || { limitSet: false, whereSet: false, withoutWhereSet: false }; + + // Safety: limit clamping for SELECT + if (prop === "limit" && config.type === "select") + return (n: number) => { + const clamped = Math.max(0, Math.min(n, MAX_LIMIT)); + if (n > MAX_LIMIT) ctx.warnings.push(`Query limit of ${n} overridden to ${MAX_LIMIT}.`); + return wrap(target.limit(clamped), ctx, config, { ...state, limitSet: true }); + }; + + // Safety: WHERE enforcement for UPDATE/DELETE + if (prop === "where" && (config.type === "update" || config.type === "delete")) + return (...args: unknown[]) => + wrap(target.where(...args), ctx, config, { ...state, whereSet: true }); + if (prop === "withoutWhere" && (config.type === "update" || config.type === "delete")) + return () => { + ctx.warnings.push(`[EdgePod] Unfiltered ${config.type} executed via .withoutWhere().`); + return wrap(target, ctx, config, { ...state, withoutWhereSet: true }); + }; + + // Safety: bulk insert limit + if (prop === "values" && config.type === "insert") + return (...args: unknown[]) => { + const rows = args[0]; + if (Array.isArray(rows) && rows.length > MAX_BULK_INSERT) + throw new Error( + `[EdgePod] Bulk insert blocked: ${rows.length} rows > ${MAX_BULK_INSERT}. Split into smaller batches.`, + ); + return wrap(target.values(...args), ctx, config, state); + }; + + // Safety: prepare + if (prop === "prepare" && config.type !== "select") + return () => { + throw new Error(`[EdgePod] .prepare() is not supported for ${config.type}s.`); + }; + if (prop === "prepare") + return (...args: unknown[]) => { + const b = state.limitSet ? target : target.limit(MAX_LIMIT); + return (b as any).prepare(...args); + }; + + // Execution: safety check → track → execute + if (EXEC.includes(prop)) + return (...args: unknown[]) => { + if ( + !state.whereSet && + !state.withoutWhereSet && + (config.type === "update" || config.type === "delete") + ) + throw new Error( + `[EdgePod] ${config.type.toUpperCase()} without WHERE is blocked. If intentional, chain .withoutWhere().`, + ); + + let b = target; + if (config.type === "select" && !state.limitSet) b = target.limit(MAX_LIMIT); + + trackExec(b, ctx, config.tableName, config.type); + + if (prop === "then") { + const [resolve, reject] = args as [(v: unknown) => void, (e: unknown) => void]; + return b.then((res: unknown) => { + warnRowLimit(res, ctx.warnings); + resolve(res); + }, reject); + } + + const method = b[prop] as Function; + const result = method.apply(b, args); + warnRowLimit(result, ctx.warnings); + return result; + }; + + // Generic pass-through: wrap if result is a builder + return passThrough(target, prop, ctx, config, state); + }, + }); +} + +function wrap( + result: any, + ctx: TrackContext, + config: BuilderConfig, + state: { limitSet: boolean; whereSet: boolean; withoutWhereSet: boolean }, +): any { + if (isBuilder(result)) { + result[STATE] = state; + return createBuilderProxy(result, ctx, config); + } + return result; +} + +function passThrough( + target: any, + prop: string, + ctx: TrackContext, + config: BuilderConfig, + state: { limitSet: boolean; whereSet: boolean; withoutWhereSet: boolean }, +): any { + const raw = target[prop]; + if (typeof raw === "function") + return (...args: unknown[]) => { + const result = raw.apply(target, args); + return isBuilder(result) + ? createBuilderProxy(Object.assign(result, { [STATE]: state }), ctx, config) + : result; + }; + if (isBuilder(raw)) { + raw[STATE] = state; + return createBuilderProxy(raw, ctx, config); + } + return raw; +} + +function isBuilder(v: unknown): boolean { + return typeof v === "object" && v !== null && !Array.isArray(v); +} diff --git a/packages/server/src/tools/createMutationProxy.test.ts b/packages/server/src/tools/createMutationProxy.test.ts deleted file mode 100644 index 8e16173..0000000 --- a/packages/server/src/tools/createMutationProxy.test.ts +++ /dev/null @@ -1,207 +0,0 @@ -import { describe, it, expect, vi } from "vitest"; -import { createMutationProxy } from "./createMutationProxy"; - -function createMockBuilder() { - const builder: Record = { - where: vi.fn(function () { - return createMockBuilder(); - }), - withoutWhere: vi.fn(function () { - return createMockBuilder(); - }), - run: vi.fn(function () { - return Promise.resolve({ changes: 1 }); - }), - all: vi.fn(function () { - return Promise.resolve([]); - }), - get: vi.fn(function () { - return Promise.resolve({ id: 1 }); - }), - values: vi.fn(function () { - return Promise.resolve([]); - }), - execute: vi.fn(function () { - return Promise.resolve(); - }), - returning: vi.fn(function () { - const inner = createMockBuilder(); - // oxlint-disable-next-line unicorn/no-thenable - inner.then = vi.fn(function (resolve: (v: unknown) => void, reject: (e: unknown) => void) { - return Promise.resolve({ changes: 1 }).then(resolve, reject); - }); - return inner; - }), - }; - - return builder; -} - -describe("createMutationProxy", () => { - it("blocks update without WHERE clause", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - expect(() => proxy.run()).toThrow("UPDATE without WHERE is blocked"); - }); - - it("blocks delete without WHERE clause", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "delete"); - - expect(() => proxy.run()).toThrow("DELETE without WHERE is blocked"); - }); - - it("allows update with WHERE clause", async () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - const withWhere = proxy.where({ id: 1 }); - const result = await withWhere.run(); - - expect(result).toEqual({ changes: 1 }); - }); - - it("allows delete with WHERE clause", async () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "delete"); - - const withWhere = proxy.where({ id: 1 }); - const result = await withWhere.run(); - - expect(result).toEqual({ changes: 1 }); - }); - - it("allows mutation with withoutWhere", async () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - const withoutWhere = proxy.withoutWhere(); - const result = await withoutWhere.run(); - - expect(result).toEqual({ changes: 1 }); - }); - - it("preserves proxy through chained methods", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - const chained = proxy.where({ id: 1 }); - expect(() => chained.run()).not.toThrow(); - }); - - it("blocks via .then() without WHERE", () => { - const builder = createMockBuilder(); - // oxlint-disable-next-line unicorn/no-thenable - builder.then = vi.fn(function () { - return Promise.resolve({ changes: 1 }); - }); - - const proxy = createMutationProxy(builder, [], "update"); - - expect(() => proxy.then(() => {})).toThrow("UPDATE without WHERE is blocked"); - }); - - it("allows via .then() with WHERE", async () => { - const builder = createMockBuilder(); - const whereResult = createMockBuilder(); - builder.where = vi.fn(() => whereResult); - // oxlint-disable-next-line unicorn/no-thenable - whereResult.then = vi.fn(function (resolve: (v: unknown) => void) { - resolve({ changes: 1 }); - return Promise.resolve({ changes: 1 }); - }); - - const proxy = createMutationProxy(builder, [], "update"); - const withWhere = proxy.where({ id: 1 }); - - const result = await new Promise((resolve) => { - withWhere.then(resolve); - }); - - expect(result).toEqual({ changes: 1 }); - }); - - const executionMethods = ["all", "get", "values", "execute"] as const; - - executionMethods.forEach((method) => { - it(`blocks .${method}() without WHERE on update`, () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - expect(() => proxy[method]()).toThrow("UPDATE without WHERE is blocked"); - }); - - it(`allows .${method}() with WHERE on update`, async () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - const withWhere = proxy.where({ id: 1 }); - - await withWhere[method](); - }); - }); - - it("returning proxy preserves WHERE guard", async () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - const withWhere = proxy.where({ id: 1 }); - const returning = withWhere.returning(); - await expect(returning.run()).resolves.toEqual({ changes: 1 }); - }); - - it("returns non-function values as-is", () => { - const builder = { ...createMockBuilder(), someProperty: 42 }; - const proxy = createMutationProxy(builder, [], "delete"); - - expect(proxy.someProperty).toBe(42); - }); - - it("includes mutation type in error message", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "delete"); - - expect(() => proxy.run()).toThrow("DELETE without WHERE is blocked"); - }); - - it("wraps the builder returned by .where() — not the original", async () => { - const builder = createMockBuilder(); - const whereResult = createMockBuilder(); - builder.where = vi.fn(() => whereResult); - - const proxy = createMutationProxy(builder, [], "update"); - const withWhere = proxy.where({ id: 1 }); - - await withWhere.run(); - expect(whereResult.run).toHaveBeenCalled(); - expect(builder.run).not.toHaveBeenCalled(); - }); - - it("wraps the builder returned by .withoutWhere() — not the original", async () => { - const builder = createMockBuilder(); - const withoutWhereResult = createMockBuilder(); - builder.withoutWhere = vi.fn(() => withoutWhereResult); - - const proxy = createMutationProxy(builder, [], "delete"); - const withoutWhere = proxy.withoutWhere(); - - await withoutWhere.run(); - expect(withoutWhereResult.run).toHaveBeenCalled(); - expect(builder.run).not.toHaveBeenCalled(); - }); - - it("original proxy remains blocked after calling .where() on a branch", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "update"); - - proxy.where({ id: 1 }); - expect(() => proxy.run()).toThrow("UPDATE without WHERE is blocked"); - }); - - it("original proxy remains blocked after calling .withoutWhere() on a branch", () => { - const builder = createMockBuilder(); - const proxy = createMutationProxy(builder, [], "delete"); - - proxy.withoutWhere(); - expect(() => proxy.run()).toThrow("DELETE without WHERE is blocked"); - }); -}); diff --git a/packages/server/src/tools/createMutationProxy.ts b/packages/server/src/tools/createMutationProxy.ts deleted file mode 100644 index 310d94d..0000000 --- a/packages/server/src/tools/createMutationProxy.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { recordMutationWithCascades } from "./recordMutation"; -import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; - -export function createMutationProxy( - builder: Record, - warnings: string[], - mutationType: "update" | "delete", - tableName?: string, - tablesWritten?: Set, - cascadeGraph?: Map>, - initialState = { whereSet: false, withoutWhereSet: false }, -): unknown { - const config: ProxyConfig = { - onMethod: { - where: (target, args, proxyState, factory) => { - return factory(target.where(...args), { ...proxyState, whereSet: true }); - }, - withoutWhere: (target, _args, proxyState, factory) => { - warnings.push(`[EdgePod] Unfiltered ${mutationType} executed via .withoutWhere().`); - return factory(target.withoutWhere(), { ...proxyState, withoutWhereSet: true }); - }, - }, - onExecute: (target, prop, args, proxyState) => { - if (prop === "prepare") { - throw new Error(`[EdgePod] .prepare() is not supported for ${mutationType}s.`); - } - if (!proxyState.whereSet && !proxyState.withoutWhereSet) { - throw new Error( - `[EdgePod] ${mutationType.toUpperCase()} without WHERE is blocked. If intentional, chain .withoutWhere().`, - ); - } - if (tableName && tableName !== "unknown" && tablesWritten) { - recordMutationWithCascades(tableName, tablesWritten, cascadeGraph ?? new Map()); - } - return target[prop](...args); - }, - }; - - return createQueryProxy(builder, initialState, config); -} diff --git a/packages/server/src/tools/createQueryProxy.ts b/packages/server/src/tools/createQueryProxy.ts deleted file mode 100644 index 373f451..0000000 --- a/packages/server/src/tools/createQueryProxy.ts +++ /dev/null @@ -1,57 +0,0 @@ -const EXECUTION_METHODS = ["run", "all", "get", "values", "execute", "prepare"]; - -type ProxyFactory = (builder: any, state: Record) => unknown; - -export type ProxyConfig = { - onMethod: Record< - string, - (target: any, args: unknown[], state: Record, factory: ProxyFactory) => unknown - >; - onExecute: ( - target: any, - prop: string, - args: unknown[], - state: Record, - ) => unknown; -}; - -export function createQueryProxy( - builder: any, - state: Record, - config: ProxyConfig, -): unknown { - const factory: ProxyFactory = (b, s) => createQueryProxy(b, s, config); - - return new Proxy(builder, { - get(target: any, prop: string) { - // 1. Specific method intercepts (.limit, .where, .values, etc.) - const methodHandler = config.onMethod[prop]; - if (methodHandler) { - return function (...args: unknown[]) { - return methodHandler(target, args, state, factory); - }; - } - - // 2. Execution intercepts (.then, .run, .all, etc.) - if (prop === "then" || EXECUTION_METHODS.includes(prop)) { - return function (...args: unknown[]) { - return config.onExecute(target, prop, args, state); - }; - } - - // 3. Generic builder-returning method wrap - const value = target[prop]; - if (typeof value === "function") { - return function (...args: unknown[]) { - const result = value.apply(target, args); - if (result && typeof result === "object" && typeof result.then === "function") { - return factory(result, { ...state }); - } - return result; - }; - } - - return value; - }, - }); -} diff --git a/packages/server/src/tools/createSafetyProxy.ts b/packages/server/src/tools/createSafetyProxy.ts new file mode 100644 index 0000000..b7432d4 --- /dev/null +++ b/packages/server/src/tools/createSafetyProxy.ts @@ -0,0 +1,109 @@ +import { getTableName } from "drizzle-orm"; +import { createBuilderProxy } from "./createBuilderProxy"; +import { addListener } from "./tracking"; +import type { EdgePodSessionMap } from "../types"; + +const FORBIDDEN = ["run", "all", "get", "values", "execute"]; + +export type TrackContext = { + sessionId: string; + activeSessions: EdgePodSessionMap; + tablesRead: Set; + tablesWritten: Set; + rowIds: Map>; + cascadeGraph: Map>; + warnings: string[]; + pkMap: Map; +}; + +export function createSafetyProxy(rawDb: any, ctx: TrackContext): any { + return new Proxy(rawDb, { + get(target: any, prop: string) { + if (FORBIDDEN.includes(prop)) + throw new Error( + `[EdgePod] Raw SQL via 'ctx.db.${prop}()' is blocked. Use ctx.db.select()/ctx.db.update() or ctx.unsafeRawDb.${prop}().`, + ); + + if (prop === "insert") + return (table: unknown, ...rest: unknown[]) => + createBuilderProxy(target.insert(table, ...rest), ctx, { + type: "insert", + tableName: getTableName(table as any) ?? "unknown", + }); + + if (prop === "update") + return (table: unknown, ...rest: unknown[]) => + createBuilderProxy(target.update(table, ...rest), ctx, { + type: "update", + tableName: getTableName(table as any) ?? "unknown", + }); + + if (prop === "delete") + return (table: unknown, ...rest: unknown[]) => + createBuilderProxy(target.delete(table, ...rest), ctx, { + type: "delete", + tableName: getTableName(table as any) ?? "unknown", + }); + + if (prop === "select" || prop === "selectDistinct") + return (...args: unknown[]) => + createBuilderProxy(target[prop](...args), ctx, { type: "select" }); + + if (prop === "query") { + const q = target.query; + return q ? createQueryApiProxy(q, ctx) : undefined; + } + + const v = target[prop]; + return typeof v === "function" ? v.bind(target) : v; + }, + }); +} + +function createQueryApiProxy(queryObject: any, ctx: TrackContext) { + return new Proxy(queryObject, { + get(_target: any, tableProp: string) { + const tableApi = (_target as any)[tableProp]; + if (!tableApi) return undefined; + return new Proxy(tableApi, { + get(t: any, method: string) { + if (method === "findMany") + return (opts: Record = {}) => { + addListener(tableProp, ctx); + trackWithRelations(opts, ctx); + const limit = + typeof opts.limit === "number" && Number.isFinite(opts.limit) + ? Math.max(0, Math.min(opts.limit, 1000)) + : 1000; + if (typeof opts.limit === "number" && opts.limit > 1000) + ctx.warnings.push(`Query limit of ${opts.limit} overridden to 1000.`); + return t.findMany({ ...opts, limit }).then((res: unknown[]) => { + if (Array.isArray(res) && res.length === 1000) + ctx.warnings.push( + "Query returned exactly 1000 rows — there may be more results. Use .limit() and .offset() to paginate.", + ); + return res; + }); + }; + if (method === "findFirst") + return (opts: Record = {}) => { + addListener(tableProp, ctx); + trackWithRelations(opts, ctx); + return t.findFirst(opts); + }; + const v = t[method]; + return typeof v === "function" ? v.bind(t) : v; + }, + }); + }, + }); +} + +function trackWithRelations(opts: Record, ctx: TrackContext) { + const withOpt = opts.with as Record | undefined; + if (!withOpt) return; + for (const relation of Object.keys(withOpt)) { + addListener(relation, ctx); + trackWithRelations(withOpt[relation] as Record, ctx); + } +} diff --git a/packages/server/src/tools/createSelectProxy.test.ts b/packages/server/src/tools/createSelectProxy.test.ts deleted file mode 100644 index 7b9455f..0000000 --- a/packages/server/src/tools/createSelectProxy.test.ts +++ /dev/null @@ -1,216 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; -import { createSelectProxy } from "./createSelectProxy"; -import { hashTableName } from "./hashTableName"; -import type { EdgePodSessionMap } from "../types"; - -vi.mock("drizzle-orm", () => ({ - getTableName: vi.fn((t: { name?: string } | null) => t?.name ?? "unknown"), -})); - -function createMockBuilder( - options: { resultData?: Record[]; limit?: number } = {}, -) { - const { resultData = [], limit: initialLimit } = options; - let currentLimit = initialLimit; - - const builder: Record = { - limit: vi.fn(function (n: number) { - currentLimit = n; - const opts: { resultData: Record[]; limit: number } = { - resultData, - limit: n, - }; - return createMockBuilder(opts); - }), - where: vi.fn(function () { - return builder; - }), - from: vi.fn(function () { - return builder; - }), - leftJoin: vi.fn(function (_table: unknown) { - const opts: { resultData: Record[]; limit?: number } = { resultData }; - if (currentLimit !== undefined) opts.limit = currentLimit; - return createMockBuilder(opts); - }), - innerJoin: vi.fn(function (_table: unknown) { - const opts: { resultData: Record[]; limit?: number } = { resultData }; - if (currentLimit !== undefined) opts.limit = currentLimit; - return createMockBuilder(opts); - }), - rightJoin: vi.fn(function () { - return builder; - }), - fullJoin: vi.fn(function () { - return builder; - }), - // oxlint-disable-next-line unicorn/no-thenable - then: vi.fn(function (resolve: (v: unknown) => void, reject: (e: unknown) => void) { - const finalLimit = currentLimit ?? 1000; - return Promise.resolve(resultData.slice(0, finalLimit)).then(resolve, reject); - }), - }; - - return builder; -} - -function createMockJoinTable() { - return { name: "joined_table" }; -} - -describe("createSelectProxy", () => { - let tablesRead: Set; - let warnings: string[]; - let activeSessions: EdgePodSessionMap; - const sessionId = "test-session"; - - beforeEach(() => { - tablesRead = new Set(); - warnings = []; - activeSessions = new Map(); - activeSessions.set(sessionId, { - socket: {} as WebSocket, - listeningToTables: new Set(), - }); - }); - - it("auto-applies max limit when none set", async () => { - const builder = createMockBuilder({ resultData: Array(2000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const result = await proxy; - - expect(result).toHaveLength(1000); - }); - - it("respects user-set limit under max", async () => { - const builder = createMockBuilder({ resultData: Array(100).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const withLimit = proxy.limit(50); - const result = await withLimit; - - expect(result).toHaveLength(50); - }); - - it("caps limit at max", async () => { - const builder = createMockBuilder({ resultData: Array(5000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const withLimit = proxy.limit(5000); - const result = await withLimit; - - expect(result).toHaveLength(1000); - }); - - it("adds warning when user limit exceeds max", async () => { - const builder = createMockBuilder({ resultData: [] }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - await proxy.limit(5000); - - expect(warnings).toHaveLength(1); - expect(warnings[0]).toContain("5000"); - expect(warnings[0]).toContain("1000"); - }); - - it("tracks table reads on join methods", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.leftJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on innerJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.innerJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("adds warning when result hits max limit", async () => { - const builder = createMockBuilder({ resultData: Array(1000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - await proxy; - - expect(warnings).toHaveLength(1); - expect(warnings[0]).toContain("1000 rows"); - expect(warnings[0]).toContain("paginate"); - }); - - it("does not add warning when result is under limit", async () => { - const builder = createMockBuilder({ resultData: Array(100).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - await proxy; - - expect(warnings).toHaveLength(0); - }); - - it("tracks table reads on rightJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.rightJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on fullJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.fullJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on from", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.from(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("registers listening tables on session for joins", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.leftJoin(joinTable, {}); - - const session = activeSessions.get(sessionId); - expect(session?.listeningToTables.has(hashTableName("joined_table"))).toBe(true); - }); - - it("preserves proxy through chained method calls", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const withWhere = proxy.where({ id: 1 }); - expect(withWhere).toBeDefined(); - expect(typeof withWhere.then).toBe("function"); - }); - - it("original proxy still applies max limit after .limit() on a branch", async () => { - const builder = createMockBuilder({ resultData: Array(2000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - proxy.limit(50); - const result = await proxy; - - expect(result).toHaveLength(1000); - }); -}); diff --git a/packages/server/src/tools/createSelectProxy.ts b/packages/server/src/tools/createSelectProxy.ts deleted file mode 100644 index fc3f24d..0000000 --- a/packages/server/src/tools/createSelectProxy.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { getTableName } from "drizzle-orm"; -import { EdgePodSessionMap } from "../types"; -import { checkResultWarnings } from "./checkResultWarnings"; -import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; -import { hashTableName } from "./hashTableName"; - -function trackTable( - table: unknown, - tablesRead: Set, - activeSessions: EdgePodSessionMap, - sessionId: string, -) { - const tableName = getTableName(table as any) ?? "unknown"; - if (tableName === "unknown") return; - const session = activeSessions.get(sessionId); - if (session) session.listeningToTables.add(hashTableName(tableName)); - tablesRead.add(tableName); -} - -export function createSelectProxy( - builder: Record, - sessionId: string, - activeSessions: EdgePodSessionMap, - tablesRead: Set, - warnings: string[], - maxLimit: number, -): unknown { - const config: ProxyConfig = { - onMethod: { - limit: (target, args, state, factory) => { - const n = args[0] as number; - if (n > maxLimit) { - warnings.push(`Query limit of ${n} overridden to ${maxLimit}.`); - } - const clamped = Math.max(0, Math.min(n, maxLimit)); - return factory(target.limit(clamped), { ...state, limitSet: true }); - }, - from: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.from(...args), { ...state }); - }, - leftJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.leftJoin(...args), { ...state }); - }, - innerJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.innerJoin(...args), { ...state }); - }, - rightJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.rightJoin(...args), { ...state }); - }, - fullJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.fullJoin(...args), { ...state }); - }, - crossJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.crossJoin(...args), { ...state }); - }, - }, - onExecute: (target, prop, args, state) => { - const finalBuilder = state.limitSet ? target : target.limit(maxLimit); - if (prop === "then") { - const [resolve, reject] = args; - return finalBuilder.then((result: unknown[]) => { - checkResultWarnings(result, warnings, maxLimit); - return (resolve as (v: unknown) => void)(result); - }, reject); - } - if (prop === "prepare") { - return finalBuilder[prop](...args); - } - const result = finalBuilder[prop](...args); - checkResultWarnings(result, warnings, maxLimit); - return result; - }, - }; - - return createQueryProxy(builder, { limitSet: false }, config); -} diff --git a/packages/server/src/tools/createTrackedDb.test.ts b/packages/server/src/tools/createTrackedDb.test.ts deleted file mode 100644 index a573028..0000000 --- a/packages/server/src/tools/createTrackedDb.test.ts +++ /dev/null @@ -1,295 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; -import { createTrackedDb } from "./createTrackedDb"; -import { hashTableName } from "./hashTableName"; -import type { EdgePodSessionMap, RawDrizzleDb } from "../types"; - -vi.mock("drizzle-orm", () => ({ - getTableName: vi.fn((t: { name?: string } | null) => t?.name ?? "unknown"), -})); - -function createUpdateBuilder() { - return { - set: vi.fn(function () { - return { - where: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - withoutWhere: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - run: vi.fn(() => Promise.resolve({ changes: 1 })), - }; - }), - where: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - withoutWhere: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - run: vi.fn(() => Promise.resolve({ changes: 1 })), - }; -} - -function createDeleteBuilder() { - return { - where: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - withoutWhere: vi.fn(function () { - return { run: vi.fn(() => Promise.resolve({ changes: 1 })) }; - }), - run: vi.fn(() => Promise.resolve({ changes: 1 })), - }; -} - -function createInsertBuilder() { - const builder: Record = { - values: vi.fn(function () { - return builder; - }), - // oxlint-disable-next-line unicorn/no-thenable - then: vi.fn(function (resolve: (v: unknown) => void, reject: (e: unknown) => void) { - return Promise.resolve({ inserted: true }).then(resolve, reject); - }), - }; - return builder; -} - -function createMockDb() { - const db: Record = { - select: vi.fn(() => createSelectBuilder()), - selectDistinct: vi.fn(() => createSelectBuilder()), - insert: vi.fn((_table: unknown) => createInsertBuilder()), - update: vi.fn((_table: unknown) => createUpdateBuilder()), - delete: vi.fn((_table: unknown) => createDeleteBuilder()), - query: { - users: createQueryTableApi("users"), - posts: createQueryTableApi("posts"), - }, - run: vi.fn(() => Promise.resolve()), - all: vi.fn(() => Promise.resolve([])), - get: vi.fn(() => Promise.resolve(null)), - values: vi.fn(() => Promise.resolve([])), - execute: vi.fn(() => Promise.resolve()), - }; - - return db; -} - -function createSelectBuilder() { - const builder: Record = { - limit: vi.fn(function () { - return builder; - }), - where: vi.fn(function () { - return builder; - }), - from: vi.fn(function (_table: unknown) { - return builder; - }), - // oxlint-disable-next-line unicorn/no-thenable - then: vi.fn(function (resolve: (v: unknown) => void) { - resolve([{ id: 1 }]); - return Promise.resolve([{ id: 1 }]); - }), - }; - return builder; -} - -function createQueryTableApi(tableName: string) { - return { - findMany: vi.fn(function (opts: Record = {}) { - const limit = (opts.limit as number) ?? 1000; - return Promise.resolve(Array(Math.min(limit, 10)).fill({ id: 1, table: tableName })); - }), - findFirst: vi.fn(function () { - return Promise.resolve({ id: 1, table: tableName }); - }), - }; -} - -describe("createTrackedDb", () => { - let tablesRead: Set; - let tablesWritten: Set; - let warnings: string[]; - let activeSessions: EdgePodSessionMap; - const sessionId = "test-session"; - - beforeEach(() => { - tablesRead = new Set(); - tablesWritten = new Set(); - warnings = []; - activeSessions = new Map(); - activeSessions.set(sessionId, { - socket: {} as WebSocket, - listeningToTables: new Set(), - }); - }); - - function createProxy(cascadeGraph?: Map>) { - const mockDb = createMockDb(); - const proxy = createTrackedDb( - mockDb as unknown as RawDrizzleDb, - sessionId, - activeSessions, - tablesRead, - tablesWritten, - cascadeGraph ?? new Map(), - warnings, - ); - return { proxy, mockDb }; - } - - it("blocks raw SQL methods", () => { - const { proxy } = createProxy(); - - expect(() => (proxy as any).run()).toThrow("ctx.db.run"); - expect(() => (proxy as any).all()).toThrow("ctx.db.all"); - expect(() => (proxy as any).get()).toThrow("ctx.db.get"); - expect(() => (proxy as any).values()).toThrow("ctx.db.values"); - expect(() => (proxy as any).execute()).toThrow("ctx.db.execute"); - }); - - it("tracks insert as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).insert(usersTable).values({ name: "test" }); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks update as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).update(usersTable).set({ name: "updated" }).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks delete as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).delete(usersTable).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("propagates cascades on delete", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts", "comments"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).delete(usersTable).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(true); - expect(tablesWritten.has("comments")).toBe(true); - }); - - it("does not propagate cascades on insert", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).insert(usersTable).values({ name: "test" }); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(false); - }); - - it("does not propagate cascades on update", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).update(usersTable).set({ name: "updated" }).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(false); - }); - - it("tracks select as table read via query.findMany", async () => { - const { proxy } = createProxy(); - - await (proxy as any).query.users.findMany(); - - expect(tablesRead.has("users")).toBe(true); - }); - - it("tracks select as table read via query.findFirst", async () => { - const { proxy } = createProxy(); - - await (proxy as any).query.users.findFirst(); - - expect(tablesRead.has("users")).toBe(true); - }); - - it("registers listening tables on session via query.findMany", async () => { - const { proxy } = createProxy(); - - await (proxy as any).query.users.findMany(); - - const session = activeSessions.get(sessionId); - expect(session?.listeningToTables.has(hashTableName("users"))).toBe(true); - }); - - it("applies limit to query.findMany", async () => { - const { proxy } = createProxy(); - - const result = await (proxy as any).query.users.findMany({ limit: 5 }); - - expect(Array.isArray(result)).toBe(true); - }); - - it("caps query.findMany limit at max", async () => { - const { proxy, mockDb } = createProxy(); - - await (proxy as any).query.users.findMany({ limit: 5000 }); - - expect(warnings).toHaveLength(1); - expect(warnings[0]).toContain("5000"); - expect(warnings[0]).toContain("1000"); - expect(mockDb.query.users.findMany).toHaveBeenCalledWith( - expect.objectContaining({ limit: 1000 }), - ); - }); - - it("blocks bulk insert exceeding max limit", () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - const rows = Array(1001).fill({ name: "test" }); - - expect(() => (proxy as any).insert(usersTable).values(rows)).toThrow("Bulk insert blocked"); - }); - - it("allows bulk insert at or under max limit", () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - const rows = Array(1000).fill({ name: "test" }); - - expect(() => (proxy as any).insert(usersTable).values(rows)).not.toThrow(); - }); - - it("delegates selectDistinct to select proxy", () => { - const { proxy } = createProxy(); - - const builder = (proxy as any).selectDistinct(); - expect(typeof builder.then).toBe("function"); - }); - - it("binds non-tracked methods to target", () => { - const { proxy } = createProxy(); - - const existingMethod = (proxy as any).select; - expect(typeof existingMethod).toBe("function"); - }); -}); diff --git a/packages/server/src/tools/createTrackedDb.ts b/packages/server/src/tools/createTrackedDb.ts deleted file mode 100644 index 15e77ea..0000000 --- a/packages/server/src/tools/createTrackedDb.ts +++ /dev/null @@ -1,197 +0,0 @@ -import { getTableName } from "drizzle-orm"; -import { RawDrizzleDb, EdgePodSessionMap } from "../types"; -import { checkResultWarnings } from "./checkResultWarnings"; -import { createSelectProxy } from "./createSelectProxy"; -import { createMutationProxy } from "./createMutationProxy"; -import { hashTableName } from "./hashTableName"; -import { recordMutationWithCascades } from "./recordMutation"; -import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; - -const FORBIDDEN_RAW_METHODS = ["run", "all", "get", "values", "execute"]; -const MAX_LIMIT = 1000; - -function createInsertProxy( - builder: Record, - maxLimit: number, - tableName: string, - tablesWritten: Set, -): unknown { - const config: ProxyConfig = { - onMethod: { - values: (target, args, _state, factory) => { - const rows = args[0] as unknown[]; - if (Array.isArray(rows) && rows.length > maxLimit) { - throw new Error( - `[EdgePod] Bulk insert blocked: ${rows.length} rows > ${maxLimit}. Split into smaller batches.`, - ); - } - return factory(target.values(rows), {}); - }, - }, - onExecute: (target, prop, args) => { - if (prop === "prepare") { - throw new Error("[EdgePod] .prepare() is not supported for inserts."); - } - if (tableName !== "unknown") { - recordMutationWithCascades(tableName, tablesWritten, new Map()); - } - return target[prop](...args); - }, - }; - - return createQueryProxy(builder, {}, config); -} - -function createUpdateBuilderProxy( - builder: Record, - warnings: string[], - tableName: string, - tablesWritten: Set, -): unknown { - const config: ProxyConfig = { - onMethod: { - set: (target, args, _state, _factory) => { - const base = target.set(...args); - return createMutationProxy(base, warnings, "update", tableName, tablesWritten); - }, - }, - onExecute: (target, prop, args) => target[prop](...args), - }; - - return createQueryProxy(builder, {}, config); -} - -/** - * Wraps a Drizzle instance in a Proxy to automatically track which tables - * are read (for subscriptions) and written (for invalidations). - */ -export function createTrackedDb>( - realDb: RawDrizzleDb, - sessionId: string, - activeSessions: EdgePodSessionMap, - tablesRead: Set, - tablesWritten: Set, - cascadeGraph: Map>, - warnings: string[], -): unknown { - return new Proxy(realDb as any, { - get(target: any, prop: string) { - if (FORBIDDEN_RAW_METHODS.includes(prop)) { - throw new Error( - `[EdgePod] Raw SQL via 'ctx.db.${prop}()' is blocked. Use ctx.db.select()/ctx.db.update(). ` + - `For raw SQL, use ctx.unsafeRawDb.${prop}() and call ctx.invalidate() manually.`, - ); - } - - if (prop === "insert") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createInsertProxy(builder, MAX_LIMIT, tableName, tablesWritten); - }; - } - - if (prop === "update") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createUpdateBuilderProxy(builder, warnings, tableName, tablesWritten); - }; - } - - if (prop === "delete") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createMutationProxy( - builder, - warnings, - "delete", - tableName, - tablesWritten, - cascadeGraph, - ); - }; - } - - if (prop === "query") { - const queryObject = target.query; - if (!queryObject) return undefined; - return new Proxy(queryObject, { - get(queryTarget: any, tableProp: string) { - const tableApi = queryTarget[tableProp]; - if (!tableApi) return undefined; - const session = activeSessions.get(sessionId); - if (session) session.listeningToTables.add(hashTableName(tableProp)); - tablesRead.add(tableProp); - return new Proxy(tableApi, { - get(tableTarget: any, method: string) { - if (method === "findMany") { - return function (opts: Record = {}) { - const limit = - typeof opts.limit === "number" && Number.isFinite(opts.limit) - ? Math.max(0, Math.min(opts.limit, MAX_LIMIT)) - : MAX_LIMIT; - if (typeof opts.limit === "number" && opts.limit > MAX_LIMIT) { - warnings.push(`Query limit of ${opts.limit} overridden to ${MAX_LIMIT}.`); - } - trackWithRelations(opts, tablesRead, activeSessions, sessionId); - return tableTarget.findMany({ ...opts, limit }).then((result: unknown[]) => { - checkResultWarnings(result, warnings, MAX_LIMIT); - return result; - }); - }; - } - if (method === "findFirst") { - return function (opts: Record = {}) { - trackWithRelations(opts, tablesRead, activeSessions, sessionId); - return tableTarget.findFirst(opts); - }; - } - const value = tableTarget[method]; - return typeof value === "function" ? value.bind(tableTarget) : value; - }, - }); - }, - }); - } - - if (prop === "select" || prop === "selectDistinct") { - return function (...args: unknown[]) { - return createSelectProxy( - target[prop].apply(target, args), - sessionId, - activeSessions, - tablesRead, - warnings, - MAX_LIMIT, - ); - }; - } - - const value = target[prop]; - return typeof value === "function" ? value.bind(target) : value; - }, - }); -} - -function trackWithRelations( - opts: Record, - tablesRead: Set, - activeSessions: EdgePodSessionMap, - sessionId: string, -) { - const withOpt = opts.with as Record | undefined; - if (!withOpt) return; - for (const relation of Object.keys(withOpt)) { - const session = activeSessions.get(sessionId); - if (session) session.listeningToTables.add(hashTableName(relation)); - tablesRead.add(relation); - trackWithRelations( - withOpt[relation] as Record, - tablesRead, - activeSessions, - sessionId, - ); - } -} diff --git a/packages/server/src/tools/createTrackedRawDb.ts b/packages/server/src/tools/createTrackedRawDb.ts new file mode 100644 index 0000000..af3473c --- /dev/null +++ b/packages/server/src/tools/createTrackedRawDb.ts @@ -0,0 +1,37 @@ +import { trackExec } from "./tracking"; +import type { TrackContext } from "./createSafetyProxy"; + +const RAW_EXEC = ["run", "all", "get", "values", "execute"]; + +export function createTrackedRawDb(rawDb: any, ctx: TrackContext): any { + return new Proxy(rawDb, { + get(target: any, prop: string) { + if (RAW_EXEC.includes(prop)) + return (...args: unknown[]) => { + let sql: string; + let params: unknown[]; + + const firstArg = args[0]; + if ( + firstArg && + typeof firstArg === "object" && + "queryChunks" in firstArg && + target.dialect?.sqlToQuery + ) { + const query = target.dialect.sqlToQuery(firstArg); + sql = query.sql; + params = query.params; + } else { + sql = String(firstArg); + params = args.slice(1).filter((p) => p !== undefined); + } + + trackExec({ toSQL: () => ({ sql, params }) }, ctx); + const method = target[prop] as Function; + return method.call(target, ...args); + }; + const v = target[prop]; + return typeof v === "function" ? v.bind(target) : v; + }, + }); +} diff --git a/packages/server/src/tools/parseSqlTracking.test.ts b/packages/server/src/tools/parseSqlTracking.test.ts new file mode 100644 index 0000000..6623de4 --- /dev/null +++ b/packages/server/src/tools/parseSqlTracking.test.ts @@ -0,0 +1,298 @@ +import { describe, it, expect } from "vitest"; +import { parseSqlTracking } from "./parseSqlTracking"; + +describe("parseSqlTracking — query type", () => { + it("detects SELECT", () => { + const r = parseSqlTracking('select "id" from "users"', []); + expect(r.queryType).toBe("select"); + }); + + it("detects INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.queryType).toBe("insert"); + }); + + it("detects UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "id" = ?', ["new", 1]); + expect(r.queryType).toBe("update"); + }); + + it("detects DELETE", () => { + const r = parseSqlTracking('delete from "users" where "id" = ?', [1]); + expect(r.queryType).toBe("delete"); + }); + + it("returns 'unknown' for garbage SQL", () => { + const r = parseSqlTracking("not sql at all", []); + expect(r.queryType).toBe("unknown"); + expect(r.tablesRead).toEqual([]); + expect(r.tablesWritten).toEqual([]); + expect(r.whereIds).toEqual([]); + }); +}); + +describe("parseSqlTracking — table extraction", () => { + it("extracts table from SELECT", () => { + const r = parseSqlTracking('select * from "users"', []); + expect(r.tablesRead).toEqual(["users"]); + }); + + it("extracts multiple tables from JOIN", () => { + const r = parseSqlTracking( + 'select * from "users" left join "posts" on "posts"."user_id" = "users"."id"', + [], + ); + expect(r.tablesRead).toEqual(["users", "posts"]); + }); + + it("extracts tables from comma-joined FROM", () => { + const r = parseSqlTracking('select * from "users", "posts"', []); + expect(r.tablesRead).toEqual(["users", "posts"]); + }); + + it("extracts table from INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("extracts table from UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ?', ["test"]); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("extracts table from DELETE", () => { + const r = parseSqlTracking('delete from "users"', []); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("detects subquery tables in WHERE IN (SELECT)", () => { + const r = parseSqlTracking( + 'select * from "users" where "id" in (select "user_id" from "banned")', + [], + ); + expect(r.tablesRead).toEqual(["users", "banned"]); + }); + + it("detects subquery tables in WHERE EXISTS", () => { + const r = parseSqlTracking( + 'select * from "users" where exists (select 1 from "orders" where "orders"."user_id" = "users"."id")', + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("orders"); + expect(r.tablesRead).toHaveLength(2); + }); + + it("detects subquery tables in INSERT...SELECT", () => { + const r = parseSqlTracking('insert into "logs" select * from "old_logs"', []); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toEqual(["old_logs"]); + }); + + it("detects subquery tables in UPDATE with scalar subquery", () => { + const r = parseSqlTracking( + 'update "users" set "name" = (select "name" from "profiles" where "profiles"."id" = "users"."id")', + [], + ); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.tablesRead).toEqual(["profiles"]); + }); + + it("detects subquery tables in DELETE with IN subquery", () => { + const r = parseSqlTracking( + 'delete from "users" where "id" in (select "id" from "banned_users")', + [], + ); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.tablesRead).toEqual(["banned_users"]); + }); + + it("detects multiple subquery tables with JOIN in outer query", () => { + const r = parseSqlTracking( + 'select * from "users" join "posts" on "users"."id" = "posts"."user_id" where "posts"."id" in (select "post_id" from "comments")', + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("posts"); + expect(r.tablesRead).toContain("comments"); + expect(r.tablesRead).toHaveLength(3); + }); + + it("does not track subquery tables as reads for write ops when same as main table", () => { + const r = parseSqlTracking('insert into "logs" select * from "logs"', []); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toEqual([]); + }); +}); + +describe("parseSqlTracking — CTEs, compounds, subqueries", () => { + it("excludes CTE aliases from tablesRead", () => { + const r = parseSqlTracking("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", []); + expect(r.tablesRead).toEqual(["users"]); + }); + + it("excludes CTE aliases from JOIN in main query", () => { + const r = parseSqlTracking( + "WITH a AS (SELECT * FROM users) SELECT * FROM a JOIN b ON a.id = b.id", + [], + ); + expect(r.tablesRead).toEqual(["users", "b"]); + }); + + it("collects tables from scalar subquery in SELECT list (no outer FROM)", () => { + const r = parseSqlTracking("SELECT (SELECT name FROM profiles LIMIT 1) AS n", []); + expect(r.tablesRead).toEqual(["profiles"]); + }); + + it("collects tables from scalar subquery in SELECT list (with outer FROM)", () => { + const r = parseSqlTracking("SELECT (SELECT count FROM stats) AS c FROM users", []); + expect(r.tablesRead).toEqual(["users", "stats"]); + }); + + it("collects all table references from UNION", () => { + const r = parseSqlTracking("SELECT * FROM users UNION SELECT * FROM admins", []); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("admins"); + expect(r.tablesRead).toHaveLength(2); + }); + + it("collects tables from subquery in HAVING", () => { + const r = parseSqlTracking( + "SELECT user_id, count(*) FROM orders GROUP BY user_id HAVING count(*) > (SELECT avg(count) FROM stats)", + [], + ); + expect(r.tablesRead).toContain("orders"); + expect(r.tablesRead).toContain("stats"); + }); + + it("collects tables from subquery in ORDER BY", () => { + const r = parseSqlTracking( + "SELECT * FROM users ORDER BY (SELECT score FROM leaderboard LIMIT 1)", + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("leaderboard"); + }); + + it("collects tables from derived table in JOIN", () => { + const r = parseSqlTracking( + "SELECT * FROM users JOIN (SELECT * FROM posts) AS p ON users.id = p.user_id", + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("posts"); + }); + + it("collects tables from deeply nested subqueries", () => { + const r = parseSqlTracking( + "SELECT * FROM users WHERE id IN (SELECT id FROM (SELECT * FROM banned))", + [], + ); + expect(r.tablesRead).toEqual(["users", "banned"]); + }); + + it("collects subquery tables for INSERT...SELECT with compound", () => { + const r = parseSqlTracking( + "INSERT INTO logs SELECT * FROM old_logs UNION SELECT * FROM archived_logs", + [], + ); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toContain("old_logs"); + expect(r.tablesRead).toContain("archived_logs"); + }); + + it("collects CTE tables for write ops with subquery", () => { + const r = parseSqlTracking( + "WITH cte AS (SELECT * FROM source) INSERT INTO target SELECT * FROM cte", + [], + ); + expect(r.tablesWritten).toEqual(["target"]); + expect(r.tablesRead).toEqual(["source"]); + }); +}); + +describe("parseSqlTracking — WHERE ID extraction", () => { + it("extracts simple WHERE id = ?", () => { + const r = parseSqlTracking('select * from "users" where "users"."id" = ?', [42]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([0]); + }); + + it("extracts WHERE id = ? from UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "users"."id" = ?', ["new", 1]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([1]); + }); + + it("extracts WHERE id IN (?, ?)", () => { + const r = parseSqlTracking('select * from "users" where "users"."id" in (?, ?)', [1, 2]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([0, 1]); + }); + + it("handles unqualified column in WHERE", () => { + const r = parseSqlTracking('select * from "users" where "id" = ?', [42]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe(""); + expect(r.whereIds[0].column).toBe("id"); + }); + + it("returns empty whereIds for queries without WHERE", () => { + const r = parseSqlTracking('select * from "users"', []); + expect(r.whereIds).toEqual([]); + }); + + it("returns empty whereIds for INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.whereIds).toEqual([]); + }); +}); + +describe("parseSqlTracking — raw SQL (uppercase)", () => { + it("parses uppercase SELECT", () => { + const r = parseSqlTracking("SELECT * FROM users WHERE users.id = ?", [1]); + expect(r.queryType).toBe("select"); + expect(r.tablesRead).toEqual(["users"]); + expect(r.whereIds).toHaveLength(1); + }); + + it("parses uppercase UPDATE", () => { + const r = parseSqlTracking("UPDATE users SET name = ? WHERE users.id = ?", ["new", 1]); + expect(r.queryType).toBe("update"); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.whereIds).toHaveLength(1); + }); +}); + +describe("parseSqlTracking — edge cases", () => { + it("handles multiple AND conditions in WHERE", () => { + const r = parseSqlTracking( + 'select * from "users" where "users"."id" = ? and "users"."name" = ?', + [1, "test"], + ); + expect(r.whereIds).toHaveLength(2); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[1].column).toBe("name"); + }); + + it("ignores params in SET clause (not WHERE)", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "users"."id" = ?', ["new", 1]); + // Only the WHERE param should be extracted, not the SET param + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].column).toBe("id"); + }); + + it("returns empty for parse errors", () => { + const r = parseSqlTracking("not valid sqlite", []); + expect(r.queryType).toBe("unknown"); + expect(r.tablesRead).toEqual([]); + expect(r.tablesWritten).toEqual([]); + expect(r.whereIds).toEqual([]); + }); +}); diff --git a/packages/server/src/tools/parseSqlTracking.ts b/packages/server/src/tools/parseSqlTracking.ts new file mode 100644 index 0000000..6f2ceb8 --- /dev/null +++ b/packages/server/src/tools/parseSqlTracking.ts @@ -0,0 +1,165 @@ +import { parseStmt, traverse } from "sqlite3-parser"; + +export type ParsedQuery = { + queryType: "select" | "insert" | "update" | "delete" | "unknown"; + tablesRead: string[]; + tablesWritten: string[]; + whereIds: Array<{ tableHint: string; column: string; paramIndices: number[] }>; +}; + +function getTableName(node: any): string | null { + if (!node) return null; + if (node.tblName?.objName?.text) return node.tblName.objName.text; + if (node.objName?.text) return node.objName.text; + return null; +} + +export function parseSqlTracking(sql: string, params: unknown[]): ParsedQuery { + const result = parseStmt(sql); + if (result.status === "error") { + return { queryType: "unknown", tablesRead: [], tablesWritten: [], whereIds: [] }; + } + + const root = result.root as any; + let queryType: ParsedQuery["queryType"] = "unknown"; + let mainTable: string | null = null; + + if (root.type === "SelectStmt") { + queryType = "select"; + } else if (root.type === "InsertStmt") { + queryType = "insert"; + mainTable = getTableName(root); + } else if (root.type === "UpdateStmt") { + queryType = "update"; + mainTable = getTableName(root); + } else if (root.type === "DeleteStmt") { + queryType = "delete"; + mainTable = getTableName(root); + } + + // Collect all CTE aliases so we can exclude them from the table list + const cteAliases = new Set(); + traverse(root, { + enter(node: any) { + if (node.type === "CommonTableExpr" && node.tblName?.text) { + cteAliases.add(node.tblName.text); + } + }, + }); + + // Collect ALL table references from every SelectFrom node in the AST. + // SelectFrom appears in top-level SELECTs (SelectStmt → body → select), + // subqueries (Select → select), UNION compounds (CompoundSelect → select), + // and INSERT...SELECT (InsertStmt → ... → Select → select → SelectFrom). + const allTables = new Set(); + traverse(root, { + enter(node: any) { + if (node.type !== "SelectFrom" || !node.from) return; + const items = node.from.select + ? Array.isArray(node.from.select) + ? node.from.select + : [node.from.select] + : []; + for (const item of items) { + const name = getTableName(item); + if (name && !cteAliases.has(name)) allTables.add(name); + } + if (node.from.joins) { + for (const join of node.from.joins) { + const name = getTableName(join.table); + if (name && !cteAliases.has(name)) allTables.add(name); + } + } + }, + }); + + const tablesRead: string[] = []; + const tablesWritten: string[] = []; + const whereIds: Array<{ tableHint: string; column: string; paramIndices: number[] }> = []; + + if (queryType === "select") { + tablesRead.push(...allTables); + } else if (queryType === "insert" || queryType === "update" || queryType === "delete") { + if (mainTable) tablesWritten.push(mainTable); + for (const t of allTables) { + if (t !== mainTable) tablesRead.push(t); + } + } + + // Map ? positions to param indices by collecting all VariableExpr order + const varExprs: Array<{ offset: number; node: any }> = []; + traverse(root, { + enter(node: any) { + if (node.type === "VariableExpr" && node.name === "?") { + varExprs.push({ offset: node.span?.offset ?? -1, node }); + } + }, + }); + varExprs.sort((a, b) => a.offset - b.offset); + + const paramIndexForOffset = new Map(); + varExprs.forEach((ve, i) => { + paramIndexForOffset.set(ve.offset, i); + }); + + // Collect WHERE conditions with param references + const visited = new Set(); + traverse(root, { + enter(node: any, _parent?: any) { + if (visited.has(node)) return; + visited.add(node); + + // "id = ?" pattern + if (node.type === "BinaryExpr" && node.op === "Equals") { + const columnName = extractColumnName(node.left); + if (!columnName) return; + const paramOffset = extractParamOffset(node.right); + if (paramOffset === -1) return; + const pIdx = paramIndexForOffset.get(paramOffset); + if (pIdx !== undefined && pIdx < params.length) { + const tableHint = extractTableHint(node.left); + whereIds.push({ tableHint, column: columnName, paramIndices: [pIdx] }); + } + } + + // "id IN (?, ?)" pattern + if (node.type === "InListExpr" && Array.isArray(node.rhs)) { + const columnName = extractColumnName(node.lhs); + if (!columnName) return; + const indices: number[] = []; + for (const item of node.rhs) { + const paramOffset = extractParamOffset(item); + if (paramOffset === -1) continue; + const pIdx = paramIndexForOffset.get(paramOffset); + if (pIdx !== undefined && pIdx < params.length) { + indices.push(pIdx); + } + } + if (indices.length > 0) { + const tableHint = extractTableHint(node.lhs); + whereIds.push({ tableHint, column: columnName, paramIndices: indices }); + } + } + }, + }); + + return { queryType, tablesRead, tablesWritten, whereIds }; +} + +function extractColumnName(node: any): string | null { + if (!node) return null; + if (node.type === "Id") return node.name; + if (node.type === "QualifiedExpr" && node.column) return node.column.text ?? node.column.name; + return null; +} + +function extractTableHint(node: any): string { + if (!node) return ""; + if (node.type === "QualifiedExpr" && node.table) return node.table.text ?? node.table.name ?? ""; + return ""; +} + +function extractParamOffset(node: any): number { + if (!node || node.type !== "VariableExpr") return -1; + return node.span?.offset ?? -1; +} diff --git a/packages/server/src/tools/proxy.integration.test.ts b/packages/server/src/tools/proxy.integration.test.ts deleted file mode 100644 index f7be44d..0000000 --- a/packages/server/src/tools/proxy.integration.test.ts +++ /dev/null @@ -1,201 +0,0 @@ -import { describe, it, expect, beforeEach } from "vitest"; -import Database from "better-sqlite3"; -import { drizzle } from "drizzle-orm/better-sqlite3"; -import { sqliteTable, integer, text } from "drizzle-orm/sqlite-core"; -import { eq } from "drizzle-orm"; -import { createTrackedDb } from "./createTrackedDb"; -import type { RawDrizzleDb, EdgePodSessionMap } from "../types"; - -const users = sqliteTable("users", { - id: integer("id").primaryKey(), - name: text("name").notNull(), -}); - -const posts = sqliteTable("posts", { - id: integer("id").primaryKey(), - title: text("title").notNull(), - userId: integer("user_id").notNull(), -}); - -function setup() { - const sqlite = new Database(":memory:"); - const db = drizzle({ client: sqlite, schema: { users, posts } }); - sqlite.exec(` - CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); - CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT NOT NULL, user_id INTEGER NOT NULL); - `); - const tablesRead = new Set(); - const tablesWritten = new Set(); - const warnings: string[] = []; - const activeSessions: EdgePodSessionMap = new Map(); - activeSessions.set("test-session", { - socket: {} as WebSocket, - listeningToTables: new Set(), - }); - - const trackedDb = createTrackedDb( - db as unknown as RawDrizzleDb, - "test-session", - activeSessions, - tablesRead, - tablesWritten, - new Map(), - warnings, - ); - - return { db: trackedDb as any, tablesRead, tablesWritten, warnings }; -} - -describe("proxy integration — limit enforcement", () => { - it("clamps negative limit to 0 (async)", async () => { - const { db } = setup(); - const result = await db.select().from(users).limit(-1); - expect(Array.isArray(result)).toBe(true); - expect(result).toHaveLength(0); - }); - - it("clamps negative limit to 0 (sync)", () => { - const { db } = setup(); - const result = db.select().from(users).limit(-1).all(); - expect(Array.isArray(result)).toBe(true); - expect(result).toHaveLength(0); - }); - - it("caps limit at 1000 when exceeding max", async () => { - const { db, warnings } = setup(); - await db.select().from(users).limit(5000); - expect(warnings).toHaveLength(1); - expect(warnings[0]).toContain("5000"); - expect(warnings[0]).toContain("1000"); - }); - - it("auto-applies default limit when none set", async () => { - const { db } = setup(); - const result = await db.select().from(users); - expect(Array.isArray(result)).toBe(true); - }); -}); - -describe("proxy integration — WHERE enforcement", () => { - it("blocks update without WHERE", () => { - const { db } = setup(); - expect(() => db.update(users).set({ name: "changed" }).run()).toThrow( - "UPDATE without WHERE is blocked", - ); - }); - - it("allows update with WHERE", async () => { - const { db } = setup(); - const result = await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); - expect(result).toBeDefined(); - }); - - it("blocks delete without WHERE", () => { - const { db } = setup(); - expect(() => db.delete(users).run()).toThrow("DELETE without WHERE is blocked"); - }); - - it("allows delete with WHERE", async () => { - const { db } = setup(); - const result = await db.delete(users).where(eq(users.id, 1)).run(); - expect(result).toBeDefined(); - }); -}); - -describe("proxy integration — table tracking", () => { - it("tracks insert as table write (async)", async () => { - const { db, tablesWritten } = setup(); - await db.insert(users).values({ name: "test" }); - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks insert as table write (sync)", () => { - const { db, tablesWritten } = setup(); - db.insert(users).values({ name: "test" }).run(); - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks update as table write", async () => { - const { db, tablesWritten } = setup(); - await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks delete as table write", async () => { - const { db, tablesWritten } = setup(); - await db.delete(users).where(eq(users.id, 1)).run(); - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks select as table read (async)", async () => { - const { db, tablesRead } = setup(); - await db.select().from(users); - expect(tablesRead.has("users")).toBe(true); - }); - - it("tracks select as table read (sync)", () => { - const { db, tablesRead } = setup(); - db.select().from(users).all(); - expect(tablesRead.has("users")).toBe(true); - }); - - it("tracks join tables as reads", async () => { - const { db, tablesRead } = setup(); - await db.select().from(users).leftJoin(posts, eq(users.id, posts.userId)); - expect(tablesRead.has("users")).toBe(true); - expect(tablesRead.has("posts")).toBe(true); - }); -}); - -describe("proxy integration — insert chaining", () => { - it("insert with .returning() records mutation", async () => { - const { db, tablesWritten } = setup(); - const result = await db.insert(users).values({ name: "test" }).returning(); - expect(Array.isArray(result)).toBe(true); - expect(tablesWritten.has("users")).toBe(true); - }); - - it("insert bulk at max limit succeeds", async () => { - const { db } = setup(); - const rows = Array(1000).fill({ name: "test" }); - await expect(db.insert(users).values(rows)).resolves.toBeDefined(); - }); - - it("insert bulk over max limit throws", () => { - const { db } = setup(); - const rows = Array(1001).fill({ name: "test" }); - expect(() => db.insert(users).values(rows)).toThrow("Bulk insert blocked"); - }); -}); - -describe("proxy integration — prepare", () => { - it("blocks insert .prepare()", () => { - const { db } = setup(); - expect(() => db.insert(users).values({ name: "test" }).prepare()).toThrow( - ".prepare() is not supported for inserts", - ); - }); - - it("blocks update .prepare()", () => { - const { db } = setup(); - expect(() => - db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).prepare(), - ).toThrow(".prepare() is not supported for updates"); - }); - - it("blocks delete .prepare()", () => { - const { db } = setup(); - expect(() => db.delete(users).where(eq(users.id, 1)).prepare()).toThrow( - ".prepare() is not supported for deletes", - ); - }); - - it(".prepare() on select returns statement with limit enforced", async () => { - const { db } = setup(); - const prepared = db.select().from(users).prepare(); - expect(typeof prepared.execute).toBe("function"); - expect(typeof prepared.all).toBe("function"); - const result = await prepared.execute(); - expect(Array.isArray(result)).toBe(true); - }); -}); diff --git a/packages/server/src/tools/recordMutation.ts b/packages/server/src/tools/recordMutation.ts deleted file mode 100644 index fbb7cdc..0000000 --- a/packages/server/src/tools/recordMutation.ts +++ /dev/null @@ -1,14 +0,0 @@ -export function recordMutationWithCascades( - tableName: string, - tablesWritten: Set, - cascadeGraph: Map>, -) { - if (tablesWritten.has(tableName)) return; - tablesWritten.add(tableName); - const children = cascadeGraph.get(tableName); - if (children) { - for (const child of children) { - recordMutationWithCascades(child, tablesWritten, cascadeGraph); - } - } -} diff --git a/packages/server/src/tools/safetyProxy.integration.test.ts b/packages/server/src/tools/safetyProxy.integration.test.ts new file mode 100644 index 0000000..77964a2 --- /dev/null +++ b/packages/server/src/tools/safetyProxy.integration.test.ts @@ -0,0 +1,555 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import Database from "better-sqlite3"; +import { drizzle } from "drizzle-orm/better-sqlite3"; +import { sqliteTable, integer, text } from "drizzle-orm/sqlite-core"; +import { eq, inArray, relations, sql } from "drizzle-orm"; +import { createSafetyProxy, type TrackContext } from "./createSafetyProxy"; +import { createTrackedRawDb } from "./createTrackedRawDb"; +import type { EdgePodSessionMap } from "../types"; + +const users = sqliteTable("users", { + id: integer("id").primaryKey(), + name: text("name").notNull(), +}); + +const posts = sqliteTable("posts", { + id: integer("id").primaryKey(), + title: text("title").notNull(), + userId: integer("user_id").notNull(), +}); + +const usersRelations = relations(users, ({ many }) => ({ + posts: many(posts), +})); + +const postsRelations = relations(posts, ({ one }) => ({ + user: one(users, { + fields: [posts.userId], + references: [users.id], + }), +})); + +function setup() { + const sqlite = new Database(":memory:"); + const db = drizzle({ client: sqlite, schema: { users, posts } }); + sqlite.exec(` + CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT NOT NULL, user_id INTEGER NOT NULL); + `); + const tablesRead = new Set(); + const tablesWritten = new Set(); + const rowIds = new Map>(); + const warnings: string[] = []; + const activeSessions: EdgePodSessionMap = new Map(); + activeSessions.set("test-session", { + socket: {} as WebSocket, + listeningToTables: new Set(), + }); + + const trackCtx: TrackContext = { + sessionId: "test-session", + activeSessions, + tablesRead, + tablesWritten, + rowIds, + cascadeGraph: new Map(), + warnings, + pkMap: new Map([["users", ["id"]]]), + }; + + const proxy = createSafetyProxy(db as any, trackCtx); + const rawDb = createTrackedRawDb(db as any, trackCtx); + + return { db: proxy as any, rawDb, tablesRead, tablesWritten, rowIds, warnings, trackCtx }; +} + +describe("safety proxy — limit enforcement", () => { + it("clamps negative limit to 0 (async)", async () => { + const { db } = setup(); + const result = await db.select().from(users).limit(-1); + expect(Array.isArray(result)).toBe(true); + expect(result).toHaveLength(0); + }); + + it("caps limit at 1000 when exceeding max", async () => { + const { db, warnings } = setup(); + await db.select().from(users).limit(5000); + expect(warnings).toHaveLength(1); + expect(warnings[0]).toContain("5000"); + expect(warnings[0]).toContain("1000"); + }); + + it("auto-applies default limit when none set", async () => { + const { db } = setup(); + const result = await db.select().from(users); + expect(Array.isArray(result)).toBe(true); + }); +}); + +describe("safety proxy — WHERE enforcement", () => { + it("blocks update without WHERE", () => { + const { db } = setup(); + expect(() => db.update(users).set({ name: "changed" }).run()).toThrow( + "UPDATE without WHERE is blocked", + ); + }); + + it("allows update with WHERE", async () => { + const { db } = setup(); + const result = await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); + expect(result).toBeDefined(); + }); + + it("blocks delete without WHERE", () => { + const { db } = setup(); + expect(() => db.delete(users).run()).toThrow("DELETE without WHERE is blocked"); + }); + + it("allows delete with WHERE", async () => { + const { db } = setup(); + const result = await db.delete(users).where(eq(users.id, 1)).run(); + expect(result).toBeDefined(); + }); +}); + +describe("safety proxy — table tracking", () => { + it("tracks insert as table write (async)", async () => { + const { db, tablesWritten } = setup(); + await db.insert(users).values({ name: "test" }); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks insert as table write (sync)", () => { + const { db, tablesWritten } = setup(); + db.insert(users).values({ name: "test" }).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks update as table write", async () => { + const { db, tablesWritten } = setup(); + await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks delete as table write", async () => { + const { db, tablesWritten } = setup(); + await db.delete(users).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks select as table read (async)", async () => { + const { db, tablesRead } = setup(); + await db.select().from(users); + expect(tablesRead.has("users")).toBe(true); + }); + + it("tracks select as table read (sync)", () => { + const { db, tablesRead } = setup(); + db.select().from(users).all(); + expect(tablesRead.has("users")).toBe(true); + }); + + it("tracks join tables as reads", async () => { + const { db, tablesRead } = setup(); + await db.select().from(users).leftJoin(posts, eq(users.id, posts.userId)); + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(true); + }); +}); + +describe("safety proxy — insert chaining", () => { + it("insert with .returning() records mutation", async () => { + const { db, tablesWritten } = setup(); + const result = await db.insert(users).values({ name: "test" }).returning(); + expect(Array.isArray(result)).toBe(true); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("insert bulk at max limit succeeds", async () => { + const { db } = setup(); + const rows = Array(1000).fill({ name: "test" }); + await expect(db.insert(users).values(rows)).resolves.toBeDefined(); + }); + + it("insert bulk over max limit throws", () => { + const { db } = setup(); + const rows = Array(1001).fill({ name: "test" }); + expect(() => db.insert(users).values(rows)).toThrow("Bulk insert blocked"); + }); +}); + +describe("safety proxy — prepare", () => { + it("blocks insert .prepare()", () => { + const { db } = setup(); + expect(() => db.insert(users).values({ name: "test" }).prepare()).toThrow( + ".prepare() is not supported for inserts", + ); + }); + + it("blocks update .prepare()", () => { + const { db } = setup(); + expect(() => + db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).prepare(), + ).toThrow(".prepare() is not supported for updates"); + }); + + it("blocks delete .prepare()", () => { + const { db } = setup(); + expect(() => db.delete(users).where(eq(users.id, 1)).prepare()).toThrow( + ".prepare() is not supported for deletes", + ); + }); + + it(".prepare() on select returns statement with limit enforced", async () => { + const { db } = setup(); + const prepared = db.select().from(users).prepare(); + expect(typeof prepared.execute).toBe("function"); + expect(typeof prepared.all).toBe("function"); + const result = await prepared.execute(); + expect(Array.isArray(result)).toBe(true); + }); +}); + +describe("safety proxy — row ID tracking", () => { + it("extracts WHERE IDs from update", async () => { + const { db, rowIds } = setup(); + await db.update(users).set({ name: "changed" }).where(eq(users.id, 42)).run(); + expect(rowIds.has("users")).toBe(true); + const ids = [...rowIds.get("users")!]; + // 42 hashed with djb2 should produce a stable hash + expect(ids).toHaveLength(1); + }); + + it("extracts WHERE IDs from delete", async () => { + const { db, rowIds } = setup(); + await db.delete(users).where(eq(users.id, 7)).run(); + expect(rowIds.has("users")).toBe(true); + const ids = [...rowIds.get("users")!]; + expect(ids).toHaveLength(1); + }); + + it("skips row IDs for inserts without WHERE", async () => { + const { db, rowIds } = setup(); + await db.insert(users).values({ name: "test" }); + expect(rowIds.size).toBe(0); + }); +}); + +describe("safety proxy — async error propagation", () => { + it("propagates async DB errors via await without unhandled rejection", async () => { + const { db } = setup(); + await expect(db.insert(users).values({ name: null as any })).rejects.toThrow(); + }); + + it("fires warnRowLimit through async (then) path", async () => { + const { db, warnings } = setup(); + const rows = Array.from({ length: 1000 }, (_, i) => ({ name: `user-${i}` })); + await db.insert(users).values(rows); + warnings.length = 0; + await db.select().from(users); + expect(warnings).toHaveLength(1); + expect(warnings[0]).toContain("1000 rows"); + }); +}); + +describe("safety proxy — relation tracking (with option)", () => { + function setupWithRelations() { + const sqlite = new Database(":memory:"); + const db = drizzle({ + client: sqlite, + schema: { users, posts, usersRelations, postsRelations }, + }); + sqlite.exec(` + CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT NOT NULL, user_id INTEGER NOT NULL); + `); + const tablesRead = new Set(); + const tablesWritten = new Set(); + const rowIds = new Map>(); + const warnings: string[] = []; + const activeSessions: EdgePodSessionMap = new Map(); + activeSessions.set("test-session", { + socket: {} as WebSocket, + listeningToTables: new Set(), + }); + + const trackCtx: TrackContext = { + sessionId: "test-session", + activeSessions, + tablesRead, + tablesWritten, + rowIds, + cascadeGraph: new Map(), + warnings, + pkMap: new Map([["users", ["id"]]]), + }; + + const proxy = createSafetyProxy(db as any, trackCtx); + + return { db: proxy as any, tablesRead, tablesWritten, warnings, trackCtx }; + } + + it("tracks root and relation tables from findMany with flat with", async () => { + const { db, tablesRead } = setupWithRelations(); + try { + await db.query.users.findMany({ with: { posts: true } }); + } catch { + // Drizzle may throw if relations not processed, tablesRead is still populated + } + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(true); + }); + + it("tracks root and relation from findFirst with with", async () => { + const { db, tablesRead } = setupWithRelations(); + try { + await db.query.users.findFirst({ with: { posts: true } }); + } catch { + // ignore + } + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(true); + }); + + it("tracks nested relation tables recursively", async () => { + const { db, tablesRead } = setupWithRelations(); + try { + await db.query.users.findMany({ + with: { + posts: { + with: { user: true }, + }, + }, + }); + } catch { + // ignore + } + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(true); + // "user" is the relation name from postsRelations — even though it's + // the same physical table, the relation is a distinct tracking key + expect(tablesRead.has("user")).toBe(true); + }); + + it("tracks nothing from findMany without with option", async () => { + const { db, tablesRead } = setupWithRelations(); + await db.query.users.findMany(); + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(false); + }); + + it("records relation tables in listeningToTables", async () => { + const { db, trackCtx } = setupWithRelations(); + try { + await db.query.users.findMany({ with: { posts: true } }); + } catch { + // ignore + } + const session = trackCtx.activeSessions.get("test-session"); + // listeningToTables uses hashed names — verify at least root + relation + expect(session?.listeningToTables.size).toBeGreaterThanOrEqual(2); + }); +}); + +describe("safety proxy — raw SQL tracking", () => { + it("tracks writes via Drizzle SQL template in rawDb.run", async () => { + const { rawDb, tablesWritten } = setup(); + rawDb.run(sql`INSERT INTO users (name) VALUES (${"test"})`); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks reads via Drizzle SQL template in rawDb.all", async () => { + const { rawDb, tablesRead } = setup(); + rawDb.all(sql`SELECT * FROM users`); + expect(tablesRead.has("users")).toBe(true); + }); + + it("tracks writes via raw string SQL in rawDb.run", async () => { + const { rawDb, tablesWritten } = setup(); + rawDb.run("INSERT INTO users (name) VALUES ('test')"); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks reads via raw string SQL in rawDb.all", async () => { + const { rawDb, tablesRead } = setup(); + rawDb.all("SELECT * FROM users"); + expect(tablesRead.has("users")).toBe(true); + }); +}); + +describe("safety proxy — cascade isolation", () => { + function setupWithCascade() { + const sqlite = new Database(":memory:"); + const db = drizzle({ client: sqlite, schema: { users, posts } }); + sqlite.exec(` + CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); + CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT NOT NULL, user_id INTEGER NOT NULL); + `); + const tablesRead = new Set(); + const tablesWritten = new Set(); + const rowIds = new Map>(); + const warnings: string[] = []; + const activeSessions: EdgePodSessionMap = new Map(); + activeSessions.set("test-session", { + socket: {} as WebSocket, + listeningToTables: new Set(), + }); + + // Cascade graph: users → posts (simulating FK with onDelete: cascade) + const cascadeGraph = new Map>(); + cascadeGraph.set("users", new Set(["posts"])); + + const trackCtx: TrackContext = { + sessionId: "test-session", + activeSessions, + tablesRead, + tablesWritten, + rowIds, + cascadeGraph, + warnings, + pkMap: new Map([ + ["users", ["id"]], + ["posts", ["id"]], + ]), + }; + + const proxy = createSafetyProxy(db as any, trackCtx); + const rawDb = createTrackedRawDb(db as any, trackCtx); + + return { db: proxy as any, rawDb, tablesRead, tablesWritten, warnings, trackCtx }; + } + + it("does NOT cascade on insert", async () => { + const { db, tablesWritten } = setupWithCascade(); + await db.insert(users).values({ name: "test" }); + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(false); + }); + + it("does NOT cascade on update", async () => { + const { db, tablesWritten } = setupWithCascade(); + await db.update(users).set({ name: "x" }).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(false); + }); + + it("DOES cascade on delete", async () => { + const { db, tablesWritten } = setupWithCascade(); + await db.delete(users).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(true); + }); + + it("tracks raw SQL delete with cascade", async () => { + const { rawDb, tablesWritten } = setupWithCascade(); + rawDb.run(sql`DELETE FROM users WHERE id = 1`); + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(true); + }); + + it("tracks raw SQL insert without cascade", async () => { + const { rawDb, tablesWritten } = setupWithCascade(); + rawDb.run(sql`INSERT INTO users (name) VALUES (${"test"})`); + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(false); + }); +}); + +describe("safety proxy — .withoutWhere() escape hatch", () => { + it("allows update without WHERE via .withoutWhere()", async () => { + const { db, tablesWritten, warnings } = setup(); + const result = await db.update(users).set({ name: "changed" }).withoutWhere().run(); + expect(result).toBeDefined(); + expect(tablesWritten.has("users")).toBe(true); + expect(warnings).toHaveLength(1); + expect(warnings[0]).toContain("Unfiltered update"); + expect(warnings[0]).toContain(".withoutWhere()"); + }); + + it("allows delete without WHERE via .withoutWhere()", async () => { + const { db, tablesWritten, warnings } = setup(); + const result = await db.delete(users).withoutWhere().run(); + expect(result).toBeDefined(); + expect(tablesWritten.has("users")).toBe(true); + expect(warnings).toHaveLength(1); + expect(warnings[0]).toContain("Unfiltered delete"); + expect(warnings[0]).toContain(".withoutWhere()"); + }); +}); + +describe("safety proxy — row ID PK filtering", () => { + function setupWithPkMap() { + const sqlite = new Database(":memory:"); + const db = drizzle({ client: sqlite, schema: { users } }); + sqlite.exec(`CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)`); + const tablesRead = new Set(); + const tablesWritten = new Set(); + const rowIds = new Map>(); + const warnings: string[] = []; + const activeSessions: EdgePodSessionMap = new Map(); + activeSessions.set("test-session", { + socket: {} as WebSocket, + listeningToTables: new Set(), + }); + + const trackCtx: TrackContext = { + sessionId: "test-session", + activeSessions, + tablesRead, + tablesWritten, + rowIds, + cascadeGraph: new Map(), + warnings, + pkMap: new Map([["users", ["id"]]]), + }; + + const proxy = createSafetyProxy(db as any, trackCtx); + + return { db: proxy as any, rowIds, tablesWritten }; + } + + it("records row ID for WHERE on PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db.update(users).set({ name: "changed" }).where(eq(users.id, 42)).run(); + expect(rowIds.has("users")).toBe(true); + expect(rowIds.get("users")!.size).toBe(1); + }); + + it("does NOT record row ID for WHERE on non-PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db.update(users).set({ name: "changed" }).where(eq(users.name, "John")).run(); + expect(rowIds.size).toBe(0); + }); + + it("records row ID for DELETE on PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db.delete(users).where(eq(users.id, 7)).run(); + expect(rowIds.has("users")).toBe(true); + expect(rowIds.get("users")!.size).toBe(1); + }); + + it("does NOT record row ID for DELETE on non-PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db.delete(users).where(eq(users.name, "John")).run(); + expect(rowIds.size).toBe(0); + }); + + it("records row ID for WHERE IN on PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db + .delete(users) + .where(inArray(users.id, [1, 2, 3])) + .run(); + expect(rowIds.has("users")).toBe(true); + expect(rowIds.get("users")!.size).toBe(3); + }); + + it("does NOT record row ID for WHERE IN on non-PK column", async () => { + const { db, rowIds } = setupWithPkMap(); + await db + .delete(users) + .where(inArray(users.name, ["Alice", "Bob"])) + .run(); + expect(rowIds.size).toBe(0); + }); +}); diff --git a/packages/server/src/tools/tracking.ts b/packages/server/src/tools/tracking.ts new file mode 100644 index 0000000..4918e04 --- /dev/null +++ b/packages/server/src/tools/tracking.ts @@ -0,0 +1,74 @@ +import { parseSqlTracking } from "./parseSqlTracking"; +import { hashTableName } from "./hashTableName"; +import type { TrackContext } from "./createSafetyProxy"; + +export function cascadeWrite(table: string, written: Set, graph: Map>) { + if (written.has(table)) return; + written.add(table); + for (const child of graph.get(table) ?? []) cascadeWrite(child, written, graph); +} + +export function addListener(table: string, ctx: TrackContext) { + if (table === "unknown") return; + ctx.activeSessions.get(ctx.sessionId)?.listeningToTables.add(hashTableName(table)); + ctx.tablesRead.add(table); +} + +export function recordWhereIds( + parsed: ReturnType, + params: unknown[], + ctx: TrackContext, +) { + for (const wid of parsed.whereIds) { + let table = wid.tableHint || ""; + if (!table) { + const first = parsed.tablesWritten[0]; + if (first) table = first; + } + if (!table) continue; + const pkCols = ctx.pkMap.get(table); + if (pkCols && !pkCols.includes(wid.column)) continue; + for (const idx of wid.paramIndices) { + if (idx < params.length) { + const hashed = hashTableName(String(params[idx])); + let ids = ctx.rowIds.get(table); + if (!ids) { + ids = new Set(); + ctx.rowIds.set(table, ids); + } + ids.add(hashed); + } + } + } +} + +export function trackExec(builder: any, ctx: TrackContext, tableHint?: string, queryType?: string) { + try { + const { sql, params } = builder.toSQL(); + const parsed = parseSqlTracking(sql, params); + for (const t of parsed.tablesRead) addListener(t, ctx); + for (const t of parsed.tablesWritten) { + if (parsed.queryType === "delete") { + cascadeWrite(t, ctx.tablesWritten, ctx.cascadeGraph); + } else { + ctx.tablesWritten.add(t); + } + } + recordWhereIds(parsed, params, ctx); + } catch { + if (tableHint) { + if (queryType === "delete") { + cascadeWrite(tableHint, ctx.tablesWritten, ctx.cascadeGraph); + } else { + ctx.tablesWritten.add(tableHint); + } + } + } +} + +export function warnRowLimit(result: unknown, warnings: string[]) { + if (Array.isArray(result) && result.length === 1000) + warnings.push( + "Query returned exactly 1000 rows — there may be more results. Use .limit() and .offset() to paginate.", + ); +} diff --git a/packages/server/src/types/index.ts b/packages/server/src/types/index.ts index 53134aa..5d162f1 100644 --- a/packages/server/src/types/index.ts +++ b/packages/server/src/types/index.ts @@ -31,6 +31,7 @@ export type RpcRequest = { export type RpcMeta = { read: string[]; changed: string[]; + rows?: Record; }; export type RpcResponse = { @@ -67,9 +68,8 @@ export type EdgePodContext< log: Logger; // A per-request logger with traceId bound — use for structured, traceable output - // Manual Reactivity Escape Hatches + // Manual Reactivity Escape Hatch subscribeTo: (tables: string[]) => void; - invalidate: (tables: string[]) => void; // Environment Variables (for Stripe keys, Resend API, etc.) env: Cloudflare.Env & TEnv; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 95a34e7..574c04e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -112,6 +112,9 @@ importers: neverthrow: specifier: ^8.2.0 version: 8.2.0 + sqlite3-parser: + specifier: ^0.7.1 + version: 0.7.1 devDependencies: '@types/better-sqlite3': specifier: ^7.6.13 @@ -3149,6 +3152,10 @@ packages: spdx-license-ids@3.0.23: resolution: {integrity: sha512-CWLcCCH7VLu13TgOH+r8p1O/Znwhqv/dbb6lqWy67G+pT1kHmeD/+V36AVb/vq8QMIQwVShJ6Ssl5FPh0fuSdw==} + sqlite3-parser@0.7.1: + resolution: {integrity: sha512-+KcAIcmD9xk4Sz3hNsEI7QEGAMGl3s5PyigIf6ri/u1DzAuAmi5YZInjtnXKqHvx9ySl5e0RYILfCVhTdAeKJg==} + hasBin: true + sqlstring@2.3.3: resolution: {integrity: sha512-qC9iz2FlN7DQl3+wjwn3802RTyjCx7sDvfQEXchwa6CWOx07/WVfh91gBmQ9fahw8snwGEWU3xGzOt4tFyHLxg==} engines: {node: '>= 0.6'} @@ -6099,6 +6106,8 @@ snapshots: spdx-license-ids@3.0.23: {} + sqlite3-parser@0.7.1: {} + sqlstring@2.3.3: optional: true