Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ def __init__(self, *args, **kwargs):
self.retry_delay = kwargs.pop("retry_delay", 1)
super().__init__(*args, **kwargs)

def manual_close(self):
"""
Close the current thread's connection and return it to the pool.

In peewee's PooledMySQLDatabase, connections are not automatically returned
to the pool after execute_sql(). They only get returned when close() is called.
This method should be called after database operations to prevent connection
pool exhaustion, especially in scenarios with many concurrent operations.
"""
try:
if not self.is_closed():
self.close()
except Exception as e:
logging.debug(f"manual_close failed: {e}")

def execute_sql(self, sql, params=None, commit=True):
for attempt in range(self.max_retries + 1):
try:
Expand Down
88 changes: 74 additions & 14 deletions rag/utils/ob_redis_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
from datetime import datetime, timedelta
from decimal import Decimal
from functools import wraps

import trio
from peewee import IntegrityError, ProgrammingError
Expand All @@ -27,6 +28,25 @@ def get_db():
return DATABASE


def release_connection(func):
"""
Decorator to ensure database connection is returned to pool after method execution.

Peewee's PooledMySQLDatabase does not automatically return connections to the pool
after execute_sql(). This decorator ensures connections are properly released
to prevent connection pool exhaustion under high concurrency.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
finally:
if isinstance(self.db, RetryingPooledMySQLDatabase):
self.db.manual_close()
return wrapper



# 由于这里数据库 message 返回值是str,而 redis_conn stream 接口返回的是 dict,所以 RedisMsg 接口初始化略有不同,因此单独声明一个 RedisMsg
class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message):
Expand Down Expand Up @@ -76,6 +96,7 @@ def __init__(self, db=None):
def register_scripts(self) -> None:
raise NotImplementedError("Not implemented")

@release_connection
def health(self):
try:
self.db.execute_sql("select 1 from dual")
Expand All @@ -86,13 +107,12 @@ def health(self):
def is_alive(self):
return self.health()

@release_connection
def exist(self, k):
if not self.db:
return
try:

cursor = self.db.execute_sql('select count(1) from cache where cache_key = %s and expire_time > now()', (k))
Comment thread
whhe marked this conversation as resolved.

ret = cursor.fetchone()
return ret[0] == 1
except Exception as e:
Expand All @@ -101,6 +121,7 @@ def exist(self, k):
else:
logging.warning("RedisDB.exist " + str(k) + " got exception: " + str(e))

@release_connection
def delete_if_equal(self, key: str, expected_value: str) -> bool:
try:
cursor = self.db.execute_sql('delete from cache where cache_key = %s and cache_value = %s and expire_time '
Expand All @@ -114,6 +135,7 @@ def delete_if_equal(self, key: str, expected_value: str) -> bool:
"RedisDB.delete_if_equal " + str(key) + ":" + str(expected_value) + " got exception: " + str(e))
return False

@release_connection
def delete(self, key) -> bool:
try:
self.db.execute_sql('delete from cache where cache_key = %s ', (key))
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -125,7 +147,7 @@ def delete(self, key) -> bool:
logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e))
return False

def deleteIfExpired(self, key) -> bool:
def _deleteIfExpired(self, key) -> bool:
try:
self.db.execute_sql('delete from cache where cache_key = %s and expire_time < now() ', key)
Comment thread
whhe marked this conversation as resolved.
return True
Expand All @@ -136,6 +158,11 @@ def deleteIfExpired(self, key) -> bool:
logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e))
return False

@release_connection
def deleteIfExpired(self, key) -> bool:
return self._deleteIfExpired(key)

@release_connection
def get(self, k):
if not self.db:
return None
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -150,9 +177,9 @@ def get(self, k):
else:
logging.warning("RedisDB.get " + str(k) + " got exception: " + str(e))

def set_obj(self, k, obj, exp=3600):
def _set_obj(self, k, obj, exp=3600):
try:
self.set_object(k, obj, exp)
self._set_object(k, obj, exp)
return True
except Exception as e:
if is_table_missing_exception(e):
Expand All @@ -161,12 +188,21 @@ def set_obj(self, k, obj, exp=3600):
logging.warning("RedisDB.set_obj " + str(k) + " got exception: " + str(e))
return False

def set_object(self, k, obj, exp=3600):
@release_connection
def set_obj(self, k, obj, exp=3600):
return self._set_obj(k, obj, exp)

def _set_object(self, k, obj, exp=3600):
expire_time = datetime.now() + timedelta(seconds=exp)
self.db.execute_sql('replace into cache (cache_key, cache_value, expire_time) values (%s, %s, %s)',
(k, json.dumps(obj, ensure_ascii=False), expire_time))
return True

@release_connection
def set_object(self, k, obj, exp=3600):
return self._set_object(k, obj, exp)

@release_connection
def set(self, k, v, exp=3600):
try:
expire_time = datetime.now() + timedelta(seconds=exp)
Expand All @@ -180,10 +216,11 @@ def set(self, k, v, exp=3600):
logging.warning("RedisDB.set " + str(k) + " got exception: " + str(e))
return False

@release_connection
def setNx(self, k, v, exp=3600):
try:
# 删除过期的kv
self.deleteIfExpired(k)
self._deleteIfExpired(k)
expire_time = datetime.now() + timedelta(seconds=exp)
self.db.execute_sql('insert into cache (cache_key, cache_value, expire_time) values (%s, %s, %s)',
(k, v, expire_time))
Expand All @@ -202,6 +239,7 @@ def transaction(self, key, value, exp=3600):
return self.setNx(key, value, exp)

# zset
@release_connection
def zadd(self, key: str, member: str, score: float):
try:
with self.db.atomic():
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -210,7 +248,7 @@ def zadd(self, key: str, member: str, score: float):
ret = cursor.fetchone()
if ret is None:
mp = {member: score}
return self.set_object(key, mp)
return self._set_object(key, mp)
else:
id = ret[0]
cursor = self.db.execute_sql(
Expand All @@ -228,6 +266,7 @@ def zadd(self, key: str, member: str, score: float):
logging.warning("RedisDB.zadd " + str(key) + " got exception: " + str(e))
return False

@release_connection
def zcount(self, key: str, min, max: float):
try:
cursor = self.db.execute_sql(
Expand All @@ -245,6 +284,7 @@ def zcount(self, key: str, min, max: float):
logging.warning("RedisDB.zcount " + str(key) + " got exception: " + str(e))
return 0

@release_connection
def zpopmin(self, key: str, count: int):
try:
with self.db.atomic() as trx:
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -264,7 +304,7 @@ def zpopmin(self, key: str, count: int):
break
for k, v in ret.items():
del mp[k]
self.set_object(key, mp)
self._set_object(key, mp)
return ret
except Exception as e:
if is_table_missing_exception(e):
Expand All @@ -273,6 +313,7 @@ def zpopmin(self, key: str, count: int):
logging.warning("RedisDB.zpopmin " + str(key) + " got exception: " + str(e))
return None

@release_connection
def sadd(self, key: str, member: str):
try:
with self.db.atomic():
Expand All @@ -281,18 +322,19 @@ def sadd(self, key: str, member: str):
ret = cursor.fetchone()
if ret is None:
st = {member}
return self.set_object(key, list(st))
return self._set_object(key, list(st))
else:
st = set(json.loads(ret[0]))
st.add(member)
return self.set_obj(key, list(st))
return self._set_obj(key, list(st))
except Exception as e:
if is_table_missing_exception(e):
pass
else:
logging.warning("RedisDB.sadd " + str(key) + " got exception: " + str(e))
return False

@release_connection
def srem(self, key: str, member: str):
try:
with self.db.atomic():
Expand All @@ -304,14 +346,15 @@ def srem(self, key: str, member: str):
else:
st = set(json.loads(ret[0]))
st.discard(member)
return self.set_object(key, list(st))
return self._set_object(key, list(st))
except Exception as e:
if is_table_missing_exception(e):
pass
else:
logging.warning("RedisDB.srem " + str(key) + " got exception: " + str(e))
return False

@release_connection
def smembers(self, key: str):
try:
cursor = self.db.execute_sql("select cache_value from cache where cache_key = %s and expire_time > "
Expand All @@ -328,6 +371,7 @@ def smembers(self, key: str):
logging.warning("RedisDB.smembers " + str(key) + " got exception: " + str(e))
return []

@release_connection
def zrangebyscore(self, key: str, min: float, max: float):
try:
cursor = self.db.execute_sql("select cache_value from cache where cache_key = %s and expire_time > now() ",
Expand All @@ -351,6 +395,7 @@ def zrangebyscore(self, key: str, min: float, max: float):
return None

# 以下是redis stream
@release_connection
def queue_product(self, queue, message) -> bool:
"""
向消息队列推送消息,如果消息队列不存在,则创建消息队列
Expand All @@ -370,6 +415,7 @@ def queue_product(self, queue, message) -> bool:
)
return False

@release_connection
def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">"):
"""
消费者拉取消息:
Comment thread
whhe marked this conversation as resolved.
Expand Down Expand Up @@ -463,6 +509,7 @@ def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">"):
)
return None

@release_connection
def get_pending_msg(self, queue, group_name):
"""
获取消费者组 {group_name} 对消息队列 {queue} 已经读取,但是没有 ACK 的消息。
Expand Down Expand Up @@ -544,7 +591,7 @@ def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name
'''
消息重新入队
'''

@release_connection
def requeue_msg(self, queue: str, group_name: str, msg_id: object):
"""
将未 ack 的消息重新入队列
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -570,6 +617,7 @@ def requeue_msg(self, queue: str, group_name: str, msg_id: object):
"RedisDB.requeue_msg " + str(queue) + " got exception: " + str(e)
)

@release_connection
def xack(self, queue: str, group_name: str, msg_id: object):
"""
提交消息 ack
Expand All @@ -578,6 +626,7 @@ def xack(self, queue: str, group_name: str, msg_id: object):
"update message_consumption set ack = true where stream = %s and group_name = %s and "
"message_id = %s", (queue, group_name, msg_id))

@release_connection
def queue_info(self, queue: str, group_name: str) -> dict | None:
"""
获取消息队列,某个消费者组的消费情况。本项目用到的属性有:
Comment thread
whhe marked this conversation as resolved.
Expand All @@ -591,7 +640,7 @@ def queue_info(self, queue: str, group_name: str) -> dict | None:
cursor = self.db.execute_sql(
"select count(1) from message_subscribe where stream = %s and group_name = %s", (queue, group_name))
ret = cursor.fetchone()
if ret == 0:
if ret is None or ret[0] == 0:
return None
else:
cursor = self.db.execute_sql(
Expand Down Expand Up @@ -635,6 +684,13 @@ def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
# blocking_timeout 没用到,预留
self.blocking_timeout = blocking_timeout

def _release_connection(self):
"""
Release the current thread's database connection back to the pool.
"""
if isinstance(self.db, RetryingPooledMySQLDatabase):
self.db.manual_close()

def acquire(self):
"""
获取锁
Expand Down Expand Up @@ -677,6 +733,8 @@ def doAcquire(self):
else:
logging.info(f"lock acquire failed:{self.lock_key}-{self.lock_value}-{e}")
return False
finally:
self._release_connection()

def release(self):
"""
Expand Down Expand Up @@ -705,6 +763,8 @@ def delete_if_equal(self):
else:
logging.warning(f"release lock failed:{self.lock_key}-{self.lock_value}-{e}")
return False
finally:
self._release_connection()


if __name__ == '__main__':
Expand Down
Loading