Skip to content

Commit ebeaaa2

Browse files
authored
Merge pull request #1 from exiorrealty/feat/tenant-integration
Add tenant-id integration and improve asyncio handling
2 parents 90034e9 + a8c462c commit ebeaaa2

8 files changed

Lines changed: 164 additions & 275 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
- id: trailing-whitespace
77
- id: requirements-txt-fixer
88
- repo: https://github.com/charliermarsh/ruff-pre-commit
9-
rev: v0.9.7
9+
rev: v0.11.0
1010
hooks:
1111
- id: ruff
1212
args:

sql_db_utils/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0"
1+
__version__ = "1.0.1"

sql_db_utils/asyncio/declarative_utils.py

Lines changed: 57 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ class DeclarativeUtils:
2121
"""
2222

2323
async def __new__(
24-
cls, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
24+
cls, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
2525
) -> None:
2626
obj = super().__new__(cls)
27-
obj.__init__(raw_database, project_id, session_manager, schema, raw_db)
27+
obj.__init__(raw_database, tenant_id, session_manager, schema, raw_db)
2828
await obj._get_declarative_module()
2929
return obj
3030

3131
def __init__(
32-
self, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
32+
self, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
3333
) -> None:
3434
self.raw_database: str = raw_database
35-
self.project_id: str = project_id
35+
self.tenant_id: str = tenant_id
3636
self.session_manager: SQLSessionManager = session_manager
3737
self.raw_db = raw_db
3838
self.schema = schema
@@ -43,29 +43,27 @@ async def _pre_check(self):
4343
await self._get_declarative_module()
4444

4545
async def _prepare_declarative_file(self, refresh: bool = False):
46-
declarative_project_directory = PathConfig.DECLARATIVES_PATH / self.project_id
47-
if not declarative_project_directory.exists():
48-
declarative_project_directory.mkdir(parents=True)
49-
project_init_file = declarative_project_directory / "__init__.py"
50-
if not project_init_file.exists():
51-
with open(project_init_file, "w") as f:
46+
declarative_tenant_directory = PathConfig.DECLARATIVES_PATH / self.tenant_id
47+
if not declarative_tenant_directory.exists():
48+
declarative_tenant_directory.mkdir(parents=True)
49+
tenant_init_file = declarative_tenant_directory / "__init__.py"
50+
if not tenant_init_file.exists():
51+
with open(tenant_init_file, "w") as f:
5252
f.write("")
53-
declarative_file = declarative_project_directory / f"async_{self.raw_database}_{self.schema}.py"
53+
declarative_file = declarative_tenant_directory / f"async_{self.raw_database}_{self.schema}.py"
5454
if declarative_file.exists() and ModuleConfig.DEFER_GEN_REFRESH and not refresh:
55-
return f"{self.project_id}.async_{self.raw_database}_{self.schema}"
55+
return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}"
5656
try:
5757
logging.debug(f"Attempting to create declarative file: {declarative_file}")
5858
from sql_db_utils.asyncio.codegen import UTDeclarativeGenerator
5959

60-
session = await self.session_manager.get_session(
61-
self.raw_database, None if self.raw_db else self.project_id
62-
)
60+
session = await self.session_manager.get_session(self.raw_database, None if self.raw_db else self.tenant_id)
6361
meta = MetaData()
6462
async with session.bind.begin() as conn:
6563
await conn.run_sync(meta.reflect, schema=self.schema)
6664
with open(declarative_file, "w", encoding="utf-8") as f:
6765
generator = UTDeclarativeGenerator(
68-
raw_database=self.raw_database if self.raw_db else f"{self.project_id}__{self.raw_database}",
66+
raw_database=self.raw_database if self.raw_db else f"{self.tenant_id}__{self.raw_database}",
6967
metadata=meta,
7068
bind=session.bind,
7169
options=set(),
@@ -79,7 +77,7 @@ async def _prepare_declarative_file(self, refresh: bool = False):
7977
except Exception as e:
8078
logging.error(f"Error creating declarative file: {e}")
8179
return False
82-
return f"{self.project_id}.async_{self.raw_database}_{self.schema}"
80+
return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}"
8381

8482
async def _get_declarative_module(self): # NOSONAR
8583
if declarative_module_path := await self._prepare_declarative_file():
@@ -97,8 +95,20 @@ async def _get_declarative_module(self): # NOSONAR
9795
try:
9896
import asyncio
9997

100-
loop = asyncio.get_event_loop()
101-
loop.stop()
98+
logging.warning("Emergency shutdown required - gracefully canceling tasks")
99+
loop = asyncio.get_running_loop()
100+
tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()]
101+
logging.debug(f"Canceling {len(tasks)} pending tasks")
102+
103+
for task in tasks:
104+
task.cancel()
105+
106+
# Wait for all tasks to complete with cancellation
107+
if tasks:
108+
await asyncio.gather(*tasks, return_exceptions=True)
109+
110+
logging.info("Tasks gracefully canceled, exiting")
111+
sys.exit(1)
102112
except ImportError:
103113
logging.error("Not asyncio module, stopping using sys.exit")
104114
sys.exit(1)
@@ -169,105 +179,54 @@ def get_declarative_utils_factory(
169179
self,
170180
raw_database: str,
171181
session_manager: SQLSessionManager,
172-
security_enabled: bool = True,
173182
):
174-
if security_enabled:
175-
try:
176-
from ut_security_util import MetaInfoSchema
177-
178-
async def get_declarative_utils(
179-
meta: MetaInfoSchema,
180-
schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA,
181-
) -> DeclarativeUtils:
182-
global declarative_utils
183-
if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"):
184-
await declarative_util._pre_check()
185-
return declarative_util
186-
else:
187-
declarative_util = await DeclarativeUtils(
188-
raw_database, meta.project_id, session_manager, schema
189-
)
190-
declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util
191-
return declarative_util
192-
193-
return get_declarative_utils
194-
except ImportError:
195-
logging.error("ut_security_util not installed, please install it to use security features")
196-
raise
197-
else:
198-
199-
async def get_declarative_utils(
200-
project_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA
201-
) -> DeclarativeUtils:
202-
global declarative_utils
203-
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
204-
await declarative_util._pre_check()
205-
return declarative_util
206-
else:
207-
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema)
208-
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
209-
return declarative_util
210-
211-
return get_declarative_utils
183+
async def get_declarative_utils(
184+
tenant_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA
185+
) -> DeclarativeUtils:
186+
global declarative_utils
187+
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
188+
await declarative_util._pre_check()
189+
return declarative_util
190+
else:
191+
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema)
192+
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
193+
return declarative_util
194+
195+
return get_declarative_utils
212196

213197
def get_schema_mandated_declarative_utils_factory(
214198
self,
215199
raw_database: str,
216200
session_manager: SQLSessionManager,
217201
schema: str,
218-
security_enabled: bool = True,
219202
):
220-
if security_enabled:
221-
try:
222-
from ut_security_util import MetaInfoSchema
223-
224-
async def get_declarative_utils(
225-
meta: MetaInfoSchema,
226-
) -> DeclarativeUtils:
227-
global declarative_utils
228-
if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"):
229-
await declarative_util._pre_check()
230-
return declarative_util
231-
else:
232-
declarative_util = await DeclarativeUtils(
233-
raw_database, meta.project_id, session_manager, schema
234-
)
235-
declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util
236-
return declarative_util
237-
238-
return get_declarative_utils
239-
except ImportError:
240-
logging.error("ut_security_util not installed, please install it to use security features")
241-
raise
242-
else:
243-
244-
async def get_declarative_utils(project_id: Annotated[str, Cookie]) -> DeclarativeUtils:
245-
global declarative_utils
246-
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
247-
await declarative_util._pre_check()
248-
return declarative_util
249-
else:
250-
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema)
251-
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
252-
return declarative_util
253-
254-
return get_declarative_utils
203+
async def get_declarative_utils(tenant_id: Annotated[str, Cookie]) -> DeclarativeUtils:
204+
global declarative_utils
205+
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
206+
await declarative_util._pre_check()
207+
return declarative_util
208+
else:
209+
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema)
210+
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
211+
return declarative_util
212+
213+
return get_declarative_utils
255214

256215
async def get_declarative_utils(
257216
self,
258217
raw_database: str,
259-
project_id: str,
218+
tenant_id: str,
260219
session_manager: SQLSessionManager,
261220
schema: str = PostgresConfig.PG_DEFAULT_SCHEMA,
262221
raw_db: bool = False,
263222
) -> DeclarativeUtils:
264223
global declarative_utils
265-
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
224+
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
266225
await declarative_util._pre_check()
267226
return declarative_util
268227
else:
269-
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema, raw_db)
270-
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
228+
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema, raw_db)
229+
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
271230
return declarative_util
272231

273232

sql_db_utils/asyncio/session_management.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
2-
from typing import Any, AsyncGenerator, Union
2+
from typing import Annotated, Any, AsyncGenerator, Union
33

4-
from redis import Redis
54
from sqlalchemy import Engine, MetaData, NullPool, text
65
from sqlalchemy.exc import OperationalError
76
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
@@ -14,19 +13,18 @@
1413

1514

1615
class SQLSessionManager:
17-
def __init__(self, redis_project_db: Union[Redis, None] = None, database_uri: Union[str, None] = None) -> None:
16+
__slots__ = ("_db_engines", "database_uri")
17+
18+
def __init__(self, database_uri: Union[str, None] = None) -> None:
1819
self._db_engines = {}
19-
if not redis_project_db:
20-
from sql_db_utils.redis_connections import project_db as redis_project_db
21-
self.redis_project_source_db = redis_project_db
2220
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI
2321

2422
def __del__(self) -> None:
2523
for engine in self._db_engines.values():
2624
engine.dispose()
2725

28-
def _get_fully_qualified_db(self, database: str, project_id: Union[str, None] = None) -> str:
29-
return f"{project_id}__{database}" if project_id else database
26+
def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str:
27+
return f"{tenant_id}__{database}" if tenant_id else database
3028

3129
async def _ensure_engine_connection(self, _engine_obj: Engine):
3230
for _ in range(PostgresConfig.PG_MAX_RETRY):
@@ -43,9 +41,9 @@ async def _ensure_engine_connection(self, _engine_obj: Engine):
4341
logging.error("Server connection failed")
4442

4543
async def _get_engine(
46-
self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
44+
self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
4745
) -> AsyncSession:
48-
qualified_db_name = self._get_fully_qualified_db(database=database, project_id=project_id)
46+
qualified_db_name = self._get_fully_qualified_db(database=database, tenant_id=tenant_id)
4947
if not (engine := self._db_engines.get(qualified_db_name)):
5048
logging.debug(f"Creating engine for database: {qualified_db_name}")
5149
if PostgresConfig.PG_ENABLE_POOLING:
@@ -89,47 +87,31 @@ async def _get_engine(
8987
async def get_session(
9088
self,
9189
database: str,
92-
project_id: Union[str, None] = None,
90+
tenant_id: Union[str, None] = None,
9391
metadata: Union[MetaData, None] = None,
9492
retrying: bool = False,
9593
) -> AsyncSession:
9694
if PostgresConfig.PG_RETRY_QUERY or retrying:
9795
return AsyncSession(
98-
bind=self._get_engine(database=database, project_id=project_id, metadata=metadata),
96+
bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
9997
future=True,
10098
query_cls=RetryingQuery,
10199
)
102100
return AsyncSession(
103-
bind=await self._get_engine(database=database, project_id=project_id, metadata=metadata),
101+
bind=await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
104102
expire_on_commit=False,
105103
future=True,
106104
)
107105

108106
async def get_engine_obj(
109-
self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
107+
self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
110108
) -> AsyncEngine:
111-
return await self._get_engine(database=database, project_id=project_id, metadata=metadata)
112-
113-
def get_db_factory(
114-
self, database: str, security_enabled: bool = True, retrying: bool = False
115-
) -> AsyncGenerator[AsyncSession, Any]:
116-
if security_enabled:
117-
try:
118-
from ut_security_util import MetaInfoSchema
119-
120-
async def get_db(meta: MetaInfoSchema):
121-
yield await self.get_session(database=database, project_id=meta.project_id, retrying=retrying)
109+
return await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata)
122110

123-
return get_db
124-
except ImportError:
125-
logging.error("ut_security_util not installed, please install it to use security features")
126-
raise
127-
else:
128-
from fastapi import Request
111+
def get_db_factory(self, database: str, retrying: bool = False) -> AsyncGenerator[AsyncSession, Any]:
112+
from fastapi import Cookie
129113

130-
async def get_db(request: Request):
131-
cookies = request.cookies
132-
project_id = cookies.get("project_id")
133-
yield await self.get_session(database=database, project_id=project_id, retrying=retrying)
114+
async def get_db(tenant_id: Annotated[str, Cookie]):
115+
yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)
134116

135-
return get_db
117+
return get_db

sql_db_utils/config.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,6 @@ def validate_my_field(cls, value):
4242
return value
4343

4444

45-
class _RedisConfig(BaseSettings):
46-
REDIS_URI: str
47-
REDIS_PROJECT_TAGS_DB: int = 18
48-
49-
5045
class _BasePathConf(BaseSettings):
5146
BASE_PATH: str = "/code/data"
5247

@@ -64,7 +59,6 @@ def validate_paths(self) -> Self:
6459

6560
ModuleConfig = _ModuleConfig()
6661
PostgresConfig = _PostgresConfig()
67-
RedisConfig = _RedisConfig()
6862
PathConfig = _PathConf()
6963

70-
__all__ = ["ModuleConfig", "PostgresConfig", "RedisConfig", "PathConfig"]
64+
__all__ = ["ModuleConfig", "PostgresConfig", "PathConfig"]

sql_db_utils/constants.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from enum import StrEnum
22

3-
ENUMPARENT = (StrEnum,)
43

5-
6-
class QueryType(*ENUMPARENT):
4+
class QueryType(StrEnum):
75
"""
86
An enumeration representing the different types of queries that can be executed.
97
@@ -18,7 +16,7 @@ class QueryType(*ENUMPARENT):
1816
POLAR = "polars"
1917

2018

21-
class AGGridDateTrim(*ENUMPARENT):
19+
class AGGridDateTrim(StrEnum):
2220
"""
2321
An enumeration representing the different date trimming options
2422

0 commit comments

Comments
 (0)