Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,33 @@ class AioQueryMixin:
async def aio_execute(self, database: AioDatabase) -> Any:
return await database.aio_execute(self)

async def fetch_results(self, cursor: CursorProtocol) -> Any:
async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> Any:
return await fetch_models(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | int:
if self._returning:
class _AioWriteQueryMixin(AioQueryMixin):
async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> Any:
if self._return_cursor: # type: ignore
return await fetch_models(cursor, self)
return cursor.rowcount
return await database.aio_rows_affected(cursor)


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | int:
if self._returning:
return await fetch_models(cursor, self)
return cursor.rowcount
class AioModelDelete(peewee.ModelDelete, _AioWriteQueryMixin): ...


class AioModelUpdate(peewee.ModelUpdate, _AioWriteQueryMixin): ...


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | Any | int:
if self._returning is not None and len(self._returning) > 1:
async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | Any | int:
if self._returning is None and database.returning_clause and self.table._primary_key: # type: ignore
self._returning = (self.table._primary_key,)
return await database.aio_last_insert_id(cursor, self)
if self._return_cursor:
return await fetch_models(cursor, self)

if self._returning:
row = await cursor.fetchone()
return row[0] if row else None
else:
return cursor.lastrowid
if self._as_rowcount:
return await database.aio_rows_affected(cursor)
return await database.aio_last_insert_id(cursor, self)


class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
Expand All @@ -92,7 +91,7 @@ async def aio_peek(self, database: AioDatabase, n: int = 1) -> Any:
`peewee.SelectBase.peek <https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.peek>`_
"""

async def fetch_results(cursor: CursorProtocol) -> Any:
async def fetch_results(database: AioDatabase, cursor: CursorProtocol) -> Any:
return await fetch_models(cursor, self, n)

rows = await database.aio_execute(self, fetch_results=fetch_results)
Expand Down
45 changes: 25 additions & 20 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import contextlib
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
from contextlib import AbstractAsyncContextManager
from typing import Any

import peewee
from playhouse import postgres_ext as ext

from peewee_async.result_wrappers import fetch_models

from .connection import ConnectionContextManager, connection_context
from .pool import MysqlPoolBackend, PoolBackend, PostgresqlPoolBackend, PsycopgPoolBackend
from .transactions import Transaction
from .utils import FetchResults, __log__, aiomysql, aiopg, psycopg
from .utils import CursorProtocol, __log__

FetchResults = Callable[["AioDatabase", CursorProtocol], Awaitable[Any]]


class AioDatabase(peewee.Database):
Expand Down Expand Up @@ -162,7 +166,7 @@ async def aio_execute_sql(
async with connection.cursor() as cursor:
await cursor.execute(sql, params or ())
if fetch_results is not None:
return await fetch_results(cursor)
return await fetch_results(self, cursor)

async def aio_execute(self, query: Any, fetch_results: FetchResults | None = None) -> Any:
"""Execute *SELECT*, *INSERT*, *UPDATE* or *DELETE* query asyncronously.
Expand All @@ -178,8 +182,24 @@ async def aio_execute(self, query: Any, fetch_results: FetchResults | None = Non
fetch_results = fetch_results or getattr(query, "fetch_results", None)
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)

async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> int:
return cursor.lastrowid

async def aio_rows_affected(self, cursor: CursorProtocol) -> int:
return cursor.rowcount


class Psycopg3Database(AioDatabase, ext.Psycopg3Database):
class AioPostgresDatabase(AioDatabase):
async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> Any:
if query._query_type == peewee.Insert.SIMPLE:
try:
return (await cursor.fetchmany(1))[0][0]
except (IndexError, KeyError, TypeError):
return None
return await fetch_models(cursor, query)


class Psycopg3Database(AioPostgresDatabase, ext.Psycopg3Database):
"""Extension for `playhouse.Psycopg3Database` providing extra methods
for managing async connection based on psycopg3 pool backend.

Expand All @@ -204,13 +224,8 @@ class Psycopg3Database(AioDatabase, ext.Psycopg3Database):

pool_backend_cls = PsycopgPoolBackend

def init(self, database: str | None, **kwargs: Any) -> None:
if not psycopg:
raise Exception("Error, psycopg is not installed!")
super().init(database, **kwargs)


class PostgresqlDatabase(AioDatabase, ext.PostgresqlExtDatabase):
class PostgresqlDatabase(AioPostgresDatabase, ext.PostgresqlExtDatabase):
"""Extension for `playhouse.PostgresqlDatabase` providing extra methods
for managing async connection based on aiopg pool backend.

Expand Down Expand Up @@ -240,11 +255,6 @@ class PostgresqlDatabase(AioDatabase, ext.PostgresqlExtDatabase):
def init_pool_params_defaults(self) -> None:
self.pool_params.update({"enable_json": True, "enable_hstore": self._register_hstore})

def init(self, database: str | None, **kwargs: Any) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
super().init(database, **kwargs)


class MySQLDatabase(AioDatabase, peewee.MySQLDatabase):
"""MySQL database driver providing **single drop-in sync**
Expand Down Expand Up @@ -274,8 +284,3 @@ class MySQLDatabase(AioDatabase, peewee.MySQLDatabase):

def init_pool_params_defaults(self) -> None:
self.pool_params.update({"autocommit": True})

def init(self, database: str | None, **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
super().init(database, **kwargs)
18 changes: 17 additions & 1 deletion peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@
import asyncio
from typing import Any, cast

from .utils import ConnectionProtocol, aiomysql, aiopg, format_dsn, psycopg, psycopg_pool
from .utils import ConnectionProtocol, ModuleRequired, aiomysql, aiopg, format_dsn, psycopg, psycopg_pool


class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool."""

required_modules: list[str] = []

def __init__(self, *, database: str, **kwargs: Any) -> None:
self.check_required_backend()
self.pool: Any | None = None
self.database = database
self.connect_params = kwargs
self._connection_lock = asyncio.Lock()

def check_required_backend(self) -> None:
for module in self.required_modules:
try:
__import__(module)
except ImportError:
raise ModuleRequired(module) from None

@property
def is_connected(self) -> bool:
if self.pool is not None:
Expand Down Expand Up @@ -54,6 +64,8 @@ async def close(self) -> None:
class PostgresqlPoolBackend(PoolBackend):
"""Asynchronous database connection pool based on aiopg."""

required_modules = ["aiopg"]

async def create(self) -> None:
if "connect_timeout" in self.connect_params:
self.connect_params["timeout"] = self.connect_params.pop("connect_timeout")
Expand Down Expand Up @@ -83,6 +95,8 @@ def has_acquired_connections(self) -> bool:
class PsycopgPoolBackend(PoolBackend):
"""Asynchronous database connection pool based on psycopg + psycopg_pool."""

required_modules = ["psycopg", "psycopg_pool"]

async def create(self) -> None:
params = self.connect_params.copy()
pool = psycopg_pool.AsyncConnectionPool(
Expand Down Expand Up @@ -130,6 +144,8 @@ async def close(self) -> None:
class MysqlPoolBackend(PoolBackend):
"""Asynchronous database connection pool based on aiomysql."""

required_modules = ["aiomysql"]

async def create(self) -> None:
self.pool = await aiomysql.create_pool(db=self.database, **self.connect_params)

Expand Down
17 changes: 8 additions & 9 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging
from collections.abc import Awaitable, Callable, Sequence
from collections.abc import Sequence
from contextlib import AbstractAsyncContextManager
from typing import Any, Protocol

try:
import aiopg
import psycopg2
except ImportError:
aiopg = None # type: ignore
psycopg2 = None

try:
import psycopg
import psycopg_pool
Expand All @@ -19,10 +16,8 @@

try:
import aiomysql
import pymysql
except ImportError:
aiomysql = None
pymysql = None # type: ignore

__log__ = logging.getLogger("peewee.async")
__log__.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -51,8 +46,12 @@ class ConnectionProtocol(Protocol):
def cursor(self, **kwargs: Any) -> AbstractAsyncContextManager[CursorProtocol]: ...


FetchResults = Callable[[CursorProtocol], Awaitable[Any]]


def format_dsn(protocol: str, host: str, port: str | int, user: str, password: str, path: str = "") -> str:
return f"{protocol}://{user}:{password}@{host}:{port}/{path}"


class ModuleRequired(Exception):
def __init__(self, package: str) -> None:
self.package = package
self.message = f"{package} is not installed"
super().__init__(self.message)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ mysql = [
"cryptography>=46.0.5"
]
psycopg = [
"psycopg>=3.3.0",
"psycopg-pool>=3.3.0"
"psycopg[binary,pool]>=3.3.0",
]
docs = [
"Sphinx>=8.1.3",
Expand Down
25 changes: 25 additions & 0 deletions tests/aio_model/test_deleting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid

import pytest
from playhouse.shortcuts import model_to_dict

from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, dbs_postgres
Expand Down Expand Up @@ -50,3 +51,27 @@ async def test_delete__return_model(db: AioDatabase) -> None:

res = await TestModel.delete().returning(TestModel).aio_execute()
assert model_has_fields(res[0], {"id": m.id, "text": m.text, "data": m.data}) is True


@dbs_postgres
async def test_delete__return_dicts(db: AioDatabase) -> None:
m = await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel).dicts().aio_execute()
assert res == [{"id": m.id, "text": "text", "data": "data"}]


@dbs_postgres
async def test_delete__return_tuples(db: AioDatabase) -> None:
m = await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel).tuples().aio_execute()
assert res == [(m.id, "text", "data")]


@dbs_postgres
async def test_delete__return_field(db: AioDatabase) -> None:
await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel.data).aio_execute()
assert model_to_dict(res[0]) == {"id": None, "text": None, "data": "data"}
Loading
Loading