Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions app/api/data_jobs_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask import Blueprint, jsonify, request

from app.services.data_jobs.service import DataJobService
from app.services.wide_table_status import get_wide_table_status


data_jobs_bp = Blueprint("data_jobs", __name__, url_prefix="/api/data-jobs")
Expand All @@ -22,6 +23,16 @@ def submit_job():
if not job_type:
return jsonify({"success": False, "error": "missing job_type"}), 400

# 大宽表构建任务需通过 18:00 校验
if job_type == "wide_table_builder":
from flask import current_app
status = get_wide_table_status(current_app.config.get("DATA_DIR"))
if not status["past_cutoff"]:
return jsonify({
"success": False,
"error": "当前时间未过 18:00,数据源可能尚未下载完毕,请稍后再试",
}), 400

params = payload.get("params", {})
try:
run = get_data_job_service().submit(job_type, params)
Expand Down Expand Up @@ -76,3 +87,47 @@ def retry_run(run_id: int):
return jsonify({"success": False, "error": str(exc)}), 500

return jsonify({"success": True, "run_id": run.id, "status": run.status})


@data_jobs_bp.route("/wide-table/status", methods=["GET"])
def wide_table_status():
"""返回大宽表状态:是否存在、日期、是否需要更新、是否过了 18:00。"""
try:
from flask import current_app
data_dir = current_app.config.get("DATA_DIR")
status = get_wide_table_status(data_dir)
return jsonify({"success": True, "status": status})
except Exception as exc:
return jsonify({"success": False, "error": str(exc)}), 500


@data_jobs_bp.route("/wide-table/build", methods=["POST"])
def build_wide_table():
"""提交大宽表构建任务。需通过 18:00 校验。"""
try:
# 后端强制 18:00 校验,防止绕过 UI 直接调 API
from flask import current_app
status = get_wide_table_status(current_app.config.get("DATA_DIR"))
if not status["past_cutoff"]:
return jsonify({
"success": False,
"error": "当前时间未过 18:00,数据源可能尚未下载完毕,请稍后再试",
}), 400

run = get_data_job_service().submit("wide_table_builder", {})
# 构建成功后清缓存(inline 模式立即生效,celery 模式在 task 完成后清)
if run.status == "success":
from app.services.data_reader import ParquetDataReader
ParquetDataReader.invalidate_stock_business_cache()
from app.services.text2sql_engine import get_text2sql_engine
get_text2sql_engine().query_executor.invalidate_cache()
return jsonify({
"success": True,
"run_id": run.id,
"job_type": run.job_type,
"status": run.status,
})
except (KeyError, ValueError) as exc:
return jsonify({"success": False, "error": str(exc)}), 400
except Exception as exc:
return jsonify({"success": False, "error": str(exc)}), 500
12 changes: 12 additions & 0 deletions app/services/data_jobs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self) -> None:
"moneyflow", # 15
"stk_factor", # 17
"cyq_perf", # 18
"wide_table_builder", # 20
}

self._jobs: Dict[str, JobDefinition] = {
Expand Down Expand Up @@ -127,6 +128,17 @@ def __init__(self) -> None:
source_name="derived",
source_mode="derived",
),
"wide_table_builder": JobDefinition(
"wide_table_builder",
"衍生计算",
"app/utils/wide_table_builder.py",
display_name="大宽表构建",
description="合并日线基本指标、技术因子、资金流向和股票基础资料为最新交易日大宽表。",
recommended_order=10,
source_name="derived",
source_mode="derived",
dependencies=["daily_basic", "stk_factor", "moneyflow", "stock_basic"],
),
}

def get_job(self, job_type: str) -> JobDefinition:
Expand Down
5 changes: 5 additions & 0 deletions app/services/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def get_index_basic(self) -> pd.DataFrame:

_stock_business_cache: Optional[pd.DataFrame] = None

@classmethod
def invalidate_stock_business_cache(cls):
"""清除 stock_business 缓存,使下次读取重新加载文件。"""
cls._stock_business_cache = None

def get_stock_business(self, ts_code: Optional[str] = None,
trade_date: Optional[str] = None) -> pd.DataFrame:
"""读取股票业务大宽表(daily_basic + factor + moneyflow 合并)。"""
Expand Down
66 changes: 42 additions & 24 deletions app/services/nlp_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def _preprocess(self, query: str) -> str:
"""预处理查询文本"""
# 去除多余空格
query = re.sub(r'\s+', ' ', query.strip())

# 统一标点符号
query = query.replace(',', ',').replace('。', '.').replace('?', '?')

# 统一数字格式
query = re.sub(r'(\d+)%', r'\1百分比', query)

# 统一数字格式(保留 % 供后续 percentage 模式匹配)
query = re.sub(r'(\d+)元', r'\1', query)

return query


Expand Down Expand Up @@ -199,14 +198,15 @@ def __init__(self):
'price_fields': ['收盘价', '开盘价', '最高价', '最低价', 'close', 'open', 'high', 'low', '价格'],
'volume_fields': ['成交量', '成交额', 'vol', 'amount', 'volume', '交易量', '交易额'],
'ratio_fields': ['涨跌幅', '涨幅', '跌幅', 'pct_change', '换手率', 'turnover_rate', '量比', 'volume_ratio'],
'valuation_fields': ['市盈率', 'PE', 'pe_ttm', '市净率', 'PB', 'pb', 'pe'],
'technical_fields': ['MACD', 'RSI', 'KDJ', '布林带', '均线', 'MA']
'valuation_fields': ['市盈率', 'PE', 'pe_ttm', '市净率', 'PB', 'pb', 'pe', 'ROE', 'ROA', 'roe', 'roa'],
'technical_fields': ['MACD', 'RSI', 'KDJ', '布林带', '均线', 'MA'],
'money_flow_fields': ['资金', '资金流', '净流入', '净流出', '主力资金', '资金净流入', '资金净流出']
}

# 字段到数据库字段的映射
self.field_db_mapping = {
'市盈率': 'pe_ttm',
'PE': 'pe_ttm',
'PE': 'pe_ttm',
'pe': 'pe_ttm',
'市净率': 'pb',
'PB': 'pb',
Expand All @@ -217,7 +217,18 @@ def __init__(self):
'收盘价': 'daily_close',
'涨跌幅': 'factor_pct_change',
'成交量': 'factor_vol',
'成交额': 'amount'
'成交额': 'amount',
'ROE': 'roe',
'ROA': 'roa',
'roe': 'roe',
'roa': 'roa',
'资金': 'net_mf_amount',
'资金流': 'net_mf_amount',
'净流入': 'net_mf_amount',
'资金净流入': 'net_mf_amount',
'净流出': 'net_mf_amount',
'资金净流出': 'net_mf_amount',
'主力资金': 'net_mf_amount',
}

def extract(self, query: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -299,10 +310,11 @@ def _extract_single_condition(self, condition: str) -> Optional[Dict[str, Any]]:
def _is_technical_indicator_condition(self, condition: str) -> bool:
"""判断是否为技术指标条件"""
technical_patterns = [
r'MACD.*金叉', r'MACD.*死叉', r'MACD.*向上', r'MACD.*向下',
r'MACD.*金叉', r'金叉.*MACD', r'MACD.*死叉', r'MACD.*向上', r'MACD.*向下',
r'RSI.*超买', r'RSI.*超卖', r'RSI.*大于', r'RSI.*小于',
r'KDJ.*金叉', r'KDJ.*死叉',
r'均线.*金叉', r'均线.*死叉', r'均线.*多头', r'均线.*空头'
r'均线.*金叉', r'均线.*死叉', r'均线.*多头', r'均线.*空头',
r'金叉.*股票', r'死叉.*股票',
]

for pattern in technical_patterns:
Expand All @@ -314,18 +326,18 @@ def _is_technical_indicator_condition(self, condition: str) -> bool:
def _extract_technical_condition(self, condition: str) -> Dict[str, Any]:
"""提取技术指标条件"""
condition_entity = {}
if 'MACD' in condition:

if 'MACD' in condition or ('金叉' in condition and 'KDJ' not in condition and '均线' not in condition):
condition_entity['field'] = {
'name': 'MACD',
'original': 'MACD',
'category': 'technical_fields',
'db_field': 'macd'
}

if '金叉' in condition:
condition_entity['comparison'] = 'golden_cross'
condition_entity['value'] = 'golden_cross' # 特殊值表示金叉
condition_entity['value'] = 'golden_cross'
elif '死叉' in condition:
condition_entity['comparison'] = 'death_cross'
condition_entity['value'] = 'death_cross'
Expand Down Expand Up @@ -483,37 +495,43 @@ def _extract_fields(self, query: str) -> Dict[str, Any]:
def _extract_sorting(self, query: str) -> Dict[str, Any]:
"""提取排序信息"""
sorting = {}
if re.search(r'排名|排序|排列', query):

if re.search(r'排名|排序|排列|最多|最大|最高|前\d+.*名|前\d+.*只', query):
sorting['sort'] = True

if re.search(r'升序|从小到大|asc', query):
sorting['order'] = 'asc'
elif re.search(r'降序|从大到小|desc', query):
sorting['order'] = 'desc'
else:
sorting['order'] = 'desc' # 默认降序

return sorting

def _extract_limits(self, query: str) -> Dict[str, Any]:
"""提取限制数量"""
limits = {}
# 提取前N名

# 提取前N名/前N只
top_match = re.search(r'前(\d+)(?:名|个|只|支)?', query)
if top_match:
limits['limit'] = int(top_match.group(1))

# 提取top N
top_match = re.search(r'top\s*(\d+)', query, re.IGNORECASE)
if top_match:
limits['limit'] = int(top_match.group(1))


# 提取"N只股票"(无"前"前缀,如"最多的10只股票")
if 'limit' not in limits:
count_match = re.search(r'(\d+)(?:只|支)股票', query)
if count_match:
limits['limit'] = int(count_match.group(1))

# 默认限制
if 'limit' not in limits:
limits['limit'] = 20

return limits


Expand Down
Loading
Loading