Skip to content

Commit f552251

Browse files
authored
Merge pull request #2 from exiorrealty/feat/auto-execute
Feat/auto execute
2 parents ebeaaa2 + 4ad17f0 commit f552251

4 files changed

Lines changed: 247 additions & 6 deletions

File tree

sql_db_utils/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.1"
1+
__version__ = "1.1.0"

sql_db_utils/asyncio/session_management.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Annotated, Any, AsyncGenerator, Union
2+
from typing import Annotated, Any, AsyncGenerator, Callable, List, Union
33

44
from sqlalchemy import Engine, MetaData, NullPool, text
55
from sqlalchemy.exc import OperationalError
@@ -13,11 +13,13 @@
1313

1414

1515
class SQLSessionManager:
16-
__slots__ = ("_db_engines", "database_uri")
16+
__slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual")
1717

1818
def __init__(self, database_uri: Union[str, None] = None) -> None:
1919
self._db_engines = {}
2020
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI
21+
self._postcreate_auto: dict = {}
22+
self._postcreate_manual: dict = {}
2123

2224
def __del__(self) -> None:
2325
for engine in self._db_engines.values():
@@ -82,6 +84,7 @@ async def _get_engine(
8284
await create_default_psql_dependencies(
8385
metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine
8486
)
87+
await self.run_postcreate(engine, database, tenant_id)
8588
return engine
8689

8790
async def get_session(
@@ -115,3 +118,41 @@ async def get_db(tenant_id: Annotated[str, Cookie]):
115118
yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)
116119

117120
return get_db
121+
122+
def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable:
123+
postcreate_store = getattr(self, postcreate_store)
124+
125+
def decorator(func: Callable) -> None:
126+
if isinstance(raw_db, list):
127+
for db in raw_db:
128+
postcreate_auto = postcreate_store.get(db, [])
129+
postcreate_auto.append(func)
130+
postcreate_store[db] = postcreate_auto
131+
else:
132+
postcreate_auto = postcreate_store.get(raw_db, [])
133+
postcreate_auto.append(func)
134+
postcreate_store[raw_db] = postcreate_auto
135+
136+
return decorator
137+
138+
def register_postcreate(self, raw_db: str | List[str]) -> Callable:
139+
return self.postcreate_decorator(raw_db, "_postcreate_auto")
140+
141+
def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable:
142+
return self.postcreate_decorator(raw_db, "_postcreate_manual")
143+
144+
async def run_postcreate(self, engine: AsyncEngine, raw_db: str, tenant_id: Union[str, None] = None) -> None:
145+
session = AsyncSession(bind=engine, future=True, expire_on_commit=False)
146+
async with session.begin():
147+
for postcreate_func in self._postcreate_auto.get(raw_db, []):
148+
result = postcreate_func(tenant_id)
149+
if isinstance(result, list):
150+
for statement in result:
151+
await session.execute(statement)
152+
else:
153+
await session.execute(result)
154+
for postcreate_func in self._postcreate_manual.get(raw_db, []):
155+
await postcreate_func(session, tenant_id)
156+
await session.commit()
157+
await session.close()
158+
logging.info(f"Postcreate for {raw_db} completed")

sql_db_utils/session_management.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Annotated, Callable, Union
2+
from typing import Annotated, Callable, List, Union
33

44
from sqlalchemy import Engine, MetaData, NullPool, create_engine, text
55
from sqlalchemy.exc import OperationalError
@@ -13,11 +13,17 @@
1313

1414

1515
class SQLSessionManager:
16-
__slots__ = ("_db_engines", "database_uri")
16+
__slots__ = ("_db_engines", "database_uri", "_postcreate_auto", "_postcreate_manual")
1717

1818
def __init__(self, database_uri: Union[str, None] = None) -> None:
1919
self._db_engines = {}
2020
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI
21+
self._postcreate_auto: dict = {}
22+
self._postcreate_manual: dict = {}
23+
24+
def __del__(self) -> None:
25+
for engine in self._db_engines.values():
26+
engine.dispose()
2127

2228
def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str:
2329
return f"{tenant_id}__{database}" if tenant_id else database
@@ -78,6 +84,7 @@ def _get_engine(
7884
create_default_psql_dependencies(
7985
metadata=metadata or DeclarativeBaseClassFactory(database).metadata, engine_obj=engine
8086
)
87+
self.run_postcreate(engine, database, tenant_id)
8188
return engine
8289

8390
def get_session(
@@ -96,6 +103,7 @@ def get_session(
96103
return Session(
97104
bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
98105
future=True,
106+
expire_on_commit=False,
99107
)
100108

101109
def get_engine_obj(
@@ -106,7 +114,45 @@ def get_engine_obj(
106114
def get_db_factory(self, database: str, retrying: bool = False) -> Callable:
107115
from fastapi import Cookie
108116

109-
async def get_db(tenant_id: Annotated[str, Cookie]):
117+
def get_db(tenant_id: Annotated[str, Cookie]):
110118
yield self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)
111119

112120
return get_db
121+
122+
def postcreate_decorator(self, raw_db: str | List[str], postcreate_store: str) -> Callable:
123+
postcreate_store = getattr(self, postcreate_store)
124+
125+
def decorator(func: Callable) -> None:
126+
if isinstance(raw_db, list):
127+
for db in raw_db:
128+
postcreate_auto = postcreate_store.get(db, [])
129+
postcreate_auto.append(func)
130+
postcreate_store[db] = postcreate_auto
131+
else:
132+
postcreate_auto = postcreate_store.get(raw_db, [])
133+
postcreate_auto.append(func)
134+
postcreate_store[raw_db] = postcreate_auto
135+
136+
return decorator
137+
138+
def register_postcreate(self, raw_db: str | List[str]) -> Callable:
139+
return self.postcreate_decorator(raw_db, "_postcreate_auto")
140+
141+
def register_postcreate_manual(self, raw_db: str | List[str]) -> Callable:
142+
return self.postcreate_decorator(raw_db, "_postcreate_manual")
143+
144+
def run_postcreate(self, engine: Engine, raw_db: str, tenant_id: Union[str, None] = None) -> None:
145+
session = Session(bind=engine, future=True, expire_on_commit=False)
146+
with session.begin():
147+
for postcreate_func in self._postcreate_auto.get(raw_db, []):
148+
result = postcreate_func(tenant_id)
149+
if isinstance(result, list):
150+
for statement in result:
151+
session.execute(statement)
152+
else:
153+
session.execute(result)
154+
for postcreate_func in self._postcreate_manual.get(raw_db, []):
155+
postcreate_func(session, tenant_id)
156+
session.commit()
157+
session.close()
158+
logging.info(f"Postcreate for {raw_db} completed")

sql_db_utils/sql_extras.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing import Any, List, Optional
2+
3+
from sqlalchemy.ext import compiler
4+
from sqlalchemy.schema import CreateColumn, DDLElement
5+
6+
7+
class CreateExtension(DDLElement):
8+
def __init__(self, name: str):
9+
self.name = name
10+
11+
12+
@compiler.compiles(CreateExtension)
13+
def compile_create_extension(element: CreateExtension, _compiler: Any, **__kwargs__) -> str:
14+
return f"CREATE EXTENSION IF NOT EXISTS {element.name};"
15+
16+
17+
class CreateServer(DDLElement):
18+
def __init__(self, server_name: str, remote_db_name: str, remote_host: str, remote_port: int):
19+
self.server_name = server_name
20+
self.remote_db_name = remote_db_name
21+
self.remote_host = remote_host
22+
self.remote_port = remote_port
23+
24+
25+
@compiler.compiles(CreateServer)
26+
def compile_create_server(element: CreateServer, _compiler: Any, **__kwargs__) -> str:
27+
return f"""
28+
CREATE SERVER IF NOT EXISTS {element.server_name}
29+
FOREIGN DATA WRAPPER postgres_fdw
30+
OPTIONS (
31+
dbname '{element.remote_db_name}',
32+
host '{element.remote_host}',
33+
port '{element.remote_port}'
34+
);
35+
"""
36+
37+
38+
class DropServer(DDLElement):
39+
def __init__(self, server_name: str):
40+
self.server_name = server_name
41+
42+
43+
@compiler.compiles(DropServer)
44+
def compile_drop_server(element: DropServer, _compiler: Any, **__kwargs__) -> str:
45+
return f"DROP SERVER IF EXISTS {element.server_name} CASCADE;"
46+
47+
48+
class CreateUserMapping(DDLElement):
49+
def __init__(self, role: str, server_name: str, remote_role: str, remote_password: str):
50+
self.role = role
51+
self.server_name = server_name
52+
self.remote_role = remote_role
53+
self.remote_password = remote_password
54+
55+
56+
@compiler.compiles(CreateUserMapping)
57+
def compile_create_user_mapping(element: CreateUserMapping, compiler: Any, **kw: Any) -> str:
58+
return f"""
59+
CREATE USER MAPPING FOR {element.role}
60+
SERVER {element.server_name}
61+
OPTIONS (user '{element.remote_role}', password '{element.remote_password}');
62+
"""
63+
64+
65+
class DropUserMapping(DDLElement):
66+
def __init__(self, role: str, server_name: str):
67+
self.role = role
68+
self.server_name = server_name
69+
70+
71+
@compiler.compiles(DropUserMapping)
72+
def compile_drop_user_mapping(element: DropUserMapping, compiler: Any, **kw: Any) -> str:
73+
return f"DROP USER MAPPING IF EXISTS FOR {element.role} SERVER {element.server_name};"
74+
75+
76+
class CreateForeignTable(DDLElement):
77+
def __init__(
78+
self,
79+
table_name: str,
80+
columns: List[Any],
81+
server_name: str,
82+
remote_schema_name: str,
83+
remote_table_name: str,
84+
local_schema_name: Optional[str] = None,
85+
):
86+
self.local_schema_name = local_schema_name or "public"
87+
self.table_name = table_name
88+
self.columns = columns
89+
self.server_name = server_name
90+
self.remote_schema_name = remote_schema_name
91+
self.remote_table_name = remote_table_name
92+
93+
94+
@compiler.compiles(CreateForeignTable)
95+
def compile_create_foreign_table(element: CreateForeignTable, compiler: Any, **kw: Any) -> str:
96+
columns = [compiler.process(CreateColumn(column), **kw) for column in element.columns]
97+
return f"""
98+
CREATE FOREIGN TABLE {element.local_schema_name}.{element.table_name}
99+
({", ".join(columns)})
100+
SERVER {element.server_name}
101+
OPTIONS(schema_name '{element.remote_schema_name}', table_name '{element.remote_table_name}');
102+
"""
103+
104+
105+
class DropForeignTable(DDLElement):
106+
def __init__(self, name: str):
107+
self.name = name
108+
109+
110+
@compiler.compiles(DropForeignTable)
111+
def compile_drop_foreign_table(element: DropForeignTable, compiler: Any, **kw: Any) -> str:
112+
return f"DROP FOREIGN TABLE IF EXISTS {element.name};"
113+
114+
115+
class CreatePrefixedIdFunction(DDLElement):
116+
def __init__(self, function_name: str):
117+
self.function_name = function_name
118+
119+
120+
@compiler.compiles(CreatePrefixedIdFunction)
121+
def compile_create_prefixed_id_function(element: CreatePrefixedIdFunction, _compiler: Any, **__kwargs__) -> str:
122+
return f"""
123+
CREATE OR REPLACE FUNCTION {element.function_name}(prefix TEXT, seq_name TEXT)
124+
RETURNS TEXT
125+
AS $$
126+
DECLARE
127+
next_val INTEGER;
128+
BEGIN
129+
next_val := nextval(seq_name);
130+
RETURN prefix || next_val;
131+
END;
132+
$$ LANGUAGE plpgsql;
133+
"""
134+
135+
136+
class CreateSuffixedIdFunction(DDLElement):
137+
def __init__(self, function_name: str):
138+
self.function_name = function_name
139+
140+
141+
@compiler.compiles(CreateSuffixedIdFunction)
142+
def compile_create_suffixed_id_function(element: CreateSuffixedIdFunction, _compiler: Any, **__kwargs__) -> str:
143+
return f"""
144+
CREATE OR REPLACE FUNCTION {element.function_name}(seq_name TEXT, suffix TEXT)
145+
RETURNS TEXT
146+
AS $$
147+
DECLARE
148+
next_val INTEGER;
149+
BEGIN
150+
next_val := nextval(seq_name);
151+
RETURN next_val || suffix;
152+
END;
153+
$$ LANGUAGE plpgsql;
154+
"""

0 commit comments

Comments
 (0)