diff --git a/sql_db_utils/__version__.py b/sql_db_utils/__version__.py index 5c4105c..6849410 100644 --- a/sql_db_utils/__version__.py +++ b/sql_db_utils/__version__.py @@ -1 +1 @@ -__version__ = "1.0.1" +__version__ = "1.1.0" diff --git a/sql_db_utils/asyncio/session_management.py b/sql_db_utils/asyncio/session_management.py index 6154a8f..f10137f 100644 --- a/sql_db_utils/asyncio/session_management.py +++ b/sql_db_utils/asyncio/session_management.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, Any, AsyncGenerator, Union +from typing import Annotated, Any, AsyncGenerator, Callable, List, Union from sqlalchemy import Engine, MetaData, NullPool, text from sqlalchemy.exc import OperationalError @@ -13,11 +13,13 @@ class SQLSessionManager: - __slots__ = ("_db_engines", "database_uri") + __slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual") def __init__(self, database_uri: Union[str, None] = None) -> None: self._db_engines = {} self.database_uri = database_uri or PostgresConfig.POSTGRES_URI + self._postcreate_auto: dict = {} + self._postcreate_manual: dict = {} def __del__(self) -> None: for engine in self._db_engines.values(): @@ -82,6 +84,7 @@ async def _get_engine( await create_default_psql_dependencies( metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine ) + await self.run_postcreate(engine, database, tenant_id) return engine async def get_session( @@ -115,3 +118,41 @@ async def get_db(tenant_id: Annotated[str, Cookie]): yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying) return get_db + + def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable: + postcreate_store = getattr(self, postcreate_store) + + def decorator(func: Callable) -> None: + if isinstance(raw_db, list): + for db in raw_db: + postcreate_auto = postcreate_store.get(db, []) + postcreate_auto.append(func) + postcreate_store[db] = postcreate_auto + else: + postcreate_auto = postcreate_store.get(raw_db, []) + postcreate_auto.append(func) + postcreate_store[raw_db] = postcreate_auto + + return decorator + + def register_postcreate(self, raw_db: str | List[str]) -> Callable: + return self.postcreate_decorator(raw_db, "_postcreate_auto") + + def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable: + return self.postcreate_decorator(raw_db, "_postcreate_manual") + + async def run_postcreate(self, engine: AsyncEngine, raw_db: str, tenant_id: Union[str, None] = None) -> None: + session = AsyncSession(bind=engine, future=True, expire_on_commit=False) + async with session.begin(): + for postcreate_func in self._postcreate_auto.get(raw_db, []): + result = postcreate_func(tenant_id) + if isinstance(result, list): + for statement in result: + await session.execute(statement) + else: + await session.execute(result) + for postcreate_func in self._postcreate_manual.get(raw_db, []): + await postcreate_func(session, tenant_id) + await session.commit() + await session.close() + logging.info(f"Postcreate for {raw_db} completed") diff --git a/sql_db_utils/session_management.py b/sql_db_utils/session_management.py index 1b75988..4db93d3 100644 --- a/sql_db_utils/session_management.py +++ b/sql_db_utils/session_management.py @@ -1,5 +1,5 @@ import logging -from typing import Annotated, Callable, Union +from typing import Annotated, Callable, List, Union from sqlalchemy import Engine, MetaData, NullPool, create_engine, text from sqlalchemy.exc import OperationalError @@ -13,11 +13,17 @@ class SQLSessionManager: - __slots__ = ("_db_engines", "database_uri") + __slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual") def __init__(self, database_uri: Union[str, None] = None) -> None: self._db_engines = {} self.database_uri = database_uri or PostgresConfig.POSTGRES_URI + self._postcreate_auto: dict = {} + self._postcreate_manual: dict = {} + + def __del__(self) -> None: + for engine in self._db_engines.values(): + engine.dispose() def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str: return f"{tenant_id}__{database}" if tenant_id else database @@ -78,6 +84,7 @@ def _get_engine( create_default_psql_dependencies( metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine ) + self.run_postcreate(engine, database, tenant_id) return engine def get_session( @@ -96,6 +103,7 @@ def get_session( return Session( bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata), future=True, + expire_on_commit=False, ) def get_engine_obj( @@ -106,7 +114,45 @@ def get_engine_obj( def get_db_factory(self, database: str, retrying: bool = False) -> Callable: from fastapi import Cookie - async def get_db(tenant_id: Annotated[str, Cookie]): + def get_db(tenant_id: Annotated[str, Cookie]): yield self.get_session(database=database, tenant_id=tenant_id, retrying=retrying) return get_db + + def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable: + postcreate_store = getattr(self, postcreate_store) + + def decorator(func: Callable) -> None: + if isinstance(raw_db, list): + for db in raw_db: + postcreate_auto = postcreate_store.get(db, []) + postcreate_auto.append(func) + postcreate_store[db] = postcreate_auto + else: + postcreate_auto = postcreate_store.get(raw_db, []) + postcreate_auto.append(func) + postcreate_store[raw_db] = postcreate_auto + + return decorator + + def register_postcreate(self, raw_db: str | List[str]) -> Callable: + return self.postcreate_decorator(raw_db, "_postcreate_auto") + + def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable: + return self.postcreate_decorator(raw_db, "_postcreate_manual") + + def run_postcreate(self, engine: Engine, raw_db: str, tenant_id: Union[str, None] = None) -> None: + session = Session(bind=engine, future=True, expire_on_commit=False) + with session.begin(): + for postcreate_func in self._postcreate_auto.get(raw_db, []): + result = postcreate_func(tenant_id) + if isinstance(result, list): + for statement in result: + session.execute(statement) + else: + session.execute(result) + for postcreate_func in self._postcreate_manual.get(raw_db, []): + postcreate_func(session, tenant_id) + session.commit() + session.close() + logging.info(f"Postcreate for {raw_db} completed") diff --git a/sql_db_utils/sql_extras.py b/sql_db_utils/sql_extras.py new file mode 100644 index 0000000..42f5465 --- /dev/null +++ b/sql_db_utils/sql_extras.py @@ -0,0 +1,154 @@ +from typing import Any, List, Optional + +from sqlalchemy.ext import compiler +from sqlalchemy.schema import CreateColumn, DDLElement + + +class CreateExtension(DDLElement): + def __init__(self, name: str): + self.name = name + + +@compiler.compiles(CreateExtension) +def compile_create_extension(element: CreateExtension, _compiler: Any, **__kwargs__) -> str: + return f"CREATE EXTENSION IF NOT EXISTS {element.name};" + + +class CreateServer(DDLElement): + def __init__(self, server_name: str, remote_db_name: str, remote_host: str, remote_port: int): + self.server_name = server_name + self.remote_db_name = remote_db_name + self.remote_host = remote_host + self.remote_port = remote_port + + +@compiler.compiles(CreateServer) +def compile_create_server(element: CreateServer, _compiler: Any, **__kwargs__) -> str: + return f""" + CREATE SERVER IF NOT EXISTS {element.server_name} + FOREIGN DATA WRAPPER postgres_fdw + OPTIONS ( + dbname '{element.remote_db_name}', + host '{element.remote_host}', + port '{element.remote_port}' + ); + """ + + +class DropServer(DDLElement): + def __init__(self, server_name: str): + self.server_name = server_name + + +@compiler.compiles(DropServer) +def compile_drop_server(element: DropServer, _compiler: Any, **__kwargs__) -> str: + return f"DROP SERVER IF EXISTS {element.server_name} CASCADE;" + + +class CreateUserMapping(DDLElement): + def __init__(self, role: str, server_name: str, remote_role: str, remote_password: str): + self.role = role + self.server_name = server_name + self.remote_role = remote_role + self.remote_password = remote_password + + +@compiler.compiles(CreateUserMapping) +def compile_create_user_mapping(element: CreateUserMapping, compiler: Any, **kw: Any) -> str: + return f""" + CREATE USER MAPPING FOR {element.role} + SERVER {element.server_name} + OPTIONS (user '{element.remote_role}', password '{element.remote_password}'); + """ + + +class DropUserMapping(DDLElement): + def __init__(self, role: str, server_name: str): + self.role = role + self.server_name = server_name + + +@compiler.compiles(DropUserMapping) +def compile_drop_user_mapping(element: DropUserMapping, compiler: Any, **kw: Any) -> str: + return f"DROP USER MAPPING IF EXISTS FOR {element.role} SERVER {element.server_name};" + + +class CreateForeignTable(DDLElement): + def __init__( + self, + table_name: str, + columns: List[Any], + server_name: str, + remote_schema_name: str, + remote_table_name: str, + local_schema_name: Optional[str] = None, + ): + self.local_schema_name = local_schema_name or "public" + self.table_name = table_name + self.columns = columns + self.server_name = server_name + self.remote_schema_name = remote_schema_name + self.remote_table_name = remote_table_name + + +@compiler.compiles(CreateForeignTable) +def compile_create_foreign_table(element: CreateForeignTable, compiler: Any, **kw: Any) -> str: + columns = [compiler.process(CreateColumn(column), **kw) for column in element.columns] + return f""" + CREATE FOREIGN TABLE {element.local_schema_name}.{element.table_name} + ({", ".join(columns)}) + SERVER {element.server_name} + OPTIONS(schema_name '{element.remote_schema_name}', table_name '{element.remote_table_name}'); + """ + + +class DropForeignTable(DDLElement): + def __init__(self, name: str): + self.name = name + + +@compiler.compiles(DropForeignTable) +def compile_drop_foreign_table(element: DropForeignTable, compiler: Any, **kw: Any) -> str: + return f"DROP FOREIGN TABLE IF EXISTS {element.name};" + + +class CreatePrefixedIdFunction(DDLElement): + def __init__(self, function_name: str): + self.function_name = function_name + + +@compiler.compiles(CreatePrefixedIdFunction) +def compile_create_prefixed_id_function(element: CreatePrefixedIdFunction, _compiler: Any, **__kwargs__) -> str: + return f""" + CREATE OR REPLACE FUNCTION {element.function_name}(prefix TEXT, seq_name TEXT) + RETURNS TEXT + AS $$ + DECLARE + next_val INTEGER; + BEGIN + next_val := nextval(seq_name); + RETURN prefix || next_val; + END; + $$ LANGUAGE plpgsql; + """ + + +class CreateSuffixedIdFunction(DDLElement): + def __init__(self, function_name: str): + self.function_name = function_name + + +@compiler.compiles(CreateSuffixedIdFunction) +def compile_create_suffixed_id_function(element: CreateSuffixedIdFunction, _compiler: Any, **__kwargs__) -> str: + return f""" + CREATE OR REPLACE FUNCTION {element.function_name}(seq_name TEXT, suffix TEXT) + RETURNS TEXT + AS $$ + DECLARE + next_val INTEGER; + BEGIN + next_val := nextval(seq_name); + RETURN next_val || suffix; + END; + $$ LANGUAGE plpgsql; + """