Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ async def aio_close(self) -> None:

await self.pool_backend.close()

async def _aio_begin(self, use_savepoint: bool = False) -> Transaction:
_connection_context = connection_context.get()
if _connection_context is None:
raise Exception("This method can only be called within the aio_connection context manager")
tr = Transaction(_connection_context.connection, is_savepoint=use_savepoint)
await tr.begin()
return tr

async def aio_begin(self) -> Transaction:
return await self._aio_begin()

async def aio_savepoint(self) -> Transaction:
return await self._aio_begin(use_savepoint=True)

def aio_atomic(self) -> AbstractAsyncContextManager[None]:
"""Create an async context-manager which runs any queries in the wrapped block
in a transaction (or save-point if blocks are nested).
Expand Down
13 changes: 5 additions & 8 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ async def t3() -> None:

@dbs_all
async def test_transaction_manual_work(db: AioDatabase) -> None:
async with db.aio_connection() as connection:
tr = Transaction(connection)
await tr.begin()
async with db.aio_connection():
tr = await db.aio_begin()
await TestModel.aio_create(text="FOO")
assert await TestModel.aio_get_or_none(text="FOO") is not None
try:
Expand Down Expand Up @@ -180,14 +179,12 @@ async def test_savepoint_rollback(db: AioDatabase) -> None:

@dbs_all
async def test_savepoint_manual_work(db: AioDatabase) -> None:
async with db.aio_connection() as connection:
tr = Transaction(connection)
await tr.begin()
async with db.aio_connection():
tr = await db.aio_begin()
await TestModel.aio_create(text="FOO")
assert await TestModel.aio_get_or_none(text="FOO") is not None

savepoint = Transaction(connection, is_savepoint=True)
await savepoint.begin()
savepoint = await db.aio_savepoint()
try:
await TestModel.aio_create(text="FOO")
except: # noqa: E722
Expand Down
Loading