Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions docs/peewee_async/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,30 @@ Databases

.. automethod:: peewee_async.databases.AioDatabase.aio_transaction

.. autoclass:: peewee_async.PsycopgDatabase
:members: init
.. automethod:: peewee_async.databases.AioDatabase.aio_create_tables

.. automethod:: peewee_async.databases.AioDatabase.aio_drop_tables

.. autoclass:: peewee_async.PooledPostgresqlDatabase
.. autoclass:: peewee_async.Psycopg3Database
:members: init

.. autoclass:: peewee_async.PooledPostgresqlExtDatabase
.. autoclass:: peewee_async.PostgresqlDatabase
:members: init

.. autoclass:: peewee_async.PooledMySQLDatabase
.. autoclass:: peewee_async.MySQLDatabase
:members: init

AioModel
++++++++++

.. autoclass:: peewee_async.AioModel

.. automethod:: peewee_async.AioModel.aio_create_table

.. automethod:: peewee_async.AioModel.aio_drop_table

.. automethod:: peewee_async.AioModel.aio_truncate_table

.. automethod:: peewee_async.AioModel.aio_get

.. automethod:: peewee_async.AioModel.aio_get_or_none
Expand Down
116 changes: 109 additions & 7 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,72 @@
from .utils import CursorProtocol


class AioSchemaManager(peewee.SchemaManager):
async def aio_create_table(self, safe: bool = True, **options: Any) -> None:
await self.database.aio_execute(self._create_table(safe=safe, **options))

async def aio_drop_table(self, safe: bool = True, **options: Any) -> None:
await self.database.aio_execute(self._drop_table(safe=safe, **options))

async def aio_truncate_table(self, restart_identity: bool = False, cascade: bool = False) -> None:
await self.database.aio_execute(self._truncate_table(restart_identity, cascade))

async def aio_create_indexes(self, safe: bool = True) -> None:
for query in self._create_indexes(safe=safe):
await self.database.aio_execute(query)

async def aio_drop_indexes(self, safe: bool = True) -> None:
for query in self._drop_indexes(safe=safe):
await self.database.aio_execute(query)

async def _aio_create_sequence(self, field: peewee.Field) -> Any:
self._check_sequences(field)
if not await self.database.aio_sequence_exists(field.sequence):
return self._create_context().literal("CREATE SEQUENCE ").sql(self._sequence_for_field(field))

async def aio_create_sequence(self, field: peewee.Field) -> None:
seq_ctx = await self._aio_create_sequence(field)
if seq_ctx is not None:
await self.database.aio_execute(seq_ctx)

async def aio_create_sequences(self) -> None:
if self.database.sequences:
for field in self.model._meta.sorted_fields:
if field.sequence:
await self.aio_create_sequence(field)

async def _aio_drop_sequence(self, field: peewee.Field) -> Any:
self._check_sequences(field)
if await self.database.aio_sequence_exists(field.sequence):
return self._create_context().literal("DROP SEQUENCE ").sql(self._sequence_for_field(field))

async def aio_drop_sequence(self, field: peewee.Field) -> None:
seq_ctx = await self._aio_drop_sequence(field)
if seq_ctx is not None:
self.database.aio_execute(seq_ctx)

async def aio_drop_sequences(self) -> None:
if self.database.sequences:
for field in self.model._meta.sorted_fields:
if field.sequence:
await self.aio_drop_sequence(field)

async def aio_create_all(self, safe: bool = True, **table_options: Any) -> None:
await self.aio_create_sequences()
await self.aio_create_table(safe, **table_options)
await self.aio_create_indexes(safe=safe)

async def aio_drop_all(self, safe: bool = True, drop_sequences: bool = True, **options: Any) -> None:
await self.aio_drop_table(safe, **options)
if drop_sequences:
await self.aio_drop_sequences()


async def aio_prefetch(sq: Any, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
"""Asynchronous version of `prefetch()`.

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#prefetch
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#prefetch
"""
if not subqueries:
return sq
Expand Down Expand Up @@ -228,6 +289,47 @@ class User(peewee_async.AioModel):
user = await User.aio_get(User.username == 'user')
"""

class Meta:
schema_manager_class = AioSchemaManager

@classmethod
async def aio_table_exists(cls) -> bool:
M = cls._meta
return cast("bool", await cls._schema.database.aio_table_exists(M.table.__name__, M.schema))

@classmethod
async def aio_create_table(cls, safe: bool = True, **options: Any) -> None:
"""
Async version of **peewee.Model.create_table**
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.create_table
"""

if safe and not cls._schema.database.safe_create_index and await cls.aio_table_exists():
return
if cls._meta.temporary:
options.setdefault("temporary", cls._meta.temporary)
await cls._schema.aio_create_all(safe, **options)

@classmethod
async def aio_drop_table(cls, safe: bool = True, drop_sequences: bool = True, **options: Any) -> None:
"""
Async version of **peewee.Model.drop_table**
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.drop_table
"""
if safe and not cls._schema.database.safe_drop_index and not await cls.aio_table_exists():
return
if cls._meta.temporary:
options.setdefault("temporary", cls._meta.temporary)
await cls._schema.aio_drop_all(safe, drop_sequences, **options)

@classmethod
async def aio_truncate_table(cls, **options: Any) -> None:
"""
Async version of **peewee.Model.truncate_table**
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.truncate_table
"""
await cls._schema.aio_truncate_table(**options)

@classmethod
def select(cls, *fields: Any) -> AioModelSelect:
is_default = not fields
Expand Down Expand Up @@ -265,7 +367,7 @@ async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bo
Async version of **peewee.Model.delete_instance**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.delete_instance
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.delete_instance
"""
if recursive:
dependencies = self.dependencies(delete_nullable)
Expand All @@ -282,7 +384,7 @@ async def aio_save(self, force_insert: bool = False, only: Any = None) -> int |
Async version of **peewee.Model.save**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.save
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.save
"""
field_dict = self.__data__.copy()
if self._meta.primary_key is not False:
Expand Down Expand Up @@ -330,7 +432,7 @@ async def aio_get(cls, *query: Any, **filters: Any) -> Self:
"""Async version of **peewee.Model.get**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.get
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.get
"""
sq = cls.select()
if query:
Expand All @@ -348,7 +450,7 @@ async def aio_get_or_none(cls, *query: Any, **filters: Any) -> Self | None:
Async version of **peewee.Model.get_or_none**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.get_or_none
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.get_or_none
"""
try:
return await cls.aio_get(*query, **filters)
Expand All @@ -361,7 +463,7 @@ async def aio_create(cls, **query: Any) -> Self:
Async version of **peewee.Model.create**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.create
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.create
"""
inst = cls(**query)
await inst.aio_save(force_insert=True)
Expand All @@ -373,7 +475,7 @@ async def aio_get_or_create(cls, **kwargs: Any) -> tuple[Self, bool]:
Async version of **peewee.Model.get_or_create**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#Model.get_or_create
http://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Model.get_or_create
"""
defaults = kwargs.pop("defaults", {})
query = cls.select()
Expand Down
84 changes: 80 additions & 4 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import contextlib
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from typing import Any

Expand All @@ -16,6 +18,22 @@
FetchResults = Callable[["AioDatabase", CursorProtocol], Awaitable[Any]]


def fetchmany(count: int | None) -> FetchResults:

async def _fetch_results(db: AioDatabase, cursor: CursorProtocol) -> Sequence[Any]:
if count == 1:
return await cursor.fetchone()
if count is not None:
return await cursor.fetchmany(count)
return await cursor.fetchall()

return _fetch_results


fetchone = fetchmany(1)
fetchall = fetchmany(None)


class AioDatabase(peewee.Database):
"""Base async database driver providing **single drop-in sync**
connection and **async connections pool** interface.
Expand Down Expand Up @@ -158,7 +176,7 @@ def aio_connection(self) -> ConnectionContextManager:
return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(
self, sql: str, params: list[Any] | None = None, fetch_results: FetchResults | None = None
self, sql: str, params: Sequence[Any] | None = None, fetch_results: FetchResults | None = None
) -> Any:
__log__.debug((sql, params))
with peewee.__exception_wrapper__:
Expand Down Expand Up @@ -188,6 +206,35 @@ async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert)
async def aio_rows_affected(self, cursor: CursorProtocol) -> int:
return cursor.rowcount

async def aio_sequence_exists(self, seq: str) -> bool:
raise NotImplementedError

async def aio_get_tables(self, schema: str | None = None) -> list[str]:
raise NotImplementedError

async def aio_create_tables(self, models: list[Any], **options: Any) -> None:
"""
Async version of **peewee.Database.create_tables**
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Database.create_tables
"""
for model in peewee.sort_models(models):
await model.aio_create_table(**options)

async def aio_drop_tables(self, models: list[Any], **kwargs: Any) -> None:
"""
Async version of **peewee.Database.drop_tables**
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#Database.drop_tables
"""
for model in reversed(peewee.sort_models(models)):
await model.aio_drop_table(**kwargs)

async def aio_table_exists(self, table_name: Any, schema: str | None = None) -> bool:
if peewee.is_model(table_name):
model = table_name
table_name = model._meta.table_name
schema = model._meta.schema
return table_name in await self.aio_get_tables(schema=schema)


class AioPostgresDatabase(AioDatabase):
async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> Any:
Expand All @@ -198,6 +245,24 @@ async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert)
return None
return await fetch_models(cursor, query)

async def aio_sequence_exists(self, sequence: str) -> bool:
res = await self.aio_execute_sql(
"""
SELECT COUNT(*) FROM pg_class, pg_namespace
WHERE relkind='S'
AND pg_class.relnamespace = pg_namespace.oid
AND relname=%s""",
[
sequence,
],
fetch_results=fetchone,
)
return bool(res[0])

async def aio_get_tables(self, schema: str | None = None) -> list[str]:
query = "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname = %s ORDER BY tablename"
return [row for (row,) in await self.aio_execute_sql(query, (schema or "public",), fetch_results=fetchall)]


class Psycopg3Database(AioPostgresDatabase, ext.Psycopg3Database):
"""Extension for `playhouse.Psycopg3Database` providing extra methods
Expand All @@ -219,6 +284,7 @@ class Psycopg3Database(AioPostgresDatabase, ext.Psycopg3Database):
)

See also:
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#PostgresqlDatabase
https://www.psycopg.org/psycopg3/docs/advanced/pool.html
"""

Expand Down Expand Up @@ -247,7 +313,8 @@ class PostgresqlDatabase(AioPostgresDatabase, ext.PostgresqlExtDatabase):
)

See also:
https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#PostgresqlDatabase
https://aiopg.readthedocs.io/en/stable/
"""

pool_backend_cls = PostgresqlPoolBackend
Expand Down Expand Up @@ -277,10 +344,19 @@ class MySQLDatabase(AioDatabase, peewee.MySQLDatabase):
)

See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
https://docs.peewee-orm.com/en/4.0.0/peewee/api.html#MySQLDatabase
https://aiomysql.readthedocs.io/en/stable/
"""

pool_backend_cls = MysqlPoolBackend

def init_pool_params_defaults(self) -> None:
self.pool_params.update({"autocommit": True})

async def aio_get_tables(self, schema: str | None = None) -> list[str]:
query = (
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = DATABASE() AND table_type != %s "
"ORDER BY table_name"
)
return [row for (row,) in await self.aio_execute_sql(query, ("VIEW",), fetch_results=fetchall)]
2 changes: 1 addition & 1 deletion peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class SyncCursorAdapter:
def __init__(self, rows: list[Any], description: Sequence[Any] | None) -> None:
def __init__(self, rows: Sequence[Any], description: Sequence[Any] | None) -> None:
self._rows = rows
self.description = description
self._idx = 0
Expand Down
6 changes: 3 additions & 3 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@


class CursorProtocol(Protocol):
async def fetchone(self) -> Any: ...
async def fetchone(self) -> Sequence[Any]: ...

async def fetchall(self) -> list[Any]: ...
async def fetchall(self) -> Sequence[Any]: ...

async def fetchmany(self, size: int) -> list[Any]: ...
async def fetchmany(self, size: int) -> Sequence[Any]: ...

@property
def lastrowid(self) -> int: ...
Expand Down
Loading
Loading