diff --git a/app/api/data_jobs_api.py b/app/api/data_jobs_api.py index 99437fbb9..96c66ee3c 100644 --- a/app/api/data_jobs_api.py +++ b/app/api/data_jobs_api.py @@ -1,6 +1,7 @@ from flask import Blueprint, jsonify, request from app.services.data_jobs.service import DataJobService +from app.services.wide_table_status import get_wide_table_status data_jobs_bp = Blueprint("data_jobs", __name__, url_prefix="/api/data-jobs") @@ -22,6 +23,16 @@ def submit_job(): if not job_type: return jsonify({"success": False, "error": "missing job_type"}), 400 + # 大宽表构建任务需通过 18:00 校验 + if job_type == "wide_table_builder": + from flask import current_app + status = get_wide_table_status(current_app.config.get("DATA_DIR")) + if not status["past_cutoff"]: + return jsonify({ + "success": False, + "error": "当前时间未过 18:00,数据源可能尚未下载完毕,请稍后再试", + }), 400 + params = payload.get("params", {}) try: run = get_data_job_service().submit(job_type, params) @@ -76,3 +87,47 @@ def retry_run(run_id: int): return jsonify({"success": False, "error": str(exc)}), 500 return jsonify({"success": True, "run_id": run.id, "status": run.status}) + + +@data_jobs_bp.route("/wide-table/status", methods=["GET"]) +def wide_table_status(): + """返回大宽表状态:是否存在、日期、是否需要更新、是否过了 18:00。""" + try: + from flask import current_app + data_dir = current_app.config.get("DATA_DIR") + status = get_wide_table_status(data_dir) + return jsonify({"success": True, "status": status}) + except Exception as exc: + return jsonify({"success": False, "error": str(exc)}), 500 + + +@data_jobs_bp.route("/wide-table/build", methods=["POST"]) +def build_wide_table(): + """提交大宽表构建任务。需通过 18:00 校验。""" + try: + # 后端强制 18:00 校验,防止绕过 UI 直接调 API + from flask import current_app + status = get_wide_table_status(current_app.config.get("DATA_DIR")) + if not status["past_cutoff"]: + return jsonify({ + "success": False, + "error": "当前时间未过 18:00,数据源可能尚未下载完毕,请稍后再试", + }), 400 + + run = get_data_job_service().submit("wide_table_builder", {}) + # 构建成功后清缓存(inline 模式立即生效,celery 模式在 task 完成后清) + if run.status == "success": + from app.services.data_reader import ParquetDataReader + ParquetDataReader.invalidate_stock_business_cache() + from app.services.text2sql_engine import get_text2sql_engine + get_text2sql_engine().query_executor.invalidate_cache() + return jsonify({ + "success": True, + "run_id": run.id, + "job_type": run.job_type, + "status": run.status, + }) + except (KeyError, ValueError) as exc: + return jsonify({"success": False, "error": str(exc)}), 400 + except Exception as exc: + return jsonify({"success": False, "error": str(exc)}), 500 diff --git a/app/services/data_jobs/registry.py b/app/services/data_jobs/registry.py index 2292fa450..a6db27c4e 100644 --- a/app/services/data_jobs/registry.py +++ b/app/services/data_jobs/registry.py @@ -18,6 +18,7 @@ def __init__(self) -> None: "moneyflow", # 15 "stk_factor", # 17 "cyq_perf", # 18 + "wide_table_builder", # 20 } self._jobs: Dict[str, JobDefinition] = { @@ -127,6 +128,17 @@ def __init__(self) -> None: source_name="derived", source_mode="derived", ), + "wide_table_builder": JobDefinition( + "wide_table_builder", + "衍生计算", + "app/utils/wide_table_builder.py", + display_name="大宽表构建", + description="合并日线基本指标、技术因子、资金流向和股票基础资料为最新交易日大宽表。", + recommended_order=10, + source_name="derived", + source_mode="derived", + dependencies=["daily_basic", "stk_factor", "moneyflow", "stock_basic"], + ), } def get_job(self, job_type: str) -> JobDefinition: diff --git a/app/services/data_reader.py b/app/services/data_reader.py index c33b1bad5..4ac18a32d 100644 --- a/app/services/data_reader.py +++ b/app/services/data_reader.py @@ -257,6 +257,11 @@ def get_index_basic(self) -> pd.DataFrame: _stock_business_cache: Optional[pd.DataFrame] = None + @classmethod + def invalidate_stock_business_cache(cls): + """清除 stock_business 缓存,使下次读取重新加载文件。""" + cls._stock_business_cache = None + def get_stock_business(self, ts_code: Optional[str] = None, trade_date: Optional[str] = None) -> pd.DataFrame: """读取股票业务大宽表(daily_basic + factor + moneyflow 合并)。""" diff --git a/app/services/nlp_processor.py b/app/services/nlp_processor.py index e2d3f2f10..1bc677714 100644 --- a/app/services/nlp_processor.py +++ b/app/services/nlp_processor.py @@ -73,14 +73,13 @@ def _preprocess(self, query: str) -> str: """预处理查询文本""" # 去除多余空格 query = re.sub(r'\s+', ' ', query.strip()) - + # 统一标点符号 query = query.replace(',', ',').replace('。', '.').replace('?', '?') - - # 统一数字格式 - query = re.sub(r'(\d+)%', r'\1百分比', query) + + # 统一数字格式(保留 % 供后续 percentage 模式匹配) query = re.sub(r'(\d+)元', r'\1', query) - + return query @@ -199,14 +198,15 @@ def __init__(self): 'price_fields': ['收盘价', '开盘价', '最高价', '最低价', 'close', 'open', 'high', 'low', '价格'], 'volume_fields': ['成交量', '成交额', 'vol', 'amount', 'volume', '交易量', '交易额'], 'ratio_fields': ['涨跌幅', '涨幅', '跌幅', 'pct_change', '换手率', 'turnover_rate', '量比', 'volume_ratio'], - 'valuation_fields': ['市盈率', 'PE', 'pe_ttm', '市净率', 'PB', 'pb', 'pe'], - 'technical_fields': ['MACD', 'RSI', 'KDJ', '布林带', '均线', 'MA'] + 'valuation_fields': ['市盈率', 'PE', 'pe_ttm', '市净率', 'PB', 'pb', 'pe', 'ROE', 'ROA', 'roe', 'roa'], + 'technical_fields': ['MACD', 'RSI', 'KDJ', '布林带', '均线', 'MA'], + 'money_flow_fields': ['资金', '资金流', '净流入', '净流出', '主力资金', '资金净流入', '资金净流出'] } # 字段到数据库字段的映射 self.field_db_mapping = { '市盈率': 'pe_ttm', - 'PE': 'pe_ttm', + 'PE': 'pe_ttm', 'pe': 'pe_ttm', '市净率': 'pb', 'PB': 'pb', @@ -217,7 +217,18 @@ def __init__(self): '收盘价': 'daily_close', '涨跌幅': 'factor_pct_change', '成交量': 'factor_vol', - '成交额': 'amount' + '成交额': 'amount', + 'ROE': 'roe', + 'ROA': 'roa', + 'roe': 'roe', + 'roa': 'roa', + '资金': 'net_mf_amount', + '资金流': 'net_mf_amount', + '净流入': 'net_mf_amount', + '资金净流入': 'net_mf_amount', + '净流出': 'net_mf_amount', + '资金净流出': 'net_mf_amount', + '主力资金': 'net_mf_amount', } def extract(self, query: str) -> Dict[str, Any]: @@ -299,10 +310,11 @@ def _extract_single_condition(self, condition: str) -> Optional[Dict[str, Any]]: def _is_technical_indicator_condition(self, condition: str) -> bool: """判断是否为技术指标条件""" technical_patterns = [ - r'MACD.*金叉', r'MACD.*死叉', r'MACD.*向上', r'MACD.*向下', + r'MACD.*金叉', r'金叉.*MACD', r'MACD.*死叉', r'MACD.*向上', r'MACD.*向下', r'RSI.*超买', r'RSI.*超卖', r'RSI.*大于', r'RSI.*小于', r'KDJ.*金叉', r'KDJ.*死叉', - r'均线.*金叉', r'均线.*死叉', r'均线.*多头', r'均线.*空头' + r'均线.*金叉', r'均线.*死叉', r'均线.*多头', r'均线.*空头', + r'金叉.*股票', r'死叉.*股票', ] for pattern in technical_patterns: @@ -314,18 +326,18 @@ def _is_technical_indicator_condition(self, condition: str) -> bool: def _extract_technical_condition(self, condition: str) -> Dict[str, Any]: """提取技术指标条件""" condition_entity = {} - - if 'MACD' in condition: + + if 'MACD' in condition or ('金叉' in condition and 'KDJ' not in condition and '均线' not in condition): condition_entity['field'] = { 'name': 'MACD', 'original': 'MACD', 'category': 'technical_fields', 'db_field': 'macd' } - + if '金叉' in condition: condition_entity['comparison'] = 'golden_cross' - condition_entity['value'] = 'golden_cross' # 特殊值表示金叉 + condition_entity['value'] = 'golden_cross' elif '死叉' in condition: condition_entity['comparison'] = 'death_cross' condition_entity['value'] = 'death_cross' @@ -483,37 +495,43 @@ def _extract_fields(self, query: str) -> Dict[str, Any]: def _extract_sorting(self, query: str) -> Dict[str, Any]: """提取排序信息""" sorting = {} - - if re.search(r'排名|排序|排列', query): + + if re.search(r'排名|排序|排列|最多|最大|最高|前\d+.*名|前\d+.*只', query): sorting['sort'] = True - + if re.search(r'升序|从小到大|asc', query): sorting['order'] = 'asc' elif re.search(r'降序|从大到小|desc', query): sorting['order'] = 'desc' else: sorting['order'] = 'desc' # 默认降序 - + return sorting def _extract_limits(self, query: str) -> Dict[str, Any]: """提取限制数量""" limits = {} - - # 提取前N名 + + # 提取前N名/前N只 top_match = re.search(r'前(\d+)(?:名|个|只|支)?', query) if top_match: limits['limit'] = int(top_match.group(1)) - + # 提取top N top_match = re.search(r'top\s*(\d+)', query, re.IGNORECASE) if top_match: limits['limit'] = int(top_match.group(1)) - + + # 提取"N只股票"(无"前"前缀,如"最多的10只股票") + if 'limit' not in limits: + count_match = re.search(r'(\d+)(?:只|支)股票', query) + if count_match: + limits['limit'] = int(count_match.group(1)) + # 默认限制 if 'limit' not in limits: limits['limit'] = 20 - + return limits diff --git a/app/services/sql_generator.py b/app/services/sql_generator.py index 3fc74e1cc..3ff4f62f9 100644 --- a/app/services/sql_generator.py +++ b/app/services/sql_generator.py @@ -21,16 +21,31 @@ def generate_sql(self, intent_result: Dict[str, Any]) -> Dict[str, Any]: try: intent = intent_result['intent']['name'] entities = intent_result['entities'] - - # 检查是否有多条件查询 - if 'conditions' in entities and len(entities['conditions']) > 1: - # 处理多条件查询 - sql = self._build_multi_condition_sql(entities) + + # 判断是否需要多表 JOIN 的查询 + has_conditions = 'conditions' in entities and len(entities['conditions']) > 0 + has_technical = False + has_money_flow = False + if has_conditions: + for cond in entities['conditions']: + cat = cond.get('field', {}).get('category', '') + name = cond.get('field', {}).get('name', '') + if cat == 'technical_fields' or name in ('MACD', 'KDJ', 'RSI', '均线'): + has_technical = True + if cat == 'money_flow_fields' or name in ('资金', '资金流', '净流入', '资金净流入'): + has_money_flow = True + + # 技术指标 / 资金流向 / 多条件 → 统一走 multi-condition builder + if has_technical or has_money_flow: + sql = self._build_multi_condition_sql(entities, intent) + template_used = 'multi_condition_dynamic' + elif has_conditions and len(entities['conditions']) > 1: + sql = self._build_multi_condition_sql(entities, intent) template_used = 'multi_condition_dynamic' else: # 1. 尝试使用模板生成 template_result = self.template_manager.generate_from_template(intent, entities) - + if template_result['success']: sql = template_result['sql'] template_used = template_result['template_id'] @@ -38,11 +53,11 @@ def generate_sql(self, intent_result: Dict[str, Any]) -> Dict[str, Any]: # 2. 使用动态构建器生成 sql = self.query_builder.build_dynamic_sql(intent, entities) template_used = None - + # 3. SQL优化和验证 optimized_sql = self._optimize_sql(sql) validation_result = self._validate_sql(optimized_sql) - + return { 'success': validation_result['valid'], 'sql': optimized_sql, @@ -50,7 +65,7 @@ def generate_sql(self, intent_result: Dict[str, Any]) -> Dict[str, Any]: 'error': validation_result.get('error'), 'explanation': self._generate_explanation(intent, entities) } - + except Exception as e: return { 'success': False, @@ -129,103 +144,151 @@ def _generate_explanation(self, intent: str, entities: Dict[str, Any]) -> str: return base_explanation - def _build_multi_condition_sql(self, entities: Dict[str, Any]) -> str: - """构建多条件查询SQL""" - conditions = entities['conditions'] - - # 分析需要的表 - tables_needed = set(['stock_business']) # 主表 + def _build_multi_condition_sql(self, entities: Dict[str, Any], + intent: str = '') -> str: + """构建多条件 / 多表 JOIN 查询 SQL + 支持技术指标(stock_factor)、资金流向(stock_moneyflow)等跨表查询。 + """ + conditions = entities.get('conditions', []) + + # ---- 1. 分类条件 & 确定需要的表 ---- + tables_needed = {'stock_business'} technical_conditions = [] business_conditions = [] - + money_flow_conditions = [] + for condition in conditions: field_info = condition.get('field', {}) field_category = field_info.get('category', '') - - if field_category == 'technical_fields' or 'MACD' in field_info.get('name', ''): + field_name = field_info.get('name', '') + + if field_category == 'technical_fields' or field_name in ('MACD', 'KDJ', 'RSI', '均线'): tables_needed.add('stock_factor') technical_conditions.append(condition) + elif field_category == 'money_flow_fields' or field_name in ('资金', '资金流', '净流入', '资金净流入'): + tables_needed.add('stock_moneyflow') + money_flow_conditions.append(condition) else: business_conditions.append(condition) - - # 构建SELECT子句 + + # 资金流向意图但没有条件时,仍然需要 JOIN + if intent == 'money_flow' and 'stock_moneyflow' not in tables_needed: + tables_needed.add('stock_moneyflow') + money_flow_conditions.append({ + 'field': {'name': '资金净流入', 'db_field': 'net_mf_amount', + 'category': 'money_flow_fields'}, + 'comparison': 'greater_than', 'value': 0 + }) + + # ---- 2. SELECT 子句 ---- select_fields = ['sb.ts_code', 'sb.stock_name'] - - # 添加查询条件中涉及的字段 + for condition in conditions: field_info = condition.get('field', {}) db_field = field_info.get('db_field', field_info.get('name', '')) - - if field_info.get('category') == 'technical_fields': + + if field_info.get('category') == 'technical_fields' or \ + field_info.get('name') in ('MACD', 'KDJ', 'RSI', '均线'): if 'MACD' in field_info.get('name', ''): - select_fields.extend(['sf.macd_dif', 'sf.macd_dea', 'sf.macd']) + for f in ('sf.macd_dif', 'sf.macd_dea', 'sf.macd'): + if f not in select_fields: + select_fields.append(f) else: - select_fields.append(f'sf.{db_field}') + f = f'sf.{db_field}' + if f not in select_fields: + select_fields.append(f) + elif field_info.get('category') == 'money_flow_fields' or \ + field_info.get('name') in ('资金', '资金流', '净流入', '资金净流入'): + f = f'sm.{db_field}' + if f not in select_fields: + select_fields.append(f) else: - if db_field not in ['ts_code', 'stock_name']: - select_fields.append(f'sb.{db_field}') - - # 去重 + if db_field not in ('ts_code', 'stock_name'): + f = f'sb.{db_field}' + if f not in select_fields: + select_fields.append(f) + + # 资金流向意图自动补 net_mf_amount + if 'stock_moneyflow' in tables_needed and not any( + 'net_mf_amount' in f for f in select_fields): + select_fields.append('sm.net_mf_amount') + select_fields = list(dict.fromkeys(select_fields)) - - # 构建FROM子句 + + # ---- 3. FROM / JOIN 子句 ---- from_clause = "FROM stock_business sb" if 'stock_factor' in tables_needed: from_clause += "\nJOIN stock_factor sf ON sb.ts_code = sf.ts_code" - - # 构建WHERE子句 + if 'stock_moneyflow' in tables_needed: + from_clause += "\nJOIN stock_moneyflow sm ON sb.ts_code = sm.ts_code" + + # ---- 4. WHERE 子句 ---- where_conditions = [] - - # 处理业务条件 + for condition in business_conditions: - condition_sql = self._build_single_condition_sql(condition, 'sb') - if condition_sql: - where_conditions.append(condition_sql) - - # 处理技术指标条件 + sql = self._build_single_condition_sql(condition, 'sb') + if sql: + where_conditions.append(sql) + for condition in technical_conditions: - condition_sql = self._build_technical_condition_sql(condition) - if condition_sql: - where_conditions.append(condition_sql) - - # 确保取最新数据 - if 'stock_factor' in tables_needed: - where_conditions.append("""sf.trade_date = ( - SELECT MAX(trade_date) FROM stock_factor WHERE ts_code = sb.ts_code - )""") - - # 基本过滤条件 + sql = self._build_technical_condition_sql(condition) + if sql: + where_conditions.append(sql) + + for condition in money_flow_conditions: + field_info = condition.get('field', {}) + db_field = field_info.get('db_field', 'net_mf_amount') + comparison = condition.get('comparison') + # 没有显式比较操作时,默认过滤 net_mf_amount > 0 + value = condition.get('value', 0) if comparison else 0 + if comparison == 'less_than': + where_conditions.append(f"sm.{db_field} < {value}") + else: + where_conditions.append(f"sm.{db_field} > {value}") + where_conditions.append("sb.ts_code IS NOT NULL") - - # 组装SQL + + # ---- 5. 组装 SQL ---- sql_parts = [ f"SELECT {', '.join(select_fields)}", from_clause, f"WHERE {' AND '.join(where_conditions)}" ] - - # 添加排序和限制 + + # ---- 6. ORDER BY ---- limit = entities.get('limit', 20) - - # 根据查询类型确定排序 + is_ranking = entities.get('sort') or intent in ('ranking', 'money_flow') + if technical_conditions: - # 如果有技术指标条件,按技术指标排序 - for condition in technical_conditions: - if 'MACD' in condition.get('field', {}).get('name', ''): + for cond in technical_conditions: + name = cond.get('field', {}).get('name', '') + if 'MACD' in name: sql_parts.append("ORDER BY sf.macd DESC") break + elif 'RSI' in name: + sql_parts.append("ORDER BY sf.rsi_6 DESC") + break + elif money_flow_conditions: + sql_parts.append("ORDER BY sm.net_mf_amount DESC") + elif is_ranking: + # ranking: 按第一个业务字段 DESC + for cond in business_conditions: + db_field = cond.get('field', {}).get('db_field', '') + if db_field: + sql_parts.append(f"ORDER BY sb.{db_field} DESC") + break else: - # 否则按第一个数值字段排序 - for condition in business_conditions: - field_info = condition.get('field', {}) + for cond in business_conditions: + field_info = cond.get('field', {}) db_field = field_info.get('db_field', '') - if db_field and condition.get('comparison') in ['greater_than', 'less_than']: - order = 'DESC' if condition.get('comparison') == 'greater_than' else 'ASC' + comp = cond.get('comparison') + if db_field and comp in ('greater_than', 'less_than'): + order = 'DESC' if comp == 'greater_than' else 'ASC' sql_parts.append(f"ORDER BY sb.{db_field} {order}") break - + sql_parts.append(f"LIMIT {limit}") - + return '\n'.join(sql_parts) def _build_single_condition_sql(self, condition: Dict[str, Any], table_alias: str = '') -> Optional[str]: @@ -593,15 +656,15 @@ def _determine_main_table(self, intent: str, entities: Dict[str, Any]) -> str: def _build_select_clause(self, entities: Dict[str, Any], main_table: str) -> str: """构建SELECT子句""" select_fields = [] - - # 基础字段 + + # 基础字段 — 始终包含 stock_name(通过 JOIN 或直接取) if main_table == 'stock_business': select_fields.extend(['ts_code', 'stock_name']) elif main_table == 'stock_factor': select_fields.extend(['ts_code', 'trade_date']) elif main_table == 'stock_moneyflow': select_fields.extend(['ts_code', 'trade_date']) - + # 根据新的条件结构添加字段 if 'conditions' in entities: for condition in entities['conditions']: @@ -610,7 +673,7 @@ def _build_select_clause(self, entities: Dict[str, Any], main_table: str) -> str db_field = field_info.get('db_field', field_info['name']) if db_field not in select_fields: select_fields.append(db_field) - + # 兼容旧的实体结构 if 'fields' in entities: for field in entities['fields']: @@ -620,24 +683,27 @@ def _build_select_clause(self, entities: Dict[str, Any], main_table: str) -> str if mapping['table'] == main_table: if mapping['field'] not in select_fields: select_fields.append(mapping['field']) - + # 如果没有特定字段,添加常用字段 if len(select_fields) <= 2: # 只有基础字段 if main_table == 'stock_business': select_fields.extend(['daily_close', 'factor_pct_change']) - + # 去重并格式化 - select_fields = list(dict.fromkeys(select_fields)) # 去重保持顺序 - + select_fields = list(dict.fromkeys(select_fields)) + return f"SELECT {', '.join(select_fields)}" def _build_from_clause(self, main_table: str, entities: Dict[str, Any]) -> str: """构建FROM子句""" from_clause = f"FROM {main_table}" - - # 检查是否需要JOIN其他表 + + # 如果主表不是 stock_business,自动 JOIN 以获取 stock_name join_tables = set() - + if main_table != 'stock_business': + join_tables.add('stock_business') + + # 检查是否需要JOIN其他表 if 'fields' in entities: for field in entities['fields']: field_name = field['name'] @@ -648,7 +714,9 @@ def _build_from_clause(self, main_table: str, entities: Dict[str, Any]) -> str: # 添加JOIN子句 for join_table in join_tables: - if join_table == 'stock_factor' and main_table == 'stock_business': + if join_table == 'stock_business' and main_table != 'stock_business': + from_clause += f"\nJOIN stock_business sb ON {main_table}.ts_code = sb.ts_code" + elif join_table == 'stock_factor' and main_table == 'stock_business': from_clause += f"\nJOIN {join_table} sf ON {main_table}.ts_code = sf.ts_code" elif join_table == 'stock_moneyflow' and main_table == 'stock_business': from_clause += f"\nJOIN {join_table} sm ON {main_table}.ts_code = sm.ts_code" @@ -748,21 +816,29 @@ def _build_where_clause_old(self, entities: Dict[str, Any]) -> str: def _build_order_clause(self, entities: Dict[str, Any]) -> str: """构建ORDER BY子句""" - if 'sort' in entities: - order = entities.get('order', 'desc').upper() - - # 确定排序字段 - if 'fields' in entities: - field = entities['fields'][0] # 使用第一个字段排序 - field_name = field['name'] - if field_name in self.field_mappings: - mapping = self.field_mappings[field_name] - return f"ORDER BY {mapping['field']} {order}" - - # 默认排序 - return "ORDER BY daily_close DESC" - - return "" + if 'sort' not in entities: + return "" + + order = entities.get('order', 'desc').upper() + + # 从 conditions 中找排序字段 + if 'conditions' in entities: + for cond in entities['conditions']: + field_info = cond.get('field', {}) + db_field = field_info.get('db_field', '') + if db_field: + return f"ORDER BY {db_field} {order}" + + # 从 fields 中找排序字段 + if 'fields' in entities: + field = entities['fields'][0] + field_name = field['name'] + if field_name in self.field_mappings: + mapping = self.field_mappings[field_name] + return f"ORDER BY {mapping['field']} {order}" + + # 默认排序 + return "ORDER BY daily_close DESC" def _build_limit_clause(self, entities: Dict[str, Any]) -> str: """构建LIMIT子句""" diff --git a/app/services/text2sql_engine.py b/app/services/text2sql_engine.py index e8e14c2c4..917f851ac 100644 --- a/app/services/text2sql_engine.py +++ b/app/services/text2sql_engine.py @@ -4,9 +4,11 @@ """ import logging +import os +import re import time import traceback -from typing import Dict, List, Any, Optional +from typing import Dict, List, Any, Optional, Set logger = logging.getLogger(__name__) from flask import request @@ -241,40 +243,97 @@ def _try_llm_enhancement(self, user_query: str, intent_result: Dict[str, Any]) - class QueryExecutor: - """查询执行器""" - + """查询执行器 + 自动将 Parquet 数据按正确列名映射加载到 SQLite 临时表, + 使生成的 SQL 可以直接在 SQLite 上执行。 + """ + def __init__(self): - self.max_result_count = 1000 # 最大结果数量限制 - + self.max_result_count = 1000 + self._loaded_tables: Set[str] = set() + # {parquet_abs_path: mtime_at_load} — 用于检测文件是否被重建 + self._file_mtimes: Dict[str, float] = {} + + def invalidate_cache(self): + """清除已加载的临时表记录,下次查询时重新从 Parquet 加载。""" + self._loaded_tables.clear() + self._file_mtimes.clear() + + # ---- Parquet → SQLite 列名映射 ---- + # key = SQL 模板中使用的列名, value = Parquet 文件中的实际列名 + TABLE_COLUMNS = { + 'stock_business': { + 'ts_code': 'ts_code', 'stock_name': 'stock_name', + 'trade_date': 'trade_date', 'daily_close': 'close', + 'factor_pct_change': 'factor_pct_change', + 'vol': 'factor_vol', 'factor_vol': 'factor_vol', + 'amount': 'factor_amount', 'factor_amount': 'factor_amount', + 'pe_ttm': 'pe_ttm', 'pb': 'pb', 'pe': 'pe', + 'turnover_rate': 'turnover_rate', + 'total_mv': 'total_mv', 'circ_mv': 'circ_mv', + }, + 'stock_factor': { + 'ts_code': 'ts_code', 'trade_date': 'trade_date', + 'macd': 'factor_macd', 'macd_dif': 'factor_macd_dif', + 'macd_dea': 'factor_macd_dea', + 'rsi_6': 'factor_rsi_6', 'rsi_12': 'factor_rsi_12', + 'rsi_24': 'factor_rsi_24', + 'kdj_k': 'factor_kdj_k', 'kdj_d': 'factor_kdj_d', + 'kdj_j': 'factor_kdj_j', + }, + 'stock_moneyflow': { + 'ts_code': 'ts_code', 'trade_date': 'trade_date', + 'net_mf_amount': 'moneyflow_net_amount', + 'net_mf_vol': 'moneyflow_net_vol', + }, + 'stock_ma_data': { + 'ts_code': 'ts_code', + 'ma5': 'ma5', 'ma10': 'ma10', 'ma20': 'ma20', + 'ma30': 'ma30', 'ma60': 'ma60', 'ma120': 'ma120', + }, + } + + # 每张虚拟表对应的 Parquet 源文件 + TABLE_SOURCES = { + 'stock_business': 'stock_business.parquet', + 'stock_factor': 'stock_business.parquet', # 宽表已含因子数据 + 'stock_moneyflow': 'stock_business.parquet', # 宽表已含资金流数据 + 'stock_ma_data': 'stock_ma_data.parquet', + } + + # ---- public ---- + def execute(self, sql: str) -> Dict[str, Any]: """执行SQL查询""" try: if not sql: return {'success': False, 'error': 'SQL为空'} - + + # 确保 SQL 引用的数据表已从 Parquet 加载到 SQLite + self._ensure_data_tables(sql) + # 执行查询 result = db.session.execute(text(sql)) - + # 获取列名 columns = list(result.keys()) - + # 获取数据 rows = result.fetchall() - + # 检查结果数量限制 if len(rows) > self.max_result_count: return { 'success': False, 'error': f'查询结果过多({len(rows)}条),请添加更多筛选条件' } - + # 转换为字典列表 data = [] for row in rows: row_dict = {} for i, column in enumerate(columns): value = row[i] - # 处理特殊数据类型 if value is not None: if isinstance(value, (int, float)): row_dict[column] = value @@ -283,17 +342,17 @@ def execute(self, sql: str) -> Dict[str, Any]: else: row_dict[column] = None data.append(row_dict) - + return { 'success': True, 'data': data, 'columns': columns, 'row_count': len(data) } - + except Exception as e: error_msg = str(e) - + # 处理常见的数据库错误 if 'no such table' in error_msg.lower(): error_msg = '数据表不存在,请检查数据库配置' @@ -301,7 +360,7 @@ def execute(self, sql: str) -> Dict[str, Any]: error_msg = '字段不存在,请检查查询条件' elif 'syntax error' in error_msg.lower(): error_msg = 'SQL语法错误' - + return { 'success': False, 'error': error_msg, @@ -310,6 +369,126 @@ def execute(self, sql: str) -> Dict[str, Any]: 'row_count': 0 } + # ---- private: Parquet → SQLite 桥接 ---- + + def _ensure_data_tables(self, sql: str): + """检查 SQL 引用的表,若缺失或源文件已变更则重新加载。 + 同一 Parquet 文件可能对应多个虚拟表(如 stock_business.parquet + 同时是 stock_business / stock_factor / stock_moneyflow 的源), + 文件变更时需重置所有共享该文件的已加载虚拟表。 + """ + tables = self._extract_table_names(sql) + stale_tables = set() + for tbl in tables: + if tbl not in self.TABLE_COLUMNS: + continue + if tbl not in self._loaded_tables: + stale_tables.add(tbl) + elif self._is_parquet_stale(tbl): + stale_tables.add(tbl) + # 文件变更 → 重置所有共享同一 Parquet 源的已加载虚拟表 + stale_file = self.TABLE_SOURCES.get(tbl) + for loaded_tbl in list(self._loaded_tables): + if self.TABLE_SOURCES.get(loaded_tbl) == stale_file: + stale_tables.add(loaded_tbl) + + for tbl in stale_tables: + self._load_parquet_to_sqlite(tbl) + + def _is_parquet_stale(self, table_name: str) -> bool: + """检测 Parquet 源文件是否在上次加载后被修改(跨进程安全)。""" + try: + from flask import current_app + import os + data_dir = current_app.config.get('DATA_DIR', 'data') + if not os.path.isabs(data_dir): + data_dir = os.path.join(current_app.root_path, '..', data_dir) + parquet_file = self.TABLE_SOURCES.get(table_name) + if not parquet_file: + return False + parquet_path = os.path.join(data_dir, parquet_file) + if not os.path.exists(parquet_path): + return False + current_mtime = os.path.getmtime(parquet_path) + # 按 (table_name, parquet_path) 追踪,每个虚拟表独立记录 + cached_mtime = self._file_mtimes.get((table_name, parquet_path), 0) + return current_mtime > cached_mtime + except Exception: + return False + + @staticmethod + def _extract_table_names(sql: str) -> Set[str]: + """从 FROM / JOIN 子句提取表名""" + return set(re.findall(r'(?:FROM|JOIN)\s+(\w+)', sql, re.IGNORECASE)) + + def _load_parquet_to_sqlite(self, table_name: str): + """从 Parquet 文件加载数据到 SQLite 临时表""" + try: + import pandas as pd + from flask import current_app + + parquet_file = self.TABLE_SOURCES[table_name] + data_dir = current_app.config.get('DATA_DIR', 'data') + if not os.path.isabs(data_dir): + data_dir = os.path.join(current_app.root_path, '..', data_dir) + parquet_path = os.path.join(data_dir, parquet_file) + + if not os.path.exists(parquet_path): + logger.warning(f"Parquet 文件不存在: {parquet_path}") + return + + df = pd.read_parquet(parquet_path) + + # 只取最新交易日数据(约 5K 行),保持 SQLite 轻量 + if 'trade_date' in df.columns: + latest_date = df['trade_date'].max() + df = df[df['trade_date'] == latest_date] + + # 按 TABLE_COLUMNS 映射选取并重命名列 + col_map = self.TABLE_COLUMNS[table_name] + available = {sql_col: pq_col + for sql_col, pq_col in col_map.items() + if pq_col in df.columns} + if not available: + return + + # 去重 Parquet 列:多个 SQL 名可能映射到同一个 Parquet 列 + unique_pq_cols = list(dict.fromkeys(available.values())) + df = df[unique_pq_cols].copy() + + # 建立重命名映射:每个 Parquet 列 → 第一个出现的 SQL 列名 + pq_to_first_sql = {} + for sql_col, pq_col in available.items(): + if pq_col not in pq_to_first_sql: + pq_to_first_sql[pq_col] = sql_col + df.rename(columns=pq_to_first_sql, inplace=True) + + # 补充别名列(同一 Parquet 列的其它 SQL 名) + for sql_col, pq_col in available.items(): + primary_sql = pq_to_first_sql[pq_col] + if sql_col != primary_sql and sql_col not in df.columns: + df[sql_col] = df[primary_sql] + + # NaN → None(SQLite 不支持 NaN) + df = df.where(pd.notnull(df), None) + + # 写入 SQLite(通过 raw_connection 使用底层 sqlite3 连接, + # 兼容 SQLAlchemy 2.0 + Pandas 3.x) + raw_conn = db.engine.raw_connection() + try: + df.to_sql(table_name, raw_conn, if_exists='replace', index=False) + finally: + raw_conn.close() + + self._loaded_tables.add(table_name) + # 按 (虚拟表名, 文件路径) 记录 mtime,同一文件的不同虚拟表独立追踪 + self._file_mtimes[(table_name, parquet_path)] = os.path.getmtime(parquet_path) + logger.info(f"Loaded {len(df)} rows into '{table_name}' from {parquet_file}") + + except Exception as e: + logger.error(f"Failed to load '{table_name}' from Parquet: {e}") + db.session.rollback() + class ResultFormatter: """结果格式化器""" diff --git a/app/services/wide_table_status.py b/app/services/wide_table_status.py new file mode 100644 index 000000000..d322c9e60 --- /dev/null +++ b/app/services/wide_table_status.py @@ -0,0 +1,159 @@ +""" +大宽表状态检查 +- 判断宽表是否存在 / 是否过时 +- 6PM 校验:最新交易日 18:00 后才允许更新 +- 供 API 端点和 startup 调用 +""" + +import logging +import os +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional, Tuple + +import pandas as pd + +logger = logging.getLogger(__name__) + +# 数据源 → 分区目录映射 +SOURCE_TABLES = { + "daily_basic": "daily_basic/daily", + "stk_factor": "stk_factor/daily", + "moneyflow": "moneyflow/daily", +} + +WIDE_TABLE_FILE = "stock_business.parquet" +TRADE_CALENDAR_FILE = "stock_trade_calendar.parquet" +CUTOFF_HOUR = 18 # 下午 6 点 + + +def _resolve_data_dir(data_dir: Optional[str] = None) -> Path: + if data_dir: + return Path(data_dir) + return Path( + os.getenv("DATA_DIR", os.path.join(os.path.dirname(__file__), "..", "..", "data")) + ) + + +def _get_latest_partition_date(partition_dir: Path) -> Optional[str]: + """从 Hive 分区目录提取最新日期 (YYYY-MM-DD)。""" + if not partition_dir.exists(): + return None + max_date = None + for year_dir in partition_dir.iterdir(): + if not year_dir.is_dir() or not year_dir.name.startswith("year="): + continue + y = year_dir.name.split("=")[1] + for month_dir in year_dir.iterdir(): + if not month_dir.is_dir() or not month_dir.name.startswith("month="): + continue + m = month_dir.name.split("=")[1] + for day_dir in month_dir.iterdir(): + if not day_dir.is_dir() or not day_dir.name.startswith("day="): + continue + d = day_dir.name.split("=")[1] + date_str = f"{y}-{m}-{d}" + if max_date is None or date_str > max_date: + max_date = date_str + return max_date + + +def _is_past_cutoff(data_dir: Path) -> bool: + """判断当前时间是否已过最新交易日的 18:00。""" + cal_path = data_dir / TRADE_CALENDAR_FILE + if not cal_path.exists(): + # 没有日历,默认允许 + return True + + try: + cal = pd.read_parquet(cal_path) + except Exception: + return True + + open_days = cal[cal["is_open"] == 1]["cal_date"].astype(str) + if open_days.empty: + return True + + today_str = datetime.now().strftime("%Y%m%d") + # 找 <= 今天 的最新交易日 + latest_open = None + for d in sorted(open_days, reverse=True): + if d <= today_str: + latest_open = d + break + + if latest_open is None: + return False + + now = datetime.now() + if latest_open < today_str: + # 最新交易日是过去的日子 → 数据应已就绪 + return True + else: + # 今天就是最新交易日 → 检查是否过了 18:00 + return now.hour >= CUTOFF_HOUR + + +def get_wide_table_status(data_dir: Optional[str] = None) -> Dict[str, object]: + """返回宽表完整状态。""" + root = _resolve_data_dir(data_dir) + wide_path = root / WIDE_TABLE_FILE + + # 1. 宽表自身状态 + exists = wide_path.exists() and wide_path.stat().st_size > 0 + wide_table_date = None + if exists: + try: + df = pd.read_parquet(wide_path, columns=["trade_date"]) + if not df.empty: + td = pd.to_datetime(df["trade_date"]) + wide_table_date = td.max().strftime("%Y-%m-%d") + except Exception: + pass + + # 2. 各数据源最新日期 + source_dates = {} + for name, rel_path in SOURCE_TABLES.items(): + source_dates[name] = _get_latest_partition_date(root / rel_path) + + # 3. 6PM 校验 + past_cutoff = _is_past_cutoff(root) + + # 4. 是否需要更新 + should_update = False + reason = "" + + if not exists: + should_update = True + reason = "宽表文件不存在" + elif wide_table_date: + # 任一数据源日期 > 宽表日期 + newer_sources = [ + f"{k}({v})" for k, v in source_dates.items() + if v and v > wide_table_date + ] + if newer_sources: + should_update = True + reason = f"数据源更新: {', '.join(newer_sources)}" + else: + reason = "宽表已是最新" + + return { + "exists": exists, + "wide_table_date": wide_table_date, + "source_dates": source_dates, + "should_update": should_update, + "reason": reason, + "past_cutoff": past_cutoff, + } + + +def should_update_wide_table(data_dir: Optional[str] = None) -> Tuple[bool, str]: + """简化版,供 startup 使用。""" + status = get_wide_table_status(data_dir) + if not status["exists"]: + return True, "宽表文件不存在,请在数据中心页面构建" + if status["should_update"]: + cutoff_note = "" if status["past_cutoff"] else "(需等待 18:00 后)" + return True, f"数据源有更新{cutoff_note}: {status['reason']}" + return False, f"宽表已是最新 ({status['wide_table_date']})" diff --git a/app/static/js/data_jobs.js b/app/static/js/data_jobs.js index 1c25cef87..c4af90d40 100644 --- a/app/static/js/data_jobs.js +++ b/app/static/js/data_jobs.js @@ -375,5 +375,101 @@ updateProgressView(null); loadJobTypes(); loadRunHistory(); + + // 大宽表按钮绑定 + var buildBtn = document.getElementById("buildWideTableBtn"); + var refreshWideBtn = document.getElementById("refreshWideTableStatusBtn"); + if (buildBtn) buildBtn.addEventListener("click", submitBuildWideTable); + if (refreshWideBtn) refreshWideBtn.addEventListener("click", loadWideTableStatus); + loadWideTableStatus(); }); + + // ---- 大宽表状态与构建 ---- + + async function loadWideTableStatus() { + var badge = document.getElementById("wideTableStatusBadge"); + var dateEl = document.getElementById("wideTableDate"); + var sourceDatesEl = document.getElementById("wideTableSourceDates"); + var reasonEl = document.getElementById("wideTableUpdateReason"); + var buildBtn = document.getElementById("buildWideTableBtn"); + if (!badge) return; + + badge.className = "badge bg-secondary"; + badge.textContent = "检查中..."; + + try { + var resp = await fetch("/api/data-jobs/wide-table/status"); + var data = await resp.json(); + if (!data.success) { + badge.className = "badge bg-danger"; + badge.textContent = "查询失败"; + return; + } + + var s = data.status; + dateEl.textContent = s.wide_table_date || "文件不存在"; + + var parts = []; + if (s.source_dates) { + for (var table in s.source_dates) { + parts.push(table + ": " + (s.source_dates[table] || "无")); + } + } + sourceDatesEl.textContent = parts.join(" | "); + + if (!s.exists) { + badge.className = "badge bg-danger"; + badge.textContent = "缺失"; + buildBtn.disabled = !s.past_cutoff; + reasonEl.textContent = "宽表文件不存在" + (s.past_cutoff ? ",可以构建" : ",需等待 18:00 后"); + } else if (s.should_update) { + badge.className = "badge bg-warning text-dark"; + badge.textContent = "需更新"; + buildBtn.disabled = !s.past_cutoff; + reasonEl.textContent = s.reason + (s.past_cutoff ? "" : "(需等待 18:00 后)"); + } else { + badge.className = "badge bg-success"; + badge.textContent = "正常"; + buildBtn.disabled = true; + reasonEl.textContent = "宽表已是最新"; + } + } catch (err) { + badge.className = "badge bg-danger"; + badge.textContent = "网络错误"; + } + } + + async function submitBuildWideTable() { + var buildBtn = document.getElementById("buildWideTableBtn"); + var resultBox = document.getElementById("wideTableBuildResult"); + + buildBtn.disabled = true; + resultBox.style.display = "block"; + resultBox.className = "alert alert-info mt-3"; + resultBox.textContent = "正在提交大宽表构建任务..."; + + try { + var resp = await fetch("/api/data-jobs/wide-table/build", { method: "POST" }); + var data = await resp.json(); + + if (data.success) { + resultBox.className = "alert alert-success mt-3"; + resultBox.textContent = "构建任务已提交 (run_id=" + data.run_id + "),请查看下方日频数据中心的进度。"; + currentRunId = data.run_id; + await fetchRunStatus(currentRunId); + startPolling(currentRunId); + loadRunHistory(); + } else { + resultBox.className = "alert alert-danger mt-3"; + resultBox.textContent = "构建失败: " + data.error; + buildBtn.disabled = false; + } + } catch (err) { + resultBox.className = "alert alert-danger mt-3"; + resultBox.textContent = "网络错误: " + err.message; + buildBtn.disabled = false; + } + + setTimeout(loadWideTableStatus, 5000); + } })(); diff --git a/app/tasks/data_jobs_tasks.py b/app/tasks/data_jobs_tasks.py index 0bc168abf..6644eb431 100644 --- a/app/tasks/data_jobs_tasks.py +++ b/app/tasks/data_jobs_tasks.py @@ -58,6 +58,14 @@ def run_data_job(run_id: int): if completed.returncode == 0: store.update_run_status(run, "success", progress=100.0, progress_message="任务执行完成") + # 大宽表构建成功后清除缓存,使后续请求读取新数据 + if run.job_type == "wide_table_builder": + from app.services.data_reader import ParquetDataReader + ParquetDataReader.invalidate_stock_business_cache() + # 同时清除 text2sql 的 SQLite 临时表缓存 + from app.services.text2sql_engine import get_text2sql_engine + engine = get_text2sql_engine() + engine.query_executor.invalidate_cache() store.save_run(run) return {"run_id": run_id, "status": "success"} diff --git a/app/templates/data_management/index.html b/app/templates/data_management/index.html index ba5abe30b..f6eaba894 100644 --- a/app/templates/data_management/index.html +++ b/app/templates/data_management/index.html @@ -182,6 +182,46 @@
正常
+ +
+
+
+
+
大宽表 (stock_business)
+ 检查中... +
+
+
+
+
+ 宽表日期:- +
+
+ 数据源日期: + - +
+
+ 更新状态:- +
+
+ 宽表合并了日线基本指标、技术因子、资金流向和股票基础资料,仅保留最新交易日数据(约 5K 行)。 +
+
+
+ + +
+
+ +
+
+
+
+
diff --git a/app/utils/wide_table_builder.py b/app/utils/wide_table_builder.py new file mode 100644 index 000000000..7292e2be5 --- /dev/null +++ b/app/utils/wide_table_builder.py @@ -0,0 +1,148 @@ +""" +大宽表构建脚本 +从 daily_basic、stk_factor、moneyflow、stock_basic 合并最新交易日数据, +输出 stock_business.parquet(仅保留最新一天)。 + +Usage: + python app/utils/wide_table_builder.py + +Registered as a derived job in DataJobService. +""" + +import sys +import os + +# 支持两种运行方式: +# 1. 作为 data_jobs 子进程(PYTHONPATH 已含项目根目录) +# 2. 从 Flask app context 直接 import +_project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _project_root not in sys.path: + sys.path.insert(0, _project_root) + +import pandas as pd + +from app.services.data_reader import ParquetDataReader +from app.utils.parquet_writer import save_single_parquet + + +# stk_factor 列 → 宽表列名映射 +STK_FACTOR_RENAME = { + "open": "factor_open", + "high": "factor_high", + "low": "factor_low", + "pre_close": "factor_pre_close", + "change": "factor_change", + "pct_change": "factor_pct_change", + "vol": "factor_vol", + "amount": "factor_amount", + "macd_dif": "factor_macd_dif", + "macd_dea": "factor_macd_dea", + "macd": "factor_macd", + "kdj_k": "factor_kdj_k", + "kdj_d": "factor_kdj_d", + "kdj_j": "factor_kdj_j", + "rsi_6": "factor_rsi_6", + "rsi_12": "factor_rsi_12", + "rsi_24": "factor_rsi_24", + "boll_upper": "factor_boll_upper", + "boll_mid": "factor_boll_mid", + "boll_lower": "factor_boll_lower", + "cci": "factor_cci", +} + +# stk_factor 中需要保留但不改名的列 +STK_FACTOR_KEEP = ["ts_code", "adj_factor"] + +# stk_factor 中需要丢弃的列(与 daily_basic 重复 或 不需要) +STK_FACTOR_DROP = { + "close", "trade_date", + # 复权列不需要 + "open_hfq", "open_qfq", "close_hfq", "close_qfq", + "high_hfq", "high_qfq", "low_hfq", "low_qfq", + "pre_close_hfq", "pre_close_qfq", +} + + +def build_wide_table() -> pd.DataFrame: + """读取各源表最新分区,合并为大宽表,返回 DataFrame。""" + reader = ParquetDataReader() + + # 1. daily_basic(主表) + db = reader._read_latest_partition("daily_basic") + if db is None or db.empty: + print("[wide_table_builder] daily_basic 无数据,跳过") + return pd.DataFrame() + db = db.copy() + print(f"[wide_table_builder] daily_basic: {len(db)} 行") + + # 2. stk_factor + sf = reader._read_latest_partition("stk_factor") + if sf is not None and not sf.empty: + sf = sf.copy() + # 丢弃不需要的列 + drop_cols = [c for c in STK_FACTOR_DROP if c in sf.columns] + sf.drop(columns=drop_cols, inplace=True, errors="ignore") + # 重命名 + rename_map = {k: v for k, v in STK_FACTOR_RENAME.items() if k in sf.columns} + sf.rename(columns=rename_map, inplace=True) + # 只保留 ts_code + 重命名后的因子列 + adj_factor + keep = [c for c in STK_FACTOR_KEEP if c in sf.columns] + \ + [v for v in STK_FACTOR_RENAME.values() if v in sf.columns] + sf = sf[keep] + print(f"[wide_table_builder] stk_factor: {len(sf)} 行, 列={len(sf.columns)}") + # left join + result = db.merge(sf, on="ts_code", how="left") + else: + print("[wide_table_builder] stk_factor 无数据,跳过") + result = db + + # 3. moneyflow + mf = reader._read_latest_partition("moneyflow") + if mf is not None and not mf.empty: + mf_cols = [c for c in ["ts_code", "net_mf_amount", "net_mf_vol"] if c in mf.columns] + mf = mf[mf_cols].copy() + mf.rename(columns={ + "net_mf_amount": "moneyflow_net_amount", + "net_mf_vol": "moneyflow_net_vol", + }, inplace=True) + print(f"[wide_table_builder] moneyflow: {len(mf)} 行") + result = result.merge(mf, on="ts_code", how="left") + else: + print("[wide_table_builder] moneyflow 无数据,跳过") + + # 4. stock_basic(静态表) + try: + basic = reader.get_stock_basic() + if basic is not None and not basic.empty: + # 取 ts_code, name, symbol + name_col = "name" if "name" in basic.columns else "stock_name" + sub = basic[["ts_code", name_col, "symbol"]].copy() if "symbol" in basic.columns else \ + basic[["ts_code", name_col]].copy() + sub.rename(columns={name_col: "stock_name"}, inplace=True) + result = result.merge(sub, on="ts_code", how="left") + print(f"[wide_table_builder] stock_basic: {len(sub)} 行") + except Exception as e: + print(f"[wide_table_builder] stock_basic 读取失败: {e}") + + # 5. 派生 year/month/day + if "trade_date" in result.columns: + td = pd.to_datetime(result["trade_date"]) + result["year"] = td.dt.year + result["month"] = td.dt.month + result["day"] = td.dt.day + + print(f"[wide_table_builder] 合并完成: {len(result)} 行, {len(result.columns)} 列") + return result + + +def main(): + df = build_wide_table() + if df is not None and not df.empty: + save_single_parquet(df, "stock_business.parquet") + print(f"[wide_table_builder] 写完成: {len(df)} 行, {len(df.columns)} 列") + else: + print("[wide_table_builder] 无数据可写入") + + +if __name__ == "__main__": + main() diff --git a/run.py b/run.py index 85cf9d536..0fbcd2c7d 100644 --- a/run.py +++ b/run.py @@ -3,10 +3,8 @@ import os from app import create_app -from app.extensions import db from app.extensions import socketio -from sqlalchemy import inspect -from startup_runtime import build_health_report, build_health_summary_lines, build_startup_report +from startup_runtime import build_health_report, build_health_summary_lines, build_startup_report, inspect_parquet_data_assets # 创建Flask应用实例 app = create_app(os.getenv('FLASK_ENV', 'default')) @@ -14,26 +12,14 @@ def inspect_runtime_health(flask_app): with flask_app.app_context(): - existing_tables = set() - non_empty_tables = set() - connected = False - - try: - inspector = inspect(db.engine) - existing_tables = set(inspector.get_table_names()) - connected = True - for table in existing_tables & {"stock_basic", "stock_trade_calendar", "data_job_run"}: - count = db.session.execute(db.text(f"SELECT COUNT(*) FROM {table}")).scalar() - if count and int(count) > 0: - non_empty_tables.add(table) - except Exception: - connected = False + data_dir = flask_app.config.get("DATA_DIR") + connected, existing_assets, non_empty_assets = inspect_parquet_data_assets(data_dir) return build_health_report( flask_app.config, connected=connected, - existing_tables=existing_tables, - non_empty_tables=non_empty_tables, + existing_tables=existing_assets, + non_empty_tables=non_empty_assets, ) if __name__ == '__main__': @@ -43,6 +29,13 @@ def inspect_runtime_health(flask_app): for line in build_health_summary_lines(inspect_runtime_health(app)): print(line) + # 大宽表状态检查 + with app.app_context(): + from app.services.wide_table_status import should_update_wide_table + need_update, reason = should_update_wide_table(app.config.get("DATA_DIR")) + tag = "⚠️" if need_update else "✅" + print(f" {tag} 大宽表: {reason}") + # 开发环境下运行,使用SocketIO socketio.run( app, diff --git a/startup_runtime.py b/startup_runtime.py index e0a9d09d6..c6803fee3 100644 --- a/startup_runtime.py +++ b/startup_runtime.py @@ -9,6 +9,7 @@ "stock_trade_calendar.parquet", "daily_history/daily", "daily_basic/daily", + "stock_business.parquet", ) diff --git a/tests/data_jobs/test_registry.py b/tests/data_jobs/test_registry.py index 655f031a8..fd91cee7d 100644 --- a/tests/data_jobs/test_registry.py +++ b/tests/data_jobs/test_registry.py @@ -21,7 +21,7 @@ def test_registry_visible_jobs_follow_whitelist(): jobs = registry.list_visible_jobs() job_types = [job.job_type for job in jobs] - assert len(job_types) == 8 + assert len(job_types) == 9 assert job_types == [ "trade_calendar", "stock_basic", @@ -31,4 +31,5 @@ def test_registry_visible_jobs_follow_whitelist(): "moneyflow", "stk_factor", "cyq_perf", + "wide_table_builder", ]