From 1261c3a9a7ab3ffda8ba2acb97bafa68d1429267 Mon Sep 17 00:00:00 2001 From: kalombo Date: Sun, 15 Mar 2026 12:15:12 +0500 Subject: [PATCH 1/2] draft: add shortcuts for transaction --- peewee_async/databases.py | 15 +++++++++++++++ tests/test_transaction.py | 13 +++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index e1e2b98..9db8ce1 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -101,6 +101,21 @@ async def aio_close(self) -> None: await self.pool_backend.close() + async def _aio_begin(self, use_savepoint: bool = False) -> Transaction: + _connection_context = connection_context.get() + if _connection_context is None: + raise Exception("This method can only be called within the aio_connection context manager") + tr = Transaction(_connection_context.connection, is_savepoint=use_savepoint) + await tr.begin() + return tr + + async def aio_begin(self) -> Transaction: + return await self._aio_begin() + + async def aio_savepoint(self) -> Transaction: + return await self._aio_begin(use_savepoint=True) + + def aio_atomic(self) -> AbstractAsyncContextManager[None]: """Create an async context-manager which runs any queries in the wrapped block in a transaction (or save-point if blocks are nested). diff --git a/tests/test_transaction.py b/tests/test_transaction.py index a4e5557..6be5a77 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -110,9 +110,8 @@ async def t3() -> None: @dbs_all async def test_transaction_manual_work(db: AioDatabase) -> None: - async with db.aio_connection() as connection: - tr = Transaction(connection) - await tr.begin() + async with db.aio_connection(): + tr = await db.aio_begin() await TestModel.aio_create(text="FOO") assert await TestModel.aio_get_or_none(text="FOO") is not None try: @@ -180,14 +179,12 @@ async def test_savepoint_rollback(db: AioDatabase) -> None: @dbs_all async def test_savepoint_manual_work(db: AioDatabase) -> None: - async with db.aio_connection() as connection: - tr = Transaction(connection) - await tr.begin() + async with db.aio_connection(): + tr = await db.aio_begin() await TestModel.aio_create(text="FOO") assert await TestModel.aio_get_or_none(text="FOO") is not None - savepoint = Transaction(connection, is_savepoint=True) - await savepoint.begin() + savepoint = await db.aio_savepoint() try: await TestModel.aio_create(text="FOO") except: # noqa: E722 From 13f1d225cf6976c3fbb844d84a1eb9f20fe2aa78 Mon Sep 17 00:00:00 2001 From: kalombo Date: Sun, 15 Mar 2026 12:18:11 +0500 Subject: [PATCH 2/2] fix format --- peewee_async/databases.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 9db8ce1..1f41094 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -108,13 +108,12 @@ async def _aio_begin(self, use_savepoint: bool = False) -> Transaction: tr = Transaction(_connection_context.connection, is_savepoint=use_savepoint) await tr.begin() return tr - + async def aio_begin(self) -> Transaction: return await self._aio_begin() - + async def aio_savepoint(self) -> Transaction: return await self._aio_begin(use_savepoint=True) - def aio_atomic(self) -> AbstractAsyncContextManager[None]: """Create an async context-manager which runs any queries in the wrapped block