diff --git a/db/database.py b/db/database.py index af3f39a..a8bbf3a 100644 --- a/db/database.py +++ b/db/database.py @@ -387,6 +387,38 @@ def _ensure_agenda_isolation_schema() -> None: "CREATE INDEX IF NOT EXISTS idx_deep_insights_agenda ON deep_insights(agenda_id)", best_effort_if_locked=_use_pg(), ) + # agenda_token_ledger lives in schema_agenda*.sql for fresh databases, but + # the best-effort schema-file replay can be skipped on existing PG + # databases (an earlier failed statement aborts the transaction), so + # create it explicitly here for pre-existing DBs. + if _use_pg(): + ledger_ddl = """ + CREATE TABLE IF NOT EXISTS agenda_token_ledger ( + id BIGSERIAL PRIMARY KEY, + agenda_id INTEGER NOT NULL, + operation TEXT NOT NULL, + tokens INTEGER NOT NULL DEFAULT 0, + cost_usd REAL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )""" + else: + ledger_ddl = """ + CREATE TABLE IF NOT EXISTS agenda_token_ledger ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agenda_id INTEGER NOT NULL, + operation TEXT NOT NULL, + tokens INTEGER NOT NULL DEFAULT 0, + cost_usd REAL, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (agenda_id) REFERENCES research_agendas(id) + )""" + _execute_startup_statement(conn, ledger_ddl, best_effort_if_locked=_use_pg()) + _execute_startup_statement( + conn, + "CREATE INDEX IF NOT EXISTS idx_agenda_token_ledger_agenda" + " ON agenda_token_ledger(agenda_id, created_at DESC)", + best_effort_if_locked=_use_pg(), + ) conn.commit() @@ -633,7 +665,13 @@ def init_db(): try: get_conn().execute(s) except Exception: - pass + # On Postgres a failed statement aborts + # the whole transaction; roll back so the + # remaining statements still execute. + try: + get_conn().rollback() + except Exception: + pass get_conn().commit() except Exception: pass @@ -656,7 +694,13 @@ def init_db(): try: get_conn().execute(s) except Exception: - pass + # On Postgres a failed statement aborts + # the whole transaction; roll back so the + # remaining statements still execute. + try: + get_conn().rollback() + except Exception: + pass get_conn().commit() except Exception: pass @@ -677,7 +721,13 @@ def init_db(): try: get_conn().execute(s) except Exception: - pass + # On Postgres a failed statement aborts + # the whole transaction; roll back so the + # remaining statements still execute. + try: + get_conn().rollback() + except Exception: + pass get_conn().commit() except Exception: pass @@ -694,7 +744,13 @@ def init_db(): try: get_conn().execute(s) except Exception: - pass + # On Postgres a failed statement aborts + # the whole transaction; roll back so the + # remaining statements still execute. + try: + get_conn().rollback() + except Exception: + pass get_conn().commit() except Exception: pass diff --git a/tests/test_agenda_budget.py b/tests/test_agenda_budget.py index 4a6ddad..2346ef8 100644 --- a/tests/test_agenda_budget.py +++ b/tests/test_agenda_budget.py @@ -344,3 +344,37 @@ def test_agenda_insights_endpoint_isolated(self): if __name__ == "__main__": unittest.main() + + +class LedgerMigrationTest(BudgetTestBase): + """Existing DBs created before agenda_token_ledger must get the table + from _ensure_agenda_isolation_schema, not only from the schema files + (on Postgres the best-effort schema replay can be skipped when an + earlier statement aborts the transaction).""" + + def test_ensure_schema_recreates_ledger_table(self): + db = self.db + db.execute("DROP TABLE agenda_token_ledger") + db.get_conn().commit() + rows = db.fetchall( + "SELECT name FROM sqlite_master WHERE type='table' AND name='agenda_token_ledger'" + ) + self.assertEqual(rows, []) + + db._ensure_agenda_isolation_schema() + + rows = db.fetchall( + "SELECT name FROM sqlite_master WHERE type='table' AND name='agenda_token_ledger'" + ) + self.assertEqual(len(rows), 1) + + from agents import agenda_loader, agenda_budget + + agenda = agenda_loader.parse_agenda(dict(SAMPLE_AGENDA)) + agenda_id = agenda_loader.save_agenda(agenda) + agenda_budget.record_usage(agenda_id, "test_op", tokens=42) + total = db.fetchone( + "SELECT SUM(tokens) AS t FROM agenda_token_ledger WHERE agenda_id = ?", + (agenda_id,), + ) + self.assertEqual(total["t"], 42)