Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql_db_utils/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.1"
__version__ = "1.1.0"
45 changes: 43 additions & 2 deletions sql_db_utils/asyncio/session_management.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Any, AsyncGenerator, Union
from typing import Annotated, Any, AsyncGenerator, Callable, List, Union

from sqlalchemy import Engine, MetaData, NullPool, text
from sqlalchemy.exc import OperationalError
Expand All @@ -13,11 +13,13 @@


class SQLSessionManager:
__slots__ = ("_db_engines", "database_uri")
__slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual")

def __init__(self, database_uri: Union[str, None] = None) -> None:
self._db_engines = {}
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI
self._postcreate_auto: dict = {}
self._postcreate_manual: dict = {}

def __del__(self) -> None:
for engine in self._db_engines.values():
Expand Down Expand Up @@ -82,6 +84,7 @@ async def _get_engine(
await create_default_psql_dependencies(
metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine
)
await self.run_postcreate(engine, database, tenant_id)
return engine

async def get_session(
Expand Down Expand Up @@ -115,3 +118,41 @@ async def get_db(tenant_id: Annotated[str, Cookie]):
yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)

return get_db

def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable:
postcreate_store = getattr(self, postcreate_store)

def decorator(func: Callable) -> None:
if isinstance(raw_db, list):
for db in raw_db:
postcreate_auto = postcreate_store.get(db, [])
postcreate_auto.append(func)
postcreate_store[db] = postcreate_auto
else:
postcreate_auto = postcreate_store.get(raw_db, [])
postcreate_auto.append(func)
postcreate_store[raw_db] = postcreate_auto

return decorator

def register_postcreate(self, raw_db: str | List[str]) -> Callable:
return self.postcreate_decorator(raw_db, "_postcreate_auto")

def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable:
return self.postcreate_decorator(raw_db, "_postcreate_manual")

async def run_postcreate(self, engine: AsyncEngine, raw_db: str, tenant_id: Union[str, None] = None) -> None:
session = AsyncSession(bind=engine, future=True, expire_on_commit=False)
async with session.begin():
for postcreate_func in self._postcreate_auto.get(raw_db, []):
result = postcreate_func(tenant_id)
if isinstance(result, list):
for statement in result:
await session.execute(statement)
else:
await session.execute(result)
for postcreate_func in self._postcreate_manual.get(raw_db, []):
await postcreate_func(session, tenant_id)
await session.commit()
await session.close()
logging.info(f"Postcreate for {raw_db} completed")
52 changes: 49 additions & 3 deletions sql_db_utils/session_management.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Callable, Union
from typing import Annotated, Callable, List, Union

from sqlalchemy import Engine, MetaData, NullPool, create_engine, text
from sqlalchemy.exc import OperationalError
Expand All @@ -13,11 +13,17 @@


class SQLSessionManager:
__slots__ = ("_db_engines", "database_uri")
__slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual")

def __init__(self, database_uri: Union[str, None] = None) -> None:
self._db_engines = {}
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI
self._postcreate_auto: dict = {}
self._postcreate_manual: dict = {}

def __del__(self) -> None:
for engine in self._db_engines.values():
engine.dispose()

def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str:
return f"{tenant_id}__{database}" if tenant_id else database
Expand Down Expand Up @@ -78,6 +84,7 @@ def _get_engine(
create_default_psql_dependencies(
metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine
)
self.run_postcreate(engine, database, tenant_id)
return engine

def get_session(
Expand All @@ -96,6 +103,7 @@ def get_session(
return Session(
bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
future=True,
expire_on_commit=False,
)

def get_engine_obj(
Expand All @@ -106,7 +114,45 @@ def get_engine_obj(
def get_db_factory(self, database: str, retrying: bool = False) -> Callable:
from fastapi import Cookie

async def get_db(tenant_id: Annotated[str, Cookie]):
def get_db(tenant_id: Annotated[str, Cookie]):
yield self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)

return get_db

def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable:
postcreate_store = getattr(self, postcreate_store)

def decorator(func: Callable) -> None:
if isinstance(raw_db, list):
for db in raw_db:
postcreate_auto = postcreate_store.get(db, [])
postcreate_auto.append(func)
postcreate_store[db] = postcreate_auto
else:
postcreate_auto = postcreate_store.get(raw_db, [])
postcreate_auto.append(func)
postcreate_store[raw_db] = postcreate_auto

return decorator

def register_postcreate(self, raw_db: str | List[str]) -> Callable:
return self.postcreate_decorator(raw_db, "_postcreate_auto")

def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable:
return self.postcreate_decorator(raw_db, "_postcreate_manual")

def run_postcreate(self, engine: Engine, raw_db: str, tenant_id: Union[str, None] = None) -> None:
session = Session(bind=engine, future=True, expire_on_commit=False)
with session.begin():
for postcreate_func in self._postcreate_auto.get(raw_db, []):
result = postcreate_func(tenant_id)
if isinstance(result, list):
for statement in result:
session.execute(statement)
else:
session.execute(result)
for postcreate_func in self._postcreate_manual.get(raw_db, []):
postcreate_func(session, tenant_id)
session.commit()
session.close()
logging.info(f"Postcreate for {raw_db} completed")
154 changes: 154 additions & 0 deletions sql_db_utils/sql_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import Any, List, Optional

from sqlalchemy.ext import compiler
from sqlalchemy.schema import CreateColumn, DDLElement


class CreateExtension(DDLElement):
def __init__(self, name: str):
self.name = name


@compiler.compiles(CreateExtension)
def compile_create_extension(element: CreateExtension, _compiler: Any, **__kwargs__) -> str:
return f"CREATE EXTENSION IF NOT EXISTS {element.name};"


class CreateServer(DDLElement):
def __init__(self, server_name: str, remote_db_name: str, remote_host: str, remote_port: int):
self.server_name = server_name
self.remote_db_name = remote_db_name
self.remote_host = remote_host
self.remote_port = remote_port


@compiler.compiles(CreateServer)
def compile_create_server(element: CreateServer, _compiler: Any, **__kwargs__) -> str:
return f"""
CREATE SERVER IF NOT EXISTS {element.server_name}
FOREIGN DATA WRAPPER postgres_fdw
OPTIONS (
dbname '{element.remote_db_name}',
host '{element.remote_host}',
port '{element.remote_port}'
);
"""


class DropServer(DDLElement):
def __init__(self, server_name: str):
self.server_name = server_name


@compiler.compiles(DropServer)
def compile_drop_server(element: DropServer, _compiler: Any, **__kwargs__) -> str:
return f"DROP SERVER IF EXISTS {element.server_name} CASCADE;"


class CreateUserMapping(DDLElement):
def __init__(self, role: str, server_name: str, remote_role: str, remote_password: str):
self.role = role
self.server_name = server_name
self.remote_role = remote_role
self.remote_password = remote_password


@compiler.compiles(CreateUserMapping)
def compile_create_user_mapping(element: CreateUserMapping, compiler: Any, **kw: Any) -> str:
return f"""
CREATE USER MAPPING FOR {element.role}
SERVER {element.server_name}
OPTIONS (user '{element.remote_role}', password '{element.remote_password}');
"""


class DropUserMapping(DDLElement):
def __init__(self, role: str, server_name: str):
self.role = role
self.server_name = server_name


@compiler.compiles(DropUserMapping)
def compile_drop_user_mapping(element: DropUserMapping, compiler: Any, **kw: Any) -> str:
return f"DROP USER MAPPING IF EXISTS FOR {element.role} SERVER {element.server_name};"


class CreateForeignTable(DDLElement):
def __init__(
self,
table_name: str,
columns: List[Any],
server_name: str,
remote_schema_name: str,
remote_table_name: str,
local_schema_name: Optional[str] = None,
):
self.local_schema_name = local_schema_name or "public"
self.table_name = table_name
self.columns = columns
self.server_name = server_name
self.remote_schema_name = remote_schema_name
self.remote_table_name = remote_table_name


@compiler.compiles(CreateForeignTable)
def compile_create_foreign_table(element: CreateForeignTable, compiler: Any, **kw: Any) -> str:
columns = [compiler.process(CreateColumn(column), **kw) for column in element.columns]
return f"""
CREATE FOREIGN TABLE {element.local_schema_name}.{element.table_name}
({", ".join(columns)})
SERVER {element.server_name}
OPTIONS(schema_name '{element.remote_schema_name}', table_name '{element.remote_table_name}');
"""


class DropForeignTable(DDLElement):
def __init__(self, name: str):
self.name = name


@compiler.compiles(DropForeignTable)
def compile_drop_foreign_table(element: DropForeignTable, compiler: Any, **kw: Any) -> str:
return f"DROP FOREIGN TABLE IF EXISTS {element.name};"


class CreatePrefixedIdFunction(DDLElement):
def __init__(self, function_name: str):
self.function_name = function_name


@compiler.compiles(CreatePrefixedIdFunction)
def compile_create_prefixed_id_function(element: CreatePrefixedIdFunction, _compiler: Any, **__kwargs__) -> str:
return f"""
CREATE OR REPLACE FUNCTION {element.function_name}(prefix TEXT, seq_name TEXT)
RETURNS TEXT
AS $$
DECLARE
next_val INTEGER;
BEGIN
next_val := nextval(seq_name);
RETURN prefix || next_val;
END;
$$ LANGUAGE plpgsql;
"""


class CreateSuffixedIdFunction(DDLElement):
def __init__(self, function_name: str):
self.function_name = function_name


@compiler.compiles(CreateSuffixedIdFunction)
def compile_create_suffixed_id_function(element: CreateSuffixedIdFunction, _compiler: Any, **__kwargs__) -> str:
return f"""
CREATE OR REPLACE FUNCTION {element.function_name}(seq_name TEXT, suffix TEXT)
RETURNS TEXT
AS $$
DECLARE
next_val INTEGER;
BEGIN
next_val := nextval(seq_name);
RETURN next_val || suffix;
END;
$$ LANGUAGE plpgsql;
"""