From cc7aee45493c181c21a95dc78725558bb562e0d3 Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Tue, 10 Mar 2026 13:42:28 +0500 Subject: [PATCH 1/4] feat: switch on peewee 4.0 --- peewee_async/__init__.py | 18 ++++++------ peewee_async/databases.py | 58 ++++++++++----------------------------- pyproject.toml | 8 +++--- tests/conftest.py | 3 +- tests/db_config.py | 10 +++---- tests/test_database.py | 29 +------------------- 6 files changed, 32 insertions(+), 94 deletions(-) diff --git a/peewee_async/__init__.py b/peewee_async/__init__.py index 4e00f5b..b9a7f64 100644 --- a/peewee_async/__init__.py +++ b/peewee_async/__init__.py @@ -21,10 +21,9 @@ from peewee_async.aio_model import AioModel, aio_prefetch from peewee_async.connection import connection_context from peewee_async.databases import ( - PooledMySQLDatabase, - PooledPostgresqlDatabase, - PooledPostgresqlExtDatabase, - PsycopgDatabase, + MySQLDatabase, + PostgresqlDatabase, + Psycopg3Database, ) from peewee_async.pool import MysqlPoolBackend, PostgresqlPoolBackend from peewee_async.transactions import Transaction @@ -34,8 +33,8 @@ __all__ = [ "PooledPostgresqlDatabase", - "PooledPostgresqlExtDatabase", - "PooledMySQLDatabase", + "PostgresqlDatabase", + "MySQLDatabase", "Transaction", "AioModel", "aio_prefetch", @@ -44,7 +43,6 @@ "MysqlPoolBackend", ] -register_database(PooledPostgresqlDatabase, "postgres+pool+async", "postgresql+pool+async") -register_database(PooledPostgresqlExtDatabase, "postgresext+pool+async", "postgresqlext+pool+async") -register_database(PsycopgDatabase, "psycopg+pool+async", "psycopg+pool+async") -register_database(PooledMySQLDatabase, "mysql+pool+async") +register_database(PostgresqlDatabase, "postgres+pool+async", "postgresql+pool+async") +register_database(Psycopg3Database, "psycopg+pool+async", "psycopg+pool+async") +register_database(MySQLDatabase, "mysql+pool+async") diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 5b1bc34..114d338 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -1,12 +1,10 @@ import contextlib -import warnings from collections.abc import AsyncIterator, Iterator from contextlib import AbstractAsyncContextManager from typing import Any import peewee from playhouse import postgres_ext as ext -from playhouse.psycopg3_ext import Psycopg3Database from .connection import ConnectionContextManager, connection_context from .pool import MysqlPoolBackend, PoolBackend, PostgresqlPoolBackend, PsycopgPoolBackend @@ -22,17 +20,16 @@ class AioDatabase(peewee.Database): Example:: - database = PooledPostgresqlExtDatabase( + database = Psycopg3Database( 'database': 'postgres', 'host': '127.0.0.1', - 'port':5432, + 'port': 5432, 'password': 'postgres', 'user': 'postgres', 'pool_params': { - "minsize": 0, - "maxsize": 5, - "timeout": 30, - 'pool_recycle': 1.5 + "min_size": 0, + "max_size": 5, + 'max_lifetime': 15 } ) @@ -54,18 +51,6 @@ def init_pool_params_defaults(self) -> None: def init_pool_params(self) -> None: self.init_pool_params_defaults() - if "min_connections" in self.connect_params or "max_connections" in self.connect_params: - warnings.warn( - "`min_connections` and `max_connections` are deprecated, use `pool_params` instead.", - DeprecationWarning, - stacklevel=2, - ) - self.pool_params.update( - { - "minsize": self.connect_params.pop("min_connections", 1), - "maxsize": self.connect_params.pop("max_connections", 20), - } - ) pool_params = self.connect_params.pop("pool_params", {}) self.pool_params.update(pool_params) self.pool_params.update(self.connect_params) @@ -194,13 +179,13 @@ async def aio_execute(self, query: Any, fetch_results: FetchResults | None = Non return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) -class PsycopgDatabase(AioDatabase, Psycopg3Database): - """Extension for `peewee.PostgresqlDatabase` providing extra methods +class Psycopg3Database(AioDatabase, ext.Psycopg3Database): + """Extension for `playhouse.Psycopg3Database` providing extra methods for managing async connection based on psycopg3 pool backend. Example:: - database = PsycopgDatabase( + database = Psycopg3Database( 'database': 'postgres', 'host': '127.0.0.1', 'port': 5432, @@ -225,14 +210,14 @@ def init(self, database: str | None, **kwargs: Any) -> None: super().init(database, **kwargs) -class PooledPostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase): - """Extension for `peewee.PostgresqlDatabase` providing extra methods +class PostgresqlDatabase(AioDatabase, ext.PostgresqlExtDatabase): + """Extension for `playhouse.PostgresqlDatabase` providing extra methods for managing async connection based on aiopg pool backend. Example:: - database = PooledPostgresqlExtDatabase( + database = PostgresqlDatabase( 'database': 'postgres', 'host': '127.0.0.1', 'port':5432, @@ -253,7 +238,7 @@ class PooledPostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase): pool_backend_cls = PostgresqlPoolBackend def init_pool_params_defaults(self) -> None: - self.pool_params.update({"enable_json": False, "enable_hstore": False}) + self.pool_params.update({"enable_json": True, "enable_hstore": self._register_hstore}) def init(self, database: str | None, **kwargs: Any) -> None: if not aiopg: @@ -261,28 +246,13 @@ def init(self, database: str | None, **kwargs: Any) -> None: super().init(database, **kwargs) -class PooledPostgresqlExtDatabase(PooledPostgresqlDatabase, ext.PostgresqlExtDatabase): - """PosgtreSQL database extended driver providing **single drop-in sync** - connection and **async connections pool** interface based on aiopg pool backend. - - JSON fields support is enabled by default, HStore supports is disabled by - default, but can be enabled through pool_params or with ``register_hstore=False`` argument. - - See also: - https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase - """ - - def init_pool_params_defaults(self) -> None: - self.pool_params.update({"enable_json": True, "enable_hstore": self._register_hstore}) - - -class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase): +class MySQLDatabase(AioDatabase, peewee.MySQLDatabase): """MySQL database driver providing **single drop-in sync** connection and **async connections pool** interface. Example:: - database = PooledMySQLDatabase( + database = MySQLDatabase( 'database': 'mysql', 'host': '127.0.0.1', 'port': 3306, diff --git a/pyproject.toml b/pyproject.toml index 8a95260..7d3db4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ requires-python = ">=3.10" readme = "README.md" dependencies = [ - "peewee>=3.15.4,<4", + "peewee>=4,<5", "typing-extensions>=4.12.2" ] @@ -19,12 +19,12 @@ postgresql = [ "aiopg>=1.4.0", ] mysql = [ - "aiomysql>=0.2.0", + "aiomysql>=0.3.2", "cryptography>=46.0.5" ] psycopg = [ - "psycopg>=3.2.0", - "psycopg-pool>=3.2.0" + "psycopg>=3.3.0", + "psycopg-pool>=3.3.0" ] docs = [ "Sphinx>=8.1.3", diff --git a/tests/conftest.py b/tests/conftest.py index fc1fa89..6bc8ffd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,8 +90,7 @@ async def db(request: pytest.FixtureRequest) -> AsyncGenerator[AioDatabase, None PG_DBS = [ - "postgres-pool", - "postgres-pool-ext", + "aiopg-pool", "psycopg-pool", ] diff --git a/tests/db_config.py b/tests/db_config.py index c40a1a8..942808d 100644 --- a/tests/db_config.py +++ b/tests/db_config.py @@ -31,15 +31,13 @@ } DB_DEFAULTS = { - "postgres-pool": PG_DEFAULTS, - "postgres-pool-ext": PG_DEFAULTS, + "aiopg-pool": PG_DEFAULTS, "psycopg-pool": PSYCOPG_DEFAULTS, "mysql-pool": MYSQL_DEFAULTS, } DB_CLASSES = { - "postgres-pool": peewee_async.PooledPostgresqlDatabase, - "postgres-pool-ext": peewee_async.PooledPostgresqlExtDatabase, - "psycopg-pool": peewee_async.PsycopgDatabase, - "mysql-pool": peewee_async.PooledMySQLDatabase, + "aiopg-pool": peewee_async.PostgresqlDatabase, + "psycopg-pool": peewee_async.Psycopg3Database, + "mysql-pool": peewee_async.MySQLDatabase, } diff --git a/tests/test_database.py b/tests/test_database.py index ec83780..f1d44bf 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -76,22 +76,6 @@ async def test_deferred_init(db_name: str) -> None: await database.aio_close() -@pytest.mark.parametrize("db_name", ["postgres-pool", "postgres-pool-ext", "mysql-pool"]) -async def test_deprecated_min_max_connections_param(db_name: str) -> None: - default_params = DB_DEFAULTS[db_name].copy() - del default_params["pool_params"] - default_params["min_connections"] = 1 - default_params["max_connections"] = 3 - db_cls = DB_CLASSES[db_name] - database = db_cls(**default_params) - await database.aio_connect() - - assert database.pool_backend.pool.minsize == 1 # type: ignore - assert database.pool_backend.pool.maxsize == 3 # type: ignore - - await database.aio_close() - - @dbs_mysql async def test_mysql_params(db: AioDatabase) -> None: async with db.aio_connection() as connection_1: @@ -101,18 +85,7 @@ async def test_mysql_params(db: AioDatabase) -> None: assert db.pool_backend.pool.maxsize == 5 # type: ignore -@pytest.mark.parametrize("db", ["postgres-pool"], indirect=["db"]) -async def test_pg_json_hstore__params(db: AioDatabase) -> None: - await db.aio_connect() - assert db.pool_backend.pool._enable_json is False # type: ignore - assert db.pool_backend.pool._enable_hstore is False # type: ignore - assert db.pool_backend.pool._timeout == 30 # type: ignore - assert db.pool_backend.pool._recycle == 1.5 # type: ignore - assert db.pool_backend.pool.minsize == 0 # type: ignore - assert db.pool_backend.pool.maxsize == 5 # type: ignore - - -@pytest.mark.parametrize("db", ["postgres-pool-ext"], indirect=["db"]) +@pytest.mark.parametrize("db", ["aiopg-pool"], indirect=["db"]) async def test_pg_ext_json_hstore__params(db: AioDatabase) -> None: await db.aio_connect() assert db.pool_backend.pool._enable_json is True # type: ignore From b14994200cadfdaa45881c7975a9491ddd796dae Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Tue, 10 Mar 2026 13:53:35 +0500 Subject: [PATCH 2/4] feat: add database param to fetch_results --- peewee_async/aio_model.py | 10 +++++----- peewee_async/databases.py | 8 +++++--- peewee_async/utils.py | 5 +---- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/peewee_async/aio_model.py b/peewee_async/aio_model.py index 382ca4d..18b23e1 100644 --- a/peewee_async/aio_model.py +++ b/peewee_async/aio_model.py @@ -50,26 +50,26 @@ class AioQueryMixin: async def aio_execute(self, database: AioDatabase) -> Any: return await database.aio_execute(self) - async def fetch_results(self, cursor: CursorProtocol) -> Any: + async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> Any: return await fetch_models(cursor, self) class AioModelDelete(peewee.ModelDelete, AioQueryMixin): - async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | int: + async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | int: if self._returning: return await fetch_models(cursor, self) return cursor.rowcount class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin): - async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | int: + async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | int: if self._returning: return await fetch_models(cursor, self) return cursor.rowcount class AioModelInsert(peewee.ModelInsert, AioQueryMixin): - async def fetch_results(self, cursor: CursorProtocol) -> list[Any] | Any | int: + async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | Any | int: if self._returning is not None and len(self._returning) > 1: return await fetch_models(cursor, self) @@ -92,7 +92,7 @@ async def aio_peek(self, database: AioDatabase, n: int = 1) -> Any: `peewee.SelectBase.peek `_ """ - async def fetch_results(cursor: CursorProtocol) -> Any: + async def fetch_results(database: AioDatabase, cursor: CursorProtocol) -> Any: return await fetch_models(cursor, self, n) rows = await database.aio_execute(self, fetch_results=fetch_results) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 114d338..fdbf430 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -1,5 +1,5 @@ import contextlib -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator from contextlib import AbstractAsyncContextManager from typing import Any @@ -9,7 +9,9 @@ from .connection import ConnectionContextManager, connection_context from .pool import MysqlPoolBackend, PoolBackend, PostgresqlPoolBackend, PsycopgPoolBackend from .transactions import Transaction -from .utils import FetchResults, __log__, aiomysql, aiopg, psycopg +from .utils import CursorProtocol, __log__, aiomysql, aiopg, psycopg + +FetchResults = Callable[["AioDatabase", CursorProtocol], Awaitable[Any]] class AioDatabase(peewee.Database): @@ -162,7 +164,7 @@ async def aio_execute_sql( async with connection.cursor() as cursor: await cursor.execute(sql, params or ()) if fetch_results is not None: - return await fetch_results(cursor) + return await fetch_results(self, cursor) async def aio_execute(self, query: Any, fetch_results: FetchResults | None = None) -> Any: """Execute *SELECT*, *INSERT*, *UPDATE* or *DELETE* query asyncronously. diff --git a/peewee_async/utils.py b/peewee_async/utils.py index 36a4643..0e43c9e 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Sequence from contextlib import AbstractAsyncContextManager from typing import Any, Protocol @@ -51,8 +51,5 @@ class ConnectionProtocol(Protocol): def cursor(self, **kwargs: Any) -> AbstractAsyncContextManager[CursorProtocol]: ... -FetchResults = Callable[[CursorProtocol], Awaitable[Any]] - - def format_dsn(protocol: str, host: str, port: str | int, user: str, password: str, path: str = "") -> str: return f"{protocol}://{user}:{password}@{host}:{port}/{path}" From c7556efcd6dd56a8c3cf0b30b314f22aec152b3f Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Wed, 11 Mar 2026 17:26:50 +0500 Subject: [PATCH 3/4] fix(#328): make fetch_results like in peewee --- peewee_async/aio_model.py | 31 ++++---- peewee_async/databases.py | 22 +++++- tests/aio_model/test_deleting.py | 25 ++++++ tests/aio_model/test_inserting.py | 124 +++++++++++++++++------------- tests/aio_model/test_updating.py | 22 ++++-- tests/db_config.py | 17 ++-- 6 files changed, 159 insertions(+), 82 deletions(-) diff --git a/peewee_async/aio_model.py b/peewee_async/aio_model.py index 18b23e1..c0646f2 100644 --- a/peewee_async/aio_model.py +++ b/peewee_async/aio_model.py @@ -54,30 +54,29 @@ async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> return await fetch_models(cursor, self) -class AioModelDelete(peewee.ModelDelete, AioQueryMixin): - async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | int: - if self._returning: +class _AioWriteQueryMixin(AioQueryMixin): + async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> Any: + if self._return_cursor: # type: ignore return await fetch_models(cursor, self) - return cursor.rowcount + return await database.aio_rows_affected(cursor) -class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin): - async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | int: - if self._returning: - return await fetch_models(cursor, self) - return cursor.rowcount +class AioModelDelete(peewee.ModelDelete, _AioWriteQueryMixin): ... + + +class AioModelUpdate(peewee.ModelUpdate, _AioWriteQueryMixin): ... class AioModelInsert(peewee.ModelInsert, AioQueryMixin): async def fetch_results(self, database: AioDatabase, cursor: CursorProtocol) -> list[Any] | Any | int: - if self._returning is not None and len(self._returning) > 1: + if self._returning is None and database.returning_clause and self.table._primary_key: # type: ignore + self._returning = (self.table._primary_key,) + return await database.aio_last_insert_id(cursor, self) + if self._return_cursor: return await fetch_models(cursor, self) - - if self._returning: - row = await cursor.fetchone() - return row[0] if row else None - else: - return cursor.lastrowid + if self._as_rowcount: + return await database.aio_rows_affected(cursor) + return await database.aio_last_insert_id(cursor, self) class AioModelRaw(peewee.ModelRaw, AioQueryMixin): diff --git a/peewee_async/databases.py b/peewee_async/databases.py index fdbf430..d683085 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -6,6 +6,8 @@ import peewee from playhouse import postgres_ext as ext +from peewee_async.result_wrappers import fetch_models + from .connection import ConnectionContextManager, connection_context from .pool import MysqlPoolBackend, PoolBackend, PostgresqlPoolBackend, PsycopgPoolBackend from .transactions import Transaction @@ -180,8 +182,24 @@ async def aio_execute(self, query: Any, fetch_results: FetchResults | None = Non fetch_results = fetch_results or getattr(query, "fetch_results", None) return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) + async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> int: + return cursor.lastrowid + + async def aio_rows_affected(self, cursor: CursorProtocol) -> int: + return cursor.rowcount + + +class AioPgDatabase(AioDatabase): + async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> Any: + if query._query_type == peewee.Insert.SIMPLE: + try: + return (await cursor.fetchmany(1))[0][0] + except (IndexError, KeyError, TypeError): + return None + return await fetch_models(cursor, query) + -class Psycopg3Database(AioDatabase, ext.Psycopg3Database): +class Psycopg3Database(AioPgDatabase, ext.Psycopg3Database): """Extension for `playhouse.Psycopg3Database` providing extra methods for managing async connection based on psycopg3 pool backend. @@ -212,7 +230,7 @@ def init(self, database: str | None, **kwargs: Any) -> None: super().init(database, **kwargs) -class PostgresqlDatabase(AioDatabase, ext.PostgresqlExtDatabase): +class PostgresqlDatabase(AioPgDatabase, ext.PostgresqlExtDatabase): """Extension for `playhouse.PostgresqlDatabase` providing extra methods for managing async connection based on aiopg pool backend. diff --git a/tests/aio_model/test_deleting.py b/tests/aio_model/test_deleting.py index 058b28a..bf67e81 100644 --- a/tests/aio_model/test_deleting.py +++ b/tests/aio_model/test_deleting.py @@ -1,6 +1,7 @@ import uuid import pytest +from playhouse.shortcuts import model_to_dict from peewee_async.databases import AioDatabase from tests.conftest import dbs_all, dbs_postgres @@ -50,3 +51,27 @@ async def test_delete__return_model(db: AioDatabase) -> None: res = await TestModel.delete().returning(TestModel).aio_execute() assert model_has_fields(res[0], {"id": m.id, "text": m.text, "data": m.data}) is True + + +@dbs_postgres +async def test_delete__return_dicts(db: AioDatabase) -> None: + m = await TestModel.aio_create(text="text", data="data") + + res = await TestModel.delete().returning(TestModel).dicts().aio_execute() + assert res == [{"id": m.id, "text": "text", "data": "data"}] + + +@dbs_postgres +async def test_delete__return_tuples(db: AioDatabase) -> None: + m = await TestModel.aio_create(text="text", data="data") + + res = await TestModel.delete().returning(TestModel).tuples().aio_execute() + assert res == [(m.id, "text", "data")] + + +@dbs_postgres +async def test_delete__return_field(db: AioDatabase) -> None: + await TestModel.aio_create(text="text", data="data") + + res = await TestModel.delete().returning(TestModel.data).aio_execute() + assert model_to_dict(res[0]) == {"id": None, "text": None, "data": "data"} diff --git a/tests/aio_model/test_inserting.py b/tests/aio_model/test_inserting.py index 6ddc669..00e9094 100644 --- a/tests/aio_model/test_inserting.py +++ b/tests/aio_model/test_inserting.py @@ -1,79 +1,80 @@ import uuid -import pytest - from peewee_async.databases import AioDatabase -from tests.conftest import dbs_all, dbs_postgres +from tests.conftest import dbs_all, dbs_mysql, dbs_postgres from tests.models import TestModel, UUIDTestModel from tests.utils import model_has_fields -pytestmark = pytest.mark.use_transaction +# pytestmark = pytest.mark.use_transaction -@dbs_all -async def test_insert_many(db: AioDatabase) -> None: - last_id = await TestModel.insert_many( +@dbs_postgres +async def test_insert_many__pg(db: AioDatabase) -> None: + text1 = f"Test {uuid.uuid4()}" + text2 = f"Test {uuid.uuid4()}" + result = await TestModel.insert_many( [ - {"text": f"Test {uuid.uuid4()}"}, - {"text": f"Test {uuid.uuid4()}"}, + {"id": 1, "text": text1}, + {"id": 2, "text": text2}, ] ).aio_execute() - res = await TestModel.select().aio_execute() + assert sorted(result) == [(1,), (2,)] - assert len(res) == 2 - assert last_id in [m.id for m in res] + assert await TestModel.aio_get_or_none(id=1, text=text1) is not None + assert await TestModel.aio_get_or_none(id=2, text=text2) is not None -@dbs_all -async def test_insert__return_id(db: AioDatabase) -> None: - last_id = await TestModel.insert(text=f"Test {uuid.uuid4()}").aio_execute() - - res = await TestModel.select().aio_execute() - obj = res[0] - assert last_id == obj.id - - -@dbs_postgres -async def test_insert_on_conflict_ignore__last_id_is_none(db: AioDatabase) -> None: - query = TestModel.insert(text="text").on_conflict_ignore() - await query.aio_execute() - - last_id = await query.aio_execute() +@dbs_mysql +async def test_insert_many__mysql(db: AioDatabase) -> None: + text1 = f"Test {uuid.uuid4()}" + text2 = f"Test {uuid.uuid4()}" + result = await TestModel.insert_many( + [ + {"id": 1, "text": text1}, + {"id": 2, "text": text2}, + ] + ).aio_execute() - assert last_id is None + assert result in [1, 2] + assert await TestModel.aio_get_or_none(id=1, text=text1) is not None + assert await TestModel.aio_get_or_none(id=2, text=text2) is not None @dbs_postgres -async def test_insert_on_conflict_ignore__return_model(db: AioDatabase) -> None: - query = TestModel.insert(text="text", data="data").on_conflict_ignore().returning(TestModel) +async def test_insert_many__return_model(db: AioDatabase) -> None: + texts = [f"text{n}" for n in range(2)] + query = TestModel.insert_many([{"text": text} for text in texts]).returning(TestModel) res = await query.aio_execute() - inserted = res[0] - res = await TestModel.select().aio_execute() - expected = res[0] - - assert model_has_fields(inserted, {"id": expected.id, "text": expected.text, "data": expected.data}) is True + texts = [m.text for m in res] + assert sorted(texts) == ["text0", "text1"] -@dbs_postgres -async def test_insert_on_conflict_ignore__inserted_once(db: AioDatabase) -> None: - query = TestModel.insert(text="text").on_conflict_ignore() - last_id = await query.aio_execute() +@dbs_all +async def test_insert__as_row_count(db: AioDatabase) -> None: + result = ( + await TestModel.insert_many( + [ + {"id": 1, "text": "text1"}, + {"id": 2, "text": "text2"}, + ] + ) + .as_rowcount() + .aio_execute() + ) - await query.aio_execute() + assert result == 2 - res = await TestModel.select().aio_execute() - assert len(res) == 1 - assert res[0].id == last_id +@dbs_all +async def test_insert__return_id(db: AioDatabase) -> None: + last_id = await TestModel.insert(text=f"Test {uuid.uuid4()}").aio_execute() -@dbs_postgres -async def test_insert__uuid_pk(db: AioDatabase) -> None: - query = UUIDTestModel.insert(text=f"Test {uuid.uuid4()}") - last_id = await query.aio_execute() - assert len(str(last_id)) == 36 + res = await TestModel.select().aio_execute() + obj = res[0] + assert last_id == obj.id @dbs_postgres @@ -89,11 +90,28 @@ async def test_insert__return_model(db: AioDatabase) -> None: @dbs_postgres -async def test_insert_many__return_model(db: AioDatabase) -> None: - texts = [f"text{n}" for n in range(2)] - query = TestModel.insert_many([{"text": text} for text in texts]).returning(TestModel) +async def test_insert__uuid_pk(db: AioDatabase) -> None: + uid = "f85d03b2-001c-4da6-92c5-c0c925af0f70" + query = UUIDTestModel.insert(id=uid, text=f"Test {uuid.uuid4()}") + last_id = await query.aio_execute() + assert str(last_id) == uid + + +@dbs_postgres +async def test_insert_on_conflict_ignore__last_id_is_none(db: AioDatabase) -> None: + await TestModel.aio_create(id=5, text="text") + last_id = await TestModel.insert(id=5, text="text").on_conflict_ignore().aio_execute() + assert last_id is None + + +@dbs_postgres +async def test_insert_on_conflict_ignore__return_model(db: AioDatabase) -> None: + query = TestModel.insert(text="text", data="data").on_conflict_ignore().returning(TestModel) res = await query.aio_execute() - texts = [m.text for m in res] - assert sorted(texts) == ["text0", "text1"] + inserted = res[0] + res = await TestModel.select().aio_execute() + expected = res[0] + + assert model_has_fields(inserted, {"id": expected.id, "text": expected.text, "data": expected.data}) is True diff --git a/tests/aio_model/test_updating.py b/tests/aio_model/test_updating.py index 45ff760..7135711 100644 --- a/tests/aio_model/test_updating.py +++ b/tests/aio_model/test_updating.py @@ -1,6 +1,7 @@ import uuid import pytest +from playhouse.shortcuts import model_to_dict from peewee_async.databases import AioDatabase from tests.conftest import dbs_all, dbs_postgres @@ -30,10 +31,21 @@ async def test_update__field_updated(db: AioDatabase) -> None: @dbs_postgres async def test_update__returning_model(db: AioDatabase) -> None: - await TestModel.aio_create(text="text1", data="data") - await TestModel.aio_create(text="text2", data="data") + m = await TestModel.aio_create(text="text1", data="data") new_data = "New_data" - wrapper = await TestModel.update(data=new_data).where(TestModel.data == "data").returning(TestModel).aio_execute() + res = await TestModel.update(data=new_data).where(TestModel.data == "data").returning(TestModel).aio_execute() - result = [m.data for m in wrapper] - assert [new_data, new_data] == result + assert model_to_dict(res[0]) == {"id": m.id, "text": "text1", "data": "New_data"} + + +@dbs_postgres +async def test_update__returning_namedtuples(db: AioDatabase) -> None: + m = await TestModel.aio_create(text="text1", data="data") + new_data = "New_data" + res = await TestModel.update(data=new_data).returning(TestModel).namedtuples().aio_execute() + + tup = res[0] + + assert tup.id == m.id + assert tup.text == "text1" + assert tup.data == "New_data" diff --git a/tests/db_config.py b/tests/db_config.py index 942808d..4ad51bf 100644 --- a/tests/db_config.py +++ b/tests/db_config.py @@ -30,14 +30,19 @@ "pool_params": {"minsize": 0, "maxsize": 5, "pool_recycle": 2}, } +AIOPG_POOL = "aiopg-pool" +PSYCOPG_POOL = "psycopg-pool" +MYSQL_POOL = "mysql-pool" + + DB_DEFAULTS = { - "aiopg-pool": PG_DEFAULTS, - "psycopg-pool": PSYCOPG_DEFAULTS, - "mysql-pool": MYSQL_DEFAULTS, + AIOPG_POOL: PG_DEFAULTS, + PSYCOPG_POOL: PSYCOPG_DEFAULTS, + MYSQL_POOL: MYSQL_DEFAULTS, } DB_CLASSES = { - "aiopg-pool": peewee_async.PostgresqlDatabase, - "psycopg-pool": peewee_async.Psycopg3Database, - "mysql-pool": peewee_async.MySQLDatabase, + AIOPG_POOL: peewee_async.PostgresqlDatabase, + PSYCOPG_POOL: peewee_async.Psycopg3Database, + MYSQL_POOL: peewee_async.MySQLDatabase, } From c2c90707e2431f936c604ec6c583fee9a952e967 Mon Sep 17 00:00:00 2001 From: Gorshkov Nikolay Date: Thu, 12 Mar 2026 10:03:47 +0500 Subject: [PATCH 4/4] chore: refactor reqired backend --- peewee_async/databases.py | 23 ++++------------------- peewee_async/pool.py | 18 +++++++++++++++++- peewee_async/utils.py | 12 +++++++----- pyproject.toml | 3 +-- 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index d683085..bb70fcf 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -11,7 +11,7 @@ from .connection import ConnectionContextManager, connection_context from .pool import MysqlPoolBackend, PoolBackend, PostgresqlPoolBackend, PsycopgPoolBackend from .transactions import Transaction -from .utils import CursorProtocol, __log__, aiomysql, aiopg, psycopg +from .utils import CursorProtocol, __log__ FetchResults = Callable[["AioDatabase", CursorProtocol], Awaitable[Any]] @@ -189,7 +189,7 @@ async def aio_rows_affected(self, cursor: CursorProtocol) -> int: return cursor.rowcount -class AioPgDatabase(AioDatabase): +class AioPostgresDatabase(AioDatabase): async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) -> Any: if query._query_type == peewee.Insert.SIMPLE: try: @@ -199,7 +199,7 @@ async def aio_last_insert_id(self, cursor: CursorProtocol, query: peewee.Insert) return await fetch_models(cursor, query) -class Psycopg3Database(AioPgDatabase, ext.Psycopg3Database): +class Psycopg3Database(AioPostgresDatabase, ext.Psycopg3Database): """Extension for `playhouse.Psycopg3Database` providing extra methods for managing async connection based on psycopg3 pool backend. @@ -224,13 +224,8 @@ class Psycopg3Database(AioPgDatabase, ext.Psycopg3Database): pool_backend_cls = PsycopgPoolBackend - def init(self, database: str | None, **kwargs: Any) -> None: - if not psycopg: - raise Exception("Error, psycopg is not installed!") - super().init(database, **kwargs) - -class PostgresqlDatabase(AioPgDatabase, ext.PostgresqlExtDatabase): +class PostgresqlDatabase(AioPostgresDatabase, ext.PostgresqlExtDatabase): """Extension for `playhouse.PostgresqlDatabase` providing extra methods for managing async connection based on aiopg pool backend. @@ -260,11 +255,6 @@ class PostgresqlDatabase(AioPgDatabase, ext.PostgresqlExtDatabase): def init_pool_params_defaults(self) -> None: self.pool_params.update({"enable_json": True, "enable_hstore": self._register_hstore}) - def init(self, database: str | None, **kwargs: Any) -> None: - if not aiopg: - raise Exception("Error, aiopg is not installed!") - super().init(database, **kwargs) - class MySQLDatabase(AioDatabase, peewee.MySQLDatabase): """MySQL database driver providing **single drop-in sync** @@ -294,8 +284,3 @@ class MySQLDatabase(AioDatabase, peewee.MySQLDatabase): def init_pool_params_defaults(self) -> None: self.pool_params.update({"autocommit": True}) - - def init(self, database: str | None, **kwargs: Any) -> None: - if not aiomysql: - raise Exception("Error, aiomysql is not installed!") - super().init(database, **kwargs) diff --git a/peewee_async/pool.py b/peewee_async/pool.py index 1e5eefa..9a29f04 100644 --- a/peewee_async/pool.py +++ b/peewee_async/pool.py @@ -2,18 +2,28 @@ import asyncio from typing import Any, cast -from .utils import ConnectionProtocol, aiomysql, aiopg, format_dsn, psycopg, psycopg_pool +from .utils import ConnectionProtocol, ModuleRequired, aiomysql, aiopg, format_dsn, psycopg, psycopg_pool class PoolBackend(metaclass=abc.ABCMeta): """Asynchronous database connection pool.""" + required_modules: list[str] = [] + def __init__(self, *, database: str, **kwargs: Any) -> None: + self.check_required_backend() self.pool: Any | None = None self.database = database self.connect_params = kwargs self._connection_lock = asyncio.Lock() + def check_required_backend(self) -> None: + for module in self.required_modules: + try: + __import__(module) + except ImportError: + raise ModuleRequired(module) from None + @property def is_connected(self) -> bool: if self.pool is not None: @@ -54,6 +64,8 @@ async def close(self) -> None: class PostgresqlPoolBackend(PoolBackend): """Asynchronous database connection pool based on aiopg.""" + required_modules = ["aiopg"] + async def create(self) -> None: if "connect_timeout" in self.connect_params: self.connect_params["timeout"] = self.connect_params.pop("connect_timeout") @@ -83,6 +95,8 @@ def has_acquired_connections(self) -> bool: class PsycopgPoolBackend(PoolBackend): """Asynchronous database connection pool based on psycopg + psycopg_pool.""" + required_modules = ["psycopg", "psycopg_pool"] + async def create(self) -> None: params = self.connect_params.copy() pool = psycopg_pool.AsyncConnectionPool( @@ -130,6 +144,8 @@ async def close(self) -> None: class MysqlPoolBackend(PoolBackend): """Asynchronous database connection pool based on aiomysql.""" + required_modules = ["aiomysql"] + async def create(self) -> None: self.pool = await aiomysql.create_pool(db=self.database, **self.connect_params) diff --git a/peewee_async/utils.py b/peewee_async/utils.py index 0e43c9e..4d5e948 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -5,11 +5,8 @@ try: import aiopg - import psycopg2 except ImportError: aiopg = None # type: ignore - psycopg2 = None - try: import psycopg import psycopg_pool @@ -19,10 +16,8 @@ try: import aiomysql - import pymysql except ImportError: aiomysql = None - pymysql = None # type: ignore __log__ = logging.getLogger("peewee.async") __log__.addHandler(logging.NullHandler()) @@ -53,3 +48,10 @@ def cursor(self, **kwargs: Any) -> AbstractAsyncContextManager[CursorProtocol]: def format_dsn(protocol: str, host: str, port: str | int, user: str, password: str, path: str = "") -> str: return f"{protocol}://{user}:{password}@{host}:{port}/{path}" + + +class ModuleRequired(Exception): + def __init__(self, package: str) -> None: + self.package = package + self.message = f"{package} is not installed" + super().__init__(self.message) diff --git a/pyproject.toml b/pyproject.toml index 8c198f9..efa4b78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,7 @@ mysql = [ "cryptography>=46.0.5" ] psycopg = [ - "psycopg>=3.3.0", - "psycopg-pool>=3.3.0" + "psycopg[binary,pool]>=3.3.0", ] docs = [ "Sphinx>=8.1.3",