diff --git a/AGENTS.md b/AGENTS.md index 1d427a7..c40182b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,7 +3,7 @@ ## General Behaviour - **Ask questions if unsure, do not assume anything.** When requirements are ambiguous, ask for clarification before writing code. -- **Keep files under 150 lines** (soft limit). Files above 200 lines must be refactored into smaller modules (hard limit). +- **Keep files under 150-200 lines** (soft limit). Files above 250 lines must be refactored into smaller modules (hard limit). - **No Python.** For helper scripts, use Node.js (plain `.mjs` files). Never reach for Python, shell scripts beyond simple one-liners, or other runtimes. - **Do not edit auto-generated files.** Files like `routeTree.gen.ts` (TanStack Router), `worker-configuration.d.ts`, or any file with a `// This file is auto-generated` header must never be manually edited — they are overwritten by tooling. - **Do not edit shadcn/ui files.** Files under `src/components/ui/` are installed and managed by the shadcn CLI. Never modify them — override styles at the call site instead. diff --git a/packages/server/package.json b/packages/server/package.json index 2f432b0..e2649d4 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -48,7 +48,8 @@ "@logtape/logtape": "^2.0.5", "drizzle-orm": "^0.45.2", "jose": "^6.2.3", - "neverthrow": "^8.2.0" + "neverthrow": "^8.2.0", + "sqlite3-parser": "^0.7.1" }, "devDependencies": { "@types/better-sqlite3": "^7.6.13", diff --git a/packages/server/src/test-utils/createTestDb.ts b/packages/server/src/test-utils/createTestDb.ts new file mode 100644 index 0000000..c7fff3d --- /dev/null +++ b/packages/server/src/test-utils/createTestDb.ts @@ -0,0 +1,11 @@ +import type { DurableObjectStorage } from "@cloudflare/workers-types"; +import Database from "better-sqlite3"; +import { drizzle } from "drizzle-orm/durable-sqlite"; +import { createMockDOStorage } from "./mockDOStorage"; + +export function createTestDb>(schema: TSchema) { + const sqlite = new Database(":memory:"); + const storage = createMockDOStorage(sqlite); + const db = drizzle(storage as unknown as DurableObjectStorage, { schema }); + return { db, sqlite, storage }; +} diff --git a/packages/server/src/test-utils/mockDOStorage.ts b/packages/server/src/test-utils/mockDOStorage.ts new file mode 100644 index 0000000..cbfb600 --- /dev/null +++ b/packages/server/src/test-utils/mockDOStorage.ts @@ -0,0 +1,110 @@ +import DatabaseCtor from "better-sqlite3"; + +type Database = InstanceType; + +export type SqlStorageCursor = { + toArray(): T[]; + next(): IteratorResult; + raw(): SqlStorageCursor; + [Symbol.iterator](): IterableIterator; +}; + +export type MockDOStorage = { + sql: { + exec(sql: string, ...bindings: unknown[]): SqlStorageCursor; + databaseSize: number; + }; + transactionSync(callback: () => T): T; +}; + +function makeEmptyCursor(): SqlStorageCursor { + const doneResult: IteratorResult = { done: true, value: undefined }; + const emptyIter: IterableIterator = { + next: () => doneResult, + [Symbol.iterator]: () => emptyIter, + }; + return { + toArray: () => [], + next: () => doneResult, + raw: () => makeEmptyCursor(), + [Symbol.iterator]: () => emptyIter, + }; +} + +function makeCursor( + sqlite: Database, + sql: string, + params: unknown[], + raw = false, +): SqlStorageCursor { + const stmt = sqlite.prepare(sql); + if (raw) stmt.raw(true); + + let rows: T[]; + try { + rows = (params.length > 0 ? stmt.all(...params) : stmt.all()) as T[]; + } catch { + // Non-SELECT statement (INSERT, UPDATE, DELETE, CREATE, etc.) + if (params.length > 0) { + stmt.run(...params); + } else { + stmt.run(); + } + return makeEmptyCursor(); + } + + let index = 0; + const iter: IterableIterator = { + next() { + if (index < rows.length) { + return { done: false, value: rows[index++] } as IteratorResult; + } + return { done: true, value: undefined } as IteratorResult; + }, + [Symbol.iterator]() { + return iter; + }, + }; + + return { + toArray() { + return rows; + }, + next() { + return iter.next(); + }, + raw() { + return makeCursor(sqlite, sql, params, true); + }, + [Symbol.iterator]() { + return iter; + }, + }; +} + +export function createMockDOStorage(sqlite: Database): MockDOStorage { + return { + sql: { + exec(sql: string, ...params: unknown[]) { + return makeCursor(sqlite, sql, params); + }, + get databaseSize() { + try { + const pageCount = sqlite.prepare("PRAGMA page_count").get() as { + page_count: number; + }; + const pageSize = sqlite.prepare("PRAGMA page_size").get() as { + page_size: number; + }; + return pageCount.page_count * pageSize.page_size; + } catch { + return 0; + } + }, + }, + transactionSync(callback: () => T): T { + const tx = sqlite.transaction(callback); + return tx(); + }, + }; +} diff --git a/packages/server/src/tools/createMutationProxy.ts b/packages/server/src/tools/createMutationProxy.ts index 310d94d..316411d 100644 --- a/packages/server/src/tools/createMutationProxy.ts +++ b/packages/server/src/tools/createMutationProxy.ts @@ -1,13 +1,9 @@ -import { recordMutationWithCascades } from "./recordMutation"; import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; export function createMutationProxy( builder: Record, warnings: string[], mutationType: "update" | "delete", - tableName?: string, - tablesWritten?: Set, - cascadeGraph?: Map>, initialState = { whereSet: false, withoutWhereSet: false }, ): unknown { const config: ProxyConfig = { @@ -29,9 +25,6 @@ export function createMutationProxy( `[EdgePod] ${mutationType.toUpperCase()} without WHERE is blocked. If intentional, chain .withoutWhere().`, ); } - if (tableName && tableName !== "unknown" && tablesWritten) { - recordMutationWithCascades(tableName, tablesWritten, cascadeGraph ?? new Map()); - } return target[prop](...args); }, }; diff --git a/packages/server/src/tools/createSelectProxy.test.ts b/packages/server/src/tools/createSelectProxy.test.ts index 7b9455f..0badcd7 100644 --- a/packages/server/src/tools/createSelectProxy.test.ts +++ b/packages/server/src/tools/createSelectProxy.test.ts @@ -1,11 +1,5 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { createSelectProxy } from "./createSelectProxy"; -import { hashTableName } from "./hashTableName"; -import type { EdgePodSessionMap } from "../types"; - -vi.mock("drizzle-orm", () => ({ - getTableName: vi.fn((t: { name?: string } | null) => t?.name ?? "unknown"), -})); function createMockBuilder( options: { resultData?: Record[]; limit?: number } = {}, @@ -28,15 +22,11 @@ function createMockBuilder( from: vi.fn(function () { return builder; }), - leftJoin: vi.fn(function (_table: unknown) { - const opts: { resultData: Record[]; limit?: number } = { resultData }; - if (currentLimit !== undefined) opts.limit = currentLimit; - return createMockBuilder(opts); + leftJoin: vi.fn(function () { + return builder; }), - innerJoin: vi.fn(function (_table: unknown) { - const opts: { resultData: Record[]; limit?: number } = { resultData }; - if (currentLimit !== undefined) opts.limit = currentLimit; - return createMockBuilder(opts); + innerJoin: vi.fn(function () { + return builder; }), rightJoin: vi.fn(function () { return builder; @@ -54,29 +44,16 @@ function createMockBuilder( return builder; } -function createMockJoinTable() { - return { name: "joined_table" }; -} - describe("createSelectProxy", () => { - let tablesRead: Set; let warnings: string[]; - let activeSessions: EdgePodSessionMap; - const sessionId = "test-session"; beforeEach(() => { - tablesRead = new Set(); warnings = []; - activeSessions = new Map(); - activeSessions.set(sessionId, { - socket: {} as WebSocket, - listeningToTables: new Set(), - }); }); it("auto-applies max limit when none set", async () => { const builder = createMockBuilder({ resultData: Array(2000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); const result = await proxy; @@ -85,7 +62,7 @@ describe("createSelectProxy", () => { it("respects user-set limit under max", async () => { const builder = createMockBuilder({ resultData: Array(100).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); const withLimit = proxy.limit(50); const result = await withLimit; @@ -95,7 +72,7 @@ describe("createSelectProxy", () => { it("caps limit at max", async () => { const builder = createMockBuilder({ resultData: Array(5000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); const withLimit = proxy.limit(5000); const result = await withLimit; @@ -105,7 +82,7 @@ describe("createSelectProxy", () => { it("adds warning when user limit exceeds max", async () => { const builder = createMockBuilder({ resultData: [] }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); await proxy.limit(5000); @@ -114,29 +91,9 @@ describe("createSelectProxy", () => { expect(warnings[0]).toContain("1000"); }); - it("tracks table reads on join methods", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.leftJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on innerJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.innerJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - it("adds warning when result hits max limit", async () => { const builder = createMockBuilder({ resultData: Array(1000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); await proxy; @@ -147,57 +104,16 @@ describe("createSelectProxy", () => { it("does not add warning when result is under limit", async () => { const builder = createMockBuilder({ resultData: Array(100).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); await proxy; expect(warnings).toHaveLength(0); }); - it("tracks table reads on rightJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.rightJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on fullJoin", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.fullJoin(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("tracks table reads on from", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.from(joinTable, {}); - - expect(tablesRead.has("joined_table")).toBe(true); - }); - - it("registers listening tables on session for joins", () => { - const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); - - const joinTable = createMockJoinTable(); - proxy.leftJoin(joinTable, {}); - - const session = activeSessions.get(sessionId); - expect(session?.listeningToTables.has(hashTableName("joined_table"))).toBe(true); - }); - it("preserves proxy through chained method calls", () => { const builder = createMockBuilder(); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); const withWhere = proxy.where({ id: 1 }); expect(withWhere).toBeDefined(); @@ -206,7 +122,7 @@ describe("createSelectProxy", () => { it("original proxy still applies max limit after .limit() on a branch", async () => { const builder = createMockBuilder({ resultData: Array(2000).fill({ id: 1 }) }); - const proxy = createSelectProxy(builder, sessionId, activeSessions, tablesRead, warnings, 1000); + const proxy = createSelectProxy(builder, warnings, 1000); proxy.limit(50); const result = await proxy; diff --git a/packages/server/src/tools/createSelectProxy.ts b/packages/server/src/tools/createSelectProxy.ts index fc3f24d..cb9cec9 100644 --- a/packages/server/src/tools/createSelectProxy.ts +++ b/packages/server/src/tools/createSelectProxy.ts @@ -1,27 +1,8 @@ -import { getTableName } from "drizzle-orm"; -import { EdgePodSessionMap } from "../types"; import { checkResultWarnings } from "./checkResultWarnings"; import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; -import { hashTableName } from "./hashTableName"; - -function trackTable( - table: unknown, - tablesRead: Set, - activeSessions: EdgePodSessionMap, - sessionId: string, -) { - const tableName = getTableName(table as any) ?? "unknown"; - if (tableName === "unknown") return; - const session = activeSessions.get(sessionId); - if (session) session.listeningToTables.add(hashTableName(tableName)); - tablesRead.add(tableName); -} export function createSelectProxy( builder: Record, - sessionId: string, - activeSessions: EdgePodSessionMap, - tablesRead: Set, warnings: string[], maxLimit: number, ): unknown { @@ -35,30 +16,6 @@ export function createSelectProxy( const clamped = Math.max(0, Math.min(n, maxLimit)); return factory(target.limit(clamped), { ...state, limitSet: true }); }, - from: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.from(...args), { ...state }); - }, - leftJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.leftJoin(...args), { ...state }); - }, - innerJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.innerJoin(...args), { ...state }); - }, - rightJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.rightJoin(...args), { ...state }); - }, - fullJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.fullJoin(...args), { ...state }); - }, - crossJoin: (target, args, state, factory) => { - trackTable(args[0], tablesRead, activeSessions, sessionId); - return factory(target.crossJoin(...args), { ...state }); - }, }, onExecute: (target, prop, args, state) => { const finalBuilder = state.limitSet ? target : target.limit(maxLimit); diff --git a/packages/server/src/tools/createTrackedClient.test.ts b/packages/server/src/tools/createTrackedClient.test.ts new file mode 100644 index 0000000..6dcc08e --- /dev/null +++ b/packages/server/src/tools/createTrackedClient.test.ts @@ -0,0 +1,128 @@ +import { describe, it, expect } from "vitest"; +import Database from "better-sqlite3"; +import { createMockDOStorage } from "../test-utils/mockDOStorage"; +import { createTrackedClient, recordCascades } from "./createTrackedClient"; + +describe("recordCascades", () => { + it("records a table and its cascade children", () => { + const tablesWritten = new Set(); + const cascadeGraph = new Map>(); + cascadeGraph.set("users", new Set(["posts", "comments"])); + cascadeGraph.set("posts", new Set(["likes"])); + + recordCascades("users", tablesWritten, cascadeGraph); + + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(true); + expect(tablesWritten.has("comments")).toBe(true); + expect(tablesWritten.has("likes")).toBe(true); + }); + + it("does not duplicate already-recorded tables", () => { + const tablesWritten = new Set(); + tablesWritten.add("posts"); + const cascadeGraph = new Map>(); + cascadeGraph.set("users", new Set(["posts"])); + + recordCascades("users", tablesWritten, cascadeGraph); + + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(true); + expect(tablesWritten.size).toBe(2); + }); + + it("handles empty cascade graph", () => { + const tablesWritten = new Set(); + recordCascades("users", tablesWritten, new Map()); + expect(tablesWritten.has("users")).toBe(true); + }); +}); + +describe("createTrackedClient", () => { + function setup() { + const sqlite = new Database(":memory:"); + const storage = createMockDOStorage(sqlite); + const tablesRead = new Set(); + const tablesWritten = new Set(); + const cascadeGraph = new Map>(); + + const tracked = createTrackedClient( + storage as unknown as DurableObjectStorage, + tablesRead, + tablesWritten, + cascadeGraph, + ); + + return { tracked, sqlite, tablesRead, tablesWritten, cascadeGraph }; + } + + it("tracks SELECT as table read", () => { + const { tracked, sqlite, tablesRead } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + + tracked.sql.exec('SELECT * FROM "users"'); + + expect(tablesRead.has("users")).toBe(true); + }); + + it("tracks INSERT as table write", () => { + const { tracked, sqlite, tablesWritten } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + + tracked.sql.exec('INSERT INTO "users" ("id") VALUES (?)', [1]); + + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks UPDATE as table write", () => { + const { tracked, sqlite, tablesWritten } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + + tracked.sql.exec('UPDATE "users" SET "id" = ?', [2]); + + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks DELETE as table write", () => { + const { tracked, sqlite, tablesWritten } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + + tracked.sql.exec('DELETE FROM "users" WHERE "id" = ?', [1]); + + expect(tablesWritten.has("users")).toBe(true); + }); + + it("propagates cascades on write", () => { + const { tracked, sqlite, tablesWritten, cascadeGraph } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + sqlite.exec("CREATE TABLE posts (id INTEGER PRIMARY KEY)"); + cascadeGraph.set("users", new Set(["posts"])); + + tracked.sql.exec('DELETE FROM "users" WHERE "id" = ?', [1]); + + expect(tablesWritten.has("users")).toBe(true); + expect(tablesWritten.has("posts")).toBe(true); + }); + + it("tracks JOIN tables as reads", () => { + const { tracked, sqlite, tablesRead } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + sqlite.exec("CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER)"); + + tracked.sql.exec('SELECT * FROM "users" LEFT JOIN "posts" ON "posts"."user_id" = "users"."id"'); + + expect(tablesRead.has("users")).toBe(true); + expect(tablesRead.has("posts")).toBe(true); + }); + + it("tracks tables inside a transaction", () => { + const { tracked, sqlite, tablesRead } = setup(); + sqlite.exec("CREATE TABLE users (id INTEGER PRIMARY KEY)"); + + tracked.transactionSync(() => { + tracked.sql.exec('SELECT * FROM "users"'); + }); + + expect(tablesRead.has("users")).toBe(true); + }); +}); diff --git a/packages/server/src/tools/createTrackedClient.ts b/packages/server/src/tools/createTrackedClient.ts new file mode 100644 index 0000000..f13f2f5 --- /dev/null +++ b/packages/server/src/tools/createTrackedClient.ts @@ -0,0 +1,49 @@ +import { parseSqlTracking } from "./parseSqlTracking"; + +export function recordCascades( + tableName: string, + tablesWritten: Set, + cascadeGraph: Map>, +) { + if (tablesWritten.has(tableName)) return; + tablesWritten.add(tableName); + const children = cascadeGraph.get(tableName); + if (children) { + for (const child of children) { + recordCascades(child, tablesWritten, cascadeGraph); + } + } +} + +export function createTrackedClient( + storage: DurableObjectStorage, + tablesRead: Set, + tablesWritten: Set, + cascadeGraph: Map>, +): DurableObjectStorage { + return new Proxy(storage, { + get(target, prop, receiver) { + if (prop === "sql") { + const sql = Reflect.get(target, prop, receiver); + return new Proxy(sql, { + get(sqlTarget, sqlProp, sqlReceiver) { + const value = Reflect.get(sqlTarget, sqlProp, sqlReceiver); + if (sqlProp === "exec" && typeof value === "function") { + return (sqlStr: string, ...params: unknown[]) => { + const parsed = parseSqlTracking(sqlStr, params); + for (const t of parsed.tablesRead) tablesRead.add(t); + for (const t of parsed.tablesWritten) { + recordCascades(t, tablesWritten, cascadeGraph); + } + return value.apply(sqlTarget, [sqlStr, ...params]); + }; + } + return typeof value === "function" ? value.bind(sqlTarget) : value; + }, + }); + } + const value = Reflect.get(target, prop, receiver); + return typeof value === "function" ? value.bind(target) : value; + }, + }) as DurableObjectStorage; +} diff --git a/packages/server/src/tools/createTrackedDb.test.ts b/packages/server/src/tools/createTrackedDb.test.ts index a573028..a3908d6 100644 --- a/packages/server/src/tools/createTrackedDb.test.ts +++ b/packages/server/src/tools/createTrackedDb.test.ts @@ -124,6 +124,7 @@ describe("createTrackedDb", () => { socket: {} as WebSocket, listeningToTables: new Set(), }); + vi.spyOn(console, "warn").mockImplementation(() => {}); }); function createProxy(cascadeGraph?: Map>) { @@ -150,89 +151,6 @@ describe("createTrackedDb", () => { expect(() => (proxy as any).execute()).toThrow("ctx.db.execute"); }); - it("tracks insert as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).insert(usersTable).values({ name: "test" }); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks update as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).update(usersTable).set({ name: "updated" }).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("tracks delete as table write", async () => { - const { proxy } = createProxy(); - const usersTable = { name: "users" }; - - await (proxy as any).delete(usersTable).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - }); - - it("propagates cascades on delete", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts", "comments"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).delete(usersTable).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(true); - expect(tablesWritten.has("comments")).toBe(true); - }); - - it("does not propagate cascades on insert", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).insert(usersTable).values({ name: "test" }); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(false); - }); - - it("does not propagate cascades on update", async () => { - const cascadeGraph = new Map>(); - cascadeGraph.set("users", new Set(["posts"])); - - const { proxy } = createProxy(cascadeGraph); - const usersTable = { name: "users" }; - - await (proxy as any).update(usersTable).set({ name: "updated" }).where({ id: 1 }).run(); - - expect(tablesWritten.has("users")).toBe(true); - expect(tablesWritten.has("posts")).toBe(false); - }); - - it("tracks select as table read via query.findMany", async () => { - const { proxy } = createProxy(); - - await (proxy as any).query.users.findMany(); - - expect(tablesRead.has("users")).toBe(true); - }); - - it("tracks select as table read via query.findFirst", async () => { - const { proxy } = createProxy(); - - await (proxy as any).query.users.findFirst(); - - expect(tablesRead.has("users")).toBe(true); - }); - it("registers listening tables on session via query.findMany", async () => { const { proxy } = createProxy(); @@ -292,4 +210,13 @@ describe("createTrackedDb", () => { const existingMethod = (proxy as any).select; expect(typeof existingMethod).toBe("function"); }); + + it("logs warning when realDb.$client is missing", () => { + const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}); + createProxy(); + expect(warnSpy).toHaveBeenCalledWith( + "[EdgePod] Unable to wire SQL tracking: realDb.$client is missing or invalid.", + ); + warnSpy.mockRestore(); + }); }); diff --git a/packages/server/src/tools/createTrackedDb.ts b/packages/server/src/tools/createTrackedDb.ts index 15e77ea..d088191 100644 --- a/packages/server/src/tools/createTrackedDb.ts +++ b/packages/server/src/tools/createTrackedDb.ts @@ -1,21 +1,15 @@ -import { getTableName } from "drizzle-orm"; import { RawDrizzleDb, EdgePodSessionMap } from "../types"; import { checkResultWarnings } from "./checkResultWarnings"; import { createSelectProxy } from "./createSelectProxy"; import { createMutationProxy } from "./createMutationProxy"; import { hashTableName } from "./hashTableName"; -import { recordMutationWithCascades } from "./recordMutation"; +import { createTrackedClient } from "./createTrackedClient"; import { createQueryProxy, type ProxyConfig } from "./createQueryProxy"; const FORBIDDEN_RAW_METHODS = ["run", "all", "get", "values", "execute"]; const MAX_LIMIT = 1000; -function createInsertProxy( - builder: Record, - maxLimit: number, - tableName: string, - tablesWritten: Set, -): unknown { +function createInsertProxy(builder: Record, maxLimit: number): unknown { const config: ProxyConfig = { onMethod: { values: (target, args, _state, factory) => { @@ -32,9 +26,6 @@ function createInsertProxy( if (prop === "prepare") { throw new Error("[EdgePod] .prepare() is not supported for inserts."); } - if (tableName !== "unknown") { - recordMutationWithCascades(tableName, tablesWritten, new Map()); - } return target[prop](...args); }, }; @@ -42,17 +33,12 @@ function createInsertProxy( return createQueryProxy(builder, {}, config); } -function createUpdateBuilderProxy( - builder: Record, - warnings: string[], - tableName: string, - tablesWritten: Set, -): unknown { +function createUpdateBuilderProxy(builder: Record, warnings: string[]): unknown { const config: ProxyConfig = { onMethod: { set: (target, args, _state, _factory) => { const base = target.set(...args); - return createMutationProxy(base, warnings, "update", tableName, tablesWritten); + return createMutationProxy(base, warnings, "update"); }, }, onExecute: (target, prop, args) => target[prop](...args), @@ -74,8 +60,27 @@ export function createTrackedDb>( cascadeGraph: Map>, warnings: string[], ): unknown { - return new Proxy(realDb as any, { - get(target: any, prop: string) { + // Wire in client-level SQL tracking if the db exposes its underlying storage + const client = (realDb as unknown as Record).$client; + if (!client || typeof client !== "object" || !("sql" in client)) { + console.warn("[EdgePod] Unable to wire SQL tracking: realDb.$client is missing or invalid."); + } else { + const trackedClient = createTrackedClient( + client as DurableObjectStorage, + tablesRead, + tablesWritten, + cascadeGraph, + ); + const session = (realDb as unknown as Record).session; + if (!session || typeof session !== "object") { + console.warn("[EdgePod] Unable to wire SQL tracking: realDb.session is missing."); + } else { + (session as Record).client = trackedClient; + } + } + + return new Proxy(realDb as unknown as Record, { + get(target: Record, prop: string) { if (FORBIDDEN_RAW_METHODS.includes(prop)) { throw new Error( `[EdgePod] Raw SQL via 'ctx.db.${prop}()' is blocked. Use ctx.db.select()/ctx.db.update(). ` + @@ -84,48 +89,37 @@ export function createTrackedDb>( } if (prop === "insert") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createInsertProxy(builder, MAX_LIMIT, tableName, tablesWritten); + return function (...args: unknown[]) { + const builder = (target[prop] as (...a: unknown[]) => unknown).apply(target, args); + return createInsertProxy(builder as Record, MAX_LIMIT); }; } if (prop === "update") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createUpdateBuilderProxy(builder, warnings, tableName, tablesWritten); + return function (...args: unknown[]) { + const builder = (target[prop] as (...a: unknown[]) => unknown).apply(target, args); + return createUpdateBuilderProxy(builder as Record, warnings); }; } if (prop === "delete") { - return function (table: unknown, ...restArgs: unknown[]) { - const tableName = getTableName(table as any) ?? "unknown"; - const builder = target[prop].apply(target, [table, ...restArgs]); - return createMutationProxy( - builder, - warnings, - "delete", - tableName, - tablesWritten, - cascadeGraph, - ); + return function (...args: unknown[]) { + const builder = (target[prop] as (...a: unknown[]) => unknown).apply(target, args); + return createMutationProxy(builder as Record, warnings, "delete"); }; } if (prop === "query") { const queryObject = target.query; if (!queryObject) return undefined; - return new Proxy(queryObject, { - get(queryTarget: any, tableProp: string) { + return new Proxy(queryObject as Record, { + get(queryTarget: Record, tableProp: string) { const tableApi = queryTarget[tableProp]; if (!tableApi) return undefined; const session = activeSessions.get(sessionId); if (session) session.listeningToTables.add(hashTableName(tableProp)); - tablesRead.add(tableProp); - return new Proxy(tableApi, { - get(tableTarget: any, method: string) { + return new Proxy(tableApi as Record, { + get(tableTarget: Record, method: string) { if (method === "findMany") { return function (opts: Record = {}) { const limit = @@ -135,8 +129,14 @@ export function createTrackedDb>( if (typeof opts.limit === "number" && opts.limit > MAX_LIMIT) { warnings.push(`Query limit of ${opts.limit} overridden to ${MAX_LIMIT}.`); } - trackWithRelations(opts, tablesRead, activeSessions, sessionId); - return tableTarget.findMany({ ...opts, limit }).then((result: unknown[]) => { + trackWithRelations(opts, activeSessions, sessionId); + const promise = ( + tableTarget.findMany as (...a: unknown[]) => Promise + )({ + ...opts, + limit, + }); + return promise.then((result: unknown[]) => { checkResultWarnings(result, warnings, MAX_LIMIT); return result; }); @@ -144,8 +144,8 @@ export function createTrackedDb>( } if (method === "findFirst") { return function (opts: Record = {}) { - trackWithRelations(opts, tablesRead, activeSessions, sessionId); - return tableTarget.findFirst(opts); + trackWithRelations(opts, activeSessions, sessionId); + return (tableTarget.findFirst as (...a: unknown[]) => unknown)(opts); }; } const value = tableTarget[method]; @@ -158,14 +158,8 @@ export function createTrackedDb>( if (prop === "select" || prop === "selectDistinct") { return function (...args: unknown[]) { - return createSelectProxy( - target[prop].apply(target, args), - sessionId, - activeSessions, - tablesRead, - warnings, - MAX_LIMIT, - ); + const builder = (target[prop] as (...a: unknown[]) => unknown).apply(target, args); + return createSelectProxy(builder as Record, warnings, MAX_LIMIT); }; } @@ -177,7 +171,6 @@ export function createTrackedDb>( function trackWithRelations( opts: Record, - tablesRead: Set, activeSessions: EdgePodSessionMap, sessionId: string, ) { @@ -186,12 +179,6 @@ function trackWithRelations( for (const relation of Object.keys(withOpt)) { const session = activeSessions.get(sessionId); if (session) session.listeningToTables.add(hashTableName(relation)); - tablesRead.add(relation); - trackWithRelations( - withOpt[relation] as Record, - tablesRead, - activeSessions, - sessionId, - ); + trackWithRelations(withOpt[relation] as Record, activeSessions, sessionId); } } diff --git a/packages/server/src/tools/parseSqlTracking.test.ts b/packages/server/src/tools/parseSqlTracking.test.ts new file mode 100644 index 0000000..5b97007 --- /dev/null +++ b/packages/server/src/tools/parseSqlTracking.test.ts @@ -0,0 +1,323 @@ +import { describe, it, expect } from "vitest"; +import { parseSqlTracking } from "./parseSqlTracking"; + +describe("parseSqlTracking — query type", () => { + it("detects SELECT", () => { + const r = parseSqlTracking('select "id" from "users"', []); + expect(r.queryType).toBe("select"); + }); + + it("detects INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.queryType).toBe("insert"); + }); + + it("detects UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "id" = ?', ["new", 1]); + expect(r.queryType).toBe("update"); + }); + + it("detects DELETE", () => { + const r = parseSqlTracking('delete from "users" where "id" = ?', [1]); + expect(r.queryType).toBe("delete"); + }); + + it("returns 'unknown' for garbage SQL", () => { + const r = parseSqlTracking("not sql at all", []); + expect(r.queryType).toBe("unknown"); + expect(r.tablesRead).toEqual([]); + expect(r.tablesWritten).toEqual([]); + expect(r.whereIds).toEqual([]); + }); +}); + +describe("parseSqlTracking — table extraction", () => { + it("extracts table from SELECT", () => { + const r = parseSqlTracking('select * from "users"', []); + expect(r.tablesRead).toEqual(["users"]); + }); + + it("extracts multiple tables from JOIN", () => { + const r = parseSqlTracking( + 'select * from "users" left join "posts" on "posts"."user_id" = "users"."id"', + [], + ); + expect(r.tablesRead).toEqual(["users", "posts"]); + }); + + it("extracts tables from comma-joined FROM", () => { + const r = parseSqlTracking('select * from "users", "posts"', []); + expect(r.tablesRead).toEqual(["users", "posts"]); + }); + + it("extracts table from INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("extracts table from UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ?', ["test"]); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("extracts table from DELETE", () => { + const r = parseSqlTracking('delete from "users"', []); + expect(r.tablesWritten).toEqual(["users"]); + }); + + it("detects subquery tables in WHERE IN (SELECT)", () => { + const r = parseSqlTracking( + 'select * from "users" where "id" in (select "user_id" from "banned")', + [], + ); + expect(r.tablesRead).toEqual(["users", "banned"]); + }); + + it("detects subquery tables in WHERE EXISTS", () => { + const r = parseSqlTracking( + 'select * from "users" where exists (select 1 from "orders" where "orders"."user_id" = "users"."id")', + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("orders"); + expect(r.tablesRead).toHaveLength(2); + }); + + it("detects subquery tables in INSERT...SELECT", () => { + const r = parseSqlTracking('insert into "logs" select * from "old_logs"', []); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toEqual(["old_logs"]); + }); + + it("detects subquery tables in UPDATE with scalar subquery", () => { + const r = parseSqlTracking( + 'update "users" set "name" = (select "name" from "profiles" where "profiles"."id" = "users"."id")', + [], + ); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.tablesRead).toEqual(["profiles"]); + }); + + it("detects subquery tables in DELETE with IN subquery", () => { + const r = parseSqlTracking( + 'delete from "users" where "id" in (select "id" from "banned_users")', + [], + ); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.tablesRead).toEqual(["banned_users"]); + }); + + it("detects multiple subquery tables with JOIN in outer query", () => { + const r = parseSqlTracking( + 'select * from "users" join "posts" on "users"."id" = "posts"."user_id" where "posts"."id" in (select "post_id" from "comments")', + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("posts"); + expect(r.tablesRead).toContain("comments"); + expect(r.tablesRead).toHaveLength(3); + }); + + it("does not track subquery tables as reads for write ops when same as main table", () => { + const r = parseSqlTracking('insert into "logs" select * from "logs"', []); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toEqual([]); + }); +}); + +describe("parseSqlTracking — CTEs, compounds, subqueries", () => { + it("excludes CTE aliases from tablesRead", () => { + const r = parseSqlTracking("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", []); + expect(r.tablesRead).toEqual(["users"]); + }); + + it("excludes CTE aliases from JOIN in main query", () => { + const r = parseSqlTracking( + "WITH a AS (SELECT * FROM users) SELECT * FROM a JOIN b ON a.id = b.id", + [], + ); + expect(r.tablesRead).toEqual(["users", "b"]); + }); + + it("collects tables from scalar subquery in SELECT list (no outer FROM)", () => { + const r = parseSqlTracking("SELECT (SELECT name FROM profiles LIMIT 1) AS n", []); + expect(r.tablesRead).toEqual(["profiles"]); + }); + + it("collects tables from scalar subquery in SELECT list (with outer FROM)", () => { + const r = parseSqlTracking("SELECT (SELECT count FROM stats) AS c FROM users", []); + expect(r.tablesRead).toEqual(["users", "stats"]); + }); + + it("collects all table references from UNION", () => { + const r = parseSqlTracking("SELECT * FROM users UNION SELECT * FROM admins", []); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("admins"); + expect(r.tablesRead).toHaveLength(2); + }); + + it("collects tables from subquery in HAVING", () => { + const r = parseSqlTracking( + "SELECT user_id, count(*) FROM orders GROUP BY user_id HAVING count(*) > (SELECT avg(count) FROM stats)", + [], + ); + expect(r.tablesRead).toContain("orders"); + expect(r.tablesRead).toContain("stats"); + }); + + it("collects tables from subquery in ORDER BY", () => { + const r = parseSqlTracking( + "SELECT * FROM users ORDER BY (SELECT score FROM leaderboard LIMIT 1)", + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("leaderboard"); + }); + + it("collects tables from derived table in JOIN", () => { + const r = parseSqlTracking( + "SELECT * FROM users JOIN (SELECT * FROM posts) AS p ON users.id = p.user_id", + [], + ); + expect(r.tablesRead).toContain("users"); + expect(r.tablesRead).toContain("posts"); + }); + + it("collects tables from deeply nested subqueries", () => { + const r = parseSqlTracking( + "SELECT * FROM users WHERE id IN (SELECT id FROM (SELECT * FROM banned))", + [], + ); + expect(r.tablesRead).toEqual(["users", "banned"]); + }); + + it("collects subquery tables for INSERT...SELECT with compound", () => { + const r = parseSqlTracking( + "INSERT INTO logs SELECT * FROM old_logs UNION SELECT * FROM archived_logs", + [], + ); + expect(r.tablesWritten).toEqual(["logs"]); + expect(r.tablesRead).toContain("old_logs"); + expect(r.tablesRead).toContain("archived_logs"); + }); + + it("collects CTE tables for write ops with subquery", () => { + const r = parseSqlTracking( + "WITH cte AS (SELECT * FROM source) INSERT INTO target SELECT * FROM cte", + [], + ); + expect(r.tablesWritten).toEqual(["target"]); + expect(r.tablesRead).toEqual(["source"]); + }); +}); + +describe("parseSqlTracking — WHERE ID extraction", () => { + it("extracts simple WHERE id = ?", () => { + const r = parseSqlTracking('select * from "users" where "users"."id" = ?', [42]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([0]); + }); + + it("extracts WHERE id = ? from UPDATE", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "users"."id" = ?', ["new", 1]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([1]); + }); + + it("extracts WHERE id IN (?, ?)", () => { + const r = parseSqlTracking('select * from "users" where "users"."id" in (?, ?)', [1, 2]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe("users"); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([0, 1]); + }); + + it("handles unqualified column in WHERE", () => { + const r = parseSqlTracking('select * from "users" where "id" = ?', [42]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].tableHint).toBe(""); + expect(r.whereIds[0].column).toBe("id"); + }); + + it("returns empty whereIds for queries without WHERE", () => { + const r = parseSqlTracking('select * from "users"', []); + expect(r.whereIds).toEqual([]); + }); + + it("returns empty whereIds for INSERT", () => { + const r = parseSqlTracking('insert into "users" ("name") values (?)', ["test"]); + expect(r.whereIds).toEqual([]); + }); +}); + +describe("parseSqlTracking — raw SQL (uppercase)", () => { + it("parses uppercase SELECT", () => { + const r = parseSqlTracking("SELECT * FROM users WHERE users.id = ?", [1]); + expect(r.queryType).toBe("select"); + expect(r.tablesRead).toEqual(["users"]); + expect(r.whereIds).toHaveLength(1); + }); + + it("parses uppercase UPDATE", () => { + const r = parseSqlTracking("UPDATE users SET name = ? WHERE users.id = ?", ["new", 1]); + expect(r.queryType).toBe("update"); + expect(r.tablesWritten).toEqual(["users"]); + expect(r.whereIds).toHaveLength(1); + }); +}); + +describe("parseSqlTracking — WHERE scoping", () => { + it("extracts WHERE id = ?", () => { + const r = parseSqlTracking('select * from "users" where "users"."id" = ?', [42]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[0].paramIndices).toEqual([0]); + }); + + it("ignores JOIN ON condition", () => { + const r = parseSqlTracking( + 'select * from "users" left join "posts" on "posts"."user_id" = ?', + [1], + ); + expect(r.tablesRead).toEqual(["users", "posts"]); + expect(r.whereIds).toEqual([]); + }); + + it("ignores HAVING condition", () => { + const r = parseSqlTracking( + 'select user_id, count(*) from "orders" group by user_id having count(*) > ?', + [5], + ); + expect(r.whereIds).toEqual([]); + }); + + it("ignores params in SET clause (not WHERE)", () => { + const r = parseSqlTracking('update "users" set "name" = ? where "users"."id" = ?', ["new", 1]); + expect(r.whereIds).toHaveLength(1); + expect(r.whereIds[0].column).toBe("id"); + }); +}); + +describe("parseSqlTracking — edge cases", () => { + it("handles multiple AND conditions in WHERE", () => { + const r = parseSqlTracking( + 'select * from "users" where "users"."id" = ? and "users"."name" = ?', + [1, "test"], + ); + expect(r.whereIds).toHaveLength(2); + expect(r.whereIds[0].column).toBe("id"); + expect(r.whereIds[1].column).toBe("name"); + }); + + it("returns empty for parse errors", () => { + const r = parseSqlTracking("not valid sqlite", []); + expect(r.queryType).toBe("unknown"); + expect(r.tablesRead).toEqual([]); + expect(r.tablesWritten).toEqual([]); + expect(r.whereIds).toEqual([]); + }); +}); diff --git a/packages/server/src/tools/parseSqlTracking.ts b/packages/server/src/tools/parseSqlTracking.ts new file mode 100644 index 0000000..f5444a3 --- /dev/null +++ b/packages/server/src/tools/parseSqlTracking.ts @@ -0,0 +1,216 @@ +import { + type BinaryExpr, + type DeleteStmt, + type Expr, + type FromClause, + type Id, + type InListExpr, + type InsertStmt, + type JoinedSelectTable, + type Name, + type QualifiedExpr, + type QualifiedName, + type SelectTable, + type UpdateStmt, + type VariableExpr, + parseStmt, + traverse, +} from "sqlite3-parser"; + +export type ParsedQuery = { + queryType: "select" | "insert" | "update" | "delete" | "unknown"; + tablesRead: string[]; + tablesWritten: string[]; + whereIds: Array<{ tableHint: string; column: string; paramIndices: number[] }>; +}; + +function getNameText(n: Name): string { + return n.text; +} + +function getQualifiedName(qn: QualifiedName): string { + return getNameText(qn.objName); +} + +function getTableNameFromStmt(stmt: DeleteStmt | InsertStmt | UpdateStmt): string | null { + return getQualifiedName(stmt.tblName); +} + +function getTableNameFromSelectTable(st: SelectTable): string | null { + if (st.type === "TableSelectTable" || st.type === "TableCallSelectTable") { + return getQualifiedName(st.tblName); + } + return null; +} + +function getTableNameFromFromClause(from: FromClause): string | null { + const select = from.select; + if (!select) return null; + return getTableNameFromSelectTable(select); +} + +function getTableNameFromJoined(j: JoinedSelectTable): string | null { + return getTableNameFromSelectTable(j.table); +} + +export function parseSqlTracking(sql: string, params: unknown[]): ParsedQuery { + const result = parseStmt(sql); + if (result.status === "error") { + return { queryType: "unknown", tablesRead: [], tablesWritten: [], whereIds: [] }; + } + + const root = result.root; + let queryType: ParsedQuery["queryType"] = "unknown"; + let mainTable: string | null = null; + + if (root.type === "SelectStmt") { + queryType = "select"; + } else if (root.type === "InsertStmt") { + queryType = "insert"; + mainTable = getTableNameFromStmt(root); + } else if (root.type === "UpdateStmt") { + queryType = "update"; + mainTable = getTableNameFromStmt(root); + } else if (root.type === "DeleteStmt") { + queryType = "delete"; + mainTable = getTableNameFromStmt(root); + } + + // Collect all CTE aliases so we can exclude them from the table list + const cteAliases = new Set(); + traverse(root, { + enter(node) { + if (node.type === "CommonTableExpr") { + cteAliases.add(getNameText(node.tblName)); + } + }, + }); + + // Collect ALL table references from every SelectFrom node in the AST. + const allTables = new Set(); + traverse(root, { + enter(node) { + if (node.type !== "SelectFrom") return; + const from = node.from; + if (!from) return; + const name = getTableNameFromFromClause(from); + if (name && !cteAliases.has(name)) allTables.add(name); + if (from.joins) { + for (const join of from.joins) { + const joinName = getTableNameFromJoined(join); + if (joinName && !cteAliases.has(joinName)) allTables.add(joinName); + } + } + }, + }); + + const tablesRead: string[] = []; + const tablesWritten: string[] = []; + const whereIds: Array<{ tableHint: string; column: string; paramIndices: number[] }> = []; + + if (queryType === "select") { + tablesRead.push(...allTables); + } else if (queryType === "insert" || queryType === "update" || queryType === "delete") { + if (mainTable) tablesWritten.push(mainTable); + for (const t of allTables) { + if (t !== mainTable) tablesRead.push(t); + } + } + + // Map ? positions to param indices by collecting all VariableExpr order + const varExprs: Array<{ offset: number; node: VariableExpr }> = []; + traverse(root, { + enter(node) { + if (node.type === "VariableExpr" && node.name === "?") { + varExprs.push({ offset: node.span.offset, node }); + } + }, + }); + varExprs.sort((a, b) => a.offset - b.offset); + + const paramIndexForOffset = new Map(); + varExprs.forEach((ve, i) => { + paramIndexForOffset.set(ve.offset, i); + }); + + // Collect WHERE expressions from all WHERE clauses in the AST + const whereExprs: Expr[] = []; + traverse(root, { + enter(node) { + if ((node.type === "DeleteStmt" || node.type === "UpdateStmt") && node.whereClause) { + whereExprs.push(node.whereClause); + } + if (node.type === "SelectFrom" && node.whereClause) { + whereExprs.push(node.whereClause); + } + }, + }); + + // Extract row IDs only from WHERE-clause expressions (not JOIN ON / HAVING) + for (const whereExpr of whereExprs) { + traverse(whereExpr, { + enter(node) { + // "id = ?" pattern + if (node.type === "BinaryExpr" && node.op === "Equals") { + const be = node as BinaryExpr; + const columnName = extractColumnName(be.left); + if (!columnName) return; + const paramOffset = extractParamOffset(be.right); + if (paramOffset === -1) return; + const pIdx = paramIndexForOffset.get(paramOffset); + if (pIdx !== undefined && pIdx < params.length) { + const tableHint = extractTableHint(be.left); + whereIds.push({ tableHint, column: columnName, paramIndices: [pIdx] }); + } + } + + // "id IN (?, ?)" pattern + if (node.type === "InListExpr") { + const ie = node as InListExpr; + if (!ie.rhs) return; + const columnName = extractColumnName(ie.lhs); + if (!columnName) return; + const indices: number[] = []; + for (const item of ie.rhs) { + const paramOffset = extractParamOffset(item); + if (paramOffset === -1) continue; + const pIdx = paramIndexForOffset.get(paramOffset); + if (pIdx !== undefined && pIdx < params.length) { + indices.push(pIdx); + } + } + if (indices.length > 0) { + const tableHint = extractTableHint(ie.lhs); + whereIds.push({ tableHint, column: columnName, paramIndices: indices }); + } + } + }, + }); + } + + return { queryType, tablesRead, tablesWritten, whereIds }; +} + +function extractColumnName(node: Expr | null): string | null { + if (!node) return null; + if (node.type === "Id") return (node as Id).name; + if (node.type === "QualifiedExpr") { + const qe = node as QualifiedExpr; + return getNameText(qe.column); + } + return null; +} + +function extractTableHint(node: Expr | null): string { + if (!node) return ""; + if (node.type === "QualifiedExpr") { + const qe = node as QualifiedExpr; + return getNameText(qe.table); + } + return ""; +} + +function extractParamOffset(node: Expr | null): number { + if (!node || node.type !== "VariableExpr") return -1; + return (node as VariableExpr).span.offset; +} diff --git a/packages/server/src/tools/proxy.integration.test.ts b/packages/server/src/tools/proxy.integration.test.ts index f7be44d..92c4cc9 100644 --- a/packages/server/src/tools/proxy.integration.test.ts +++ b/packages/server/src/tools/proxy.integration.test.ts @@ -1,10 +1,9 @@ import { describe, it, expect, beforeEach } from "vitest"; -import Database from "better-sqlite3"; -import { drizzle } from "drizzle-orm/better-sqlite3"; import { sqliteTable, integer, text } from "drizzle-orm/sqlite-core"; import { eq } from "drizzle-orm"; import { createTrackedDb } from "./createTrackedDb"; -import type { RawDrizzleDb, EdgePodSessionMap } from "../types"; +import { createTestDb } from "../test-utils/createTestDb"; +import type { EdgePodSessionMap } from "../types"; const users = sqliteTable("users", { id: integer("id").primaryKey(), @@ -18,8 +17,7 @@ const posts = sqliteTable("posts", { }); function setup() { - const sqlite = new Database(":memory:"); - const db = drizzle({ client: sqlite, schema: { users, posts } }); + const { db, sqlite } = createTestDb({ users, posts }); sqlite.exec(` CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL); CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT NOT NULL, user_id INTEGER NOT NULL); @@ -34,7 +32,7 @@ function setup() { }); const trackedDb = createTrackedDb( - db as unknown as RawDrizzleDb, + db, "test-session", activeSessions, tablesRead, @@ -43,7 +41,7 @@ function setup() { warnings, ); - return { db: trackedDb as any, tablesRead, tablesWritten, warnings }; + return { db: trackedDb as any, rawDb: db, tablesRead, tablesWritten, warnings }; } describe("proxy integration — limit enforcement", () => { @@ -86,8 +84,7 @@ describe("proxy integration — WHERE enforcement", () => { it("allows update with WHERE", async () => { const { db } = setup(); - const result = await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); - expect(result).toBeDefined(); + await db.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); }); it("blocks delete without WHERE", () => { @@ -97,12 +94,11 @@ describe("proxy integration — WHERE enforcement", () => { it("allows delete with WHERE", async () => { const { db } = setup(); - const result = await db.delete(users).where(eq(users.id, 1)).run(); - expect(result).toBeDefined(); + await db.delete(users).where(eq(users.id, 1)).run(); }); }); -describe("proxy integration — table tracking", () => { +describe("proxy integration — table tracking via client proxy", () => { it("tracks insert as table write (async)", async () => { const { db, tablesWritten } = setup(); await db.insert(users).values({ name: "test" }); @@ -158,7 +154,7 @@ describe("proxy integration — insert chaining", () => { it("insert bulk at max limit succeeds", async () => { const { db } = setup(); const rows = Array(1000).fill({ name: "test" }); - await expect(db.insert(users).values(rows)).resolves.toBeDefined(); + await db.insert(users).values(rows); }); it("insert bulk over max limit throws", () => { @@ -199,3 +195,40 @@ describe("proxy integration — prepare", () => { expect(Array.isArray(result)).toBe(true); }); }); + +describe("proxy integration — unsafeRawDb tracking", () => { + it("tracks raw SELECT on unsafeRawDb", () => { + const { rawDb, tablesRead } = setup(); + rawDb.select().from(users).all(); + expect(tablesRead.has("users")).toBe(true); + }); + + it("tracks raw INSERT on unsafeRawDb", () => { + const { rawDb, tablesWritten } = setup(); + rawDb.insert(users).values({ name: "test" }).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks raw UPDATE on unsafeRawDb", () => { + const { rawDb, tablesWritten } = setup(); + rawDb.update(users).set({ name: "changed" }).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("tracks raw DELETE on unsafeRawDb", () => { + const { rawDb, tablesWritten } = setup(); + rawDb.delete(users).where(eq(users.id, 1)).run(); + expect(tablesWritten.has("users")).toBe(true); + }); + + it("unsafeRawDb bypasses safety enforcement (no WHERE block)", () => { + const { rawDb } = setup(); + expect(() => rawDb.delete(users).run()).not.toThrow(); + }); + + it("unsafeRawDb bypasses safety enforcement (no limit clamp)", () => { + const { rawDb, warnings } = setup(); + rawDb.select().from(users).limit(5000).all(); + expect(warnings).toHaveLength(0); + }); +}); diff --git a/packages/server/src/tools/recordMutation.ts b/packages/server/src/tools/recordMutation.ts deleted file mode 100644 index fbb7cdc..0000000 --- a/packages/server/src/tools/recordMutation.ts +++ /dev/null @@ -1,14 +0,0 @@ -export function recordMutationWithCascades( - tableName: string, - tablesWritten: Set, - cascadeGraph: Map>, -) { - if (tablesWritten.has(tableName)) return; - tablesWritten.add(tableName); - const children = cascadeGraph.get(tableName); - if (children) { - for (const child of children) { - recordMutationWithCascades(child, tablesWritten, cascadeGraph); - } - } -} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 95a34e7..574c04e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -112,6 +112,9 @@ importers: neverthrow: specifier: ^8.2.0 version: 8.2.0 + sqlite3-parser: + specifier: ^0.7.1 + version: 0.7.1 devDependencies: '@types/better-sqlite3': specifier: ^7.6.13 @@ -3149,6 +3152,10 @@ packages: spdx-license-ids@3.0.23: resolution: {integrity: sha512-CWLcCCH7VLu13TgOH+r8p1O/Znwhqv/dbb6lqWy67G+pT1kHmeD/+V36AVb/vq8QMIQwVShJ6Ssl5FPh0fuSdw==} + sqlite3-parser@0.7.1: + resolution: {integrity: sha512-+KcAIcmD9xk4Sz3hNsEI7QEGAMGl3s5PyigIf6ri/u1DzAuAmi5YZInjtnXKqHvx9ySl5e0RYILfCVhTdAeKJg==} + hasBin: true + sqlstring@2.3.3: resolution: {integrity: sha512-qC9iz2FlN7DQl3+wjwn3802RTyjCx7sDvfQEXchwa6CWOx07/WVfh91gBmQ9fahw8snwGEWU3xGzOt4tFyHLxg==} engines: {node: '>= 0.6'} @@ -6099,6 +6106,8 @@ snapshots: spdx-license-ids@3.0.23: {} + sqlite3-parser@0.7.1: {} + sqlstring@2.3.3: optional: true