Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/api/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flask import Blueprint, request, jsonify
from app.services.websocket_push_service import push_service
from app.websocket.websocket_events import get_connection_stats
from datetime import datetime
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,7 +193,7 @@ def test_connection():
test_data = {
'type': 'test',
'message': 'WebSocket连接测试',
'timestamp': push_service.get_push_status()['connection_stats']
'timestamp': datetime.now().isoformat()
}

push_service.trigger_immediate_push('monitor', test_data)
Expand Down
130 changes: 73 additions & 57 deletions app/services/websocket_push_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
import json
import pandas as pd

from app.extensions import db
from app.models.realtime_indicator import RealtimeIndicator
Expand Down Expand Up @@ -43,7 +44,6 @@ def __init__(self):
self.event_store = ParquetEventStore()

self.is_running = False
self.push_thread = None
self.push_interval = 30 # 推送间隔(秒)

# 推送配置
Expand All @@ -65,43 +65,42 @@ def start_push_service(self):
if self.is_running:
logger.warning("推送服务已在运行")
return

self.is_running = True
self.push_thread = threading.Thread(target=self._push_loop, daemon=True)
self.push_thread.start()
from app.extensions import socketio
socketio.start_background_task(target=self._push_loop)
logger.info("WebSocket推送服务已启动")

def stop_push_service(self):
"""停止推送服务"""
self.is_running = False
if self.push_thread:
self.push_thread.join(timeout=5)
logger.info("WebSocket推送服务已停止")

def _push_loop(self):
"""推送循环"""
from app.extensions import socketio as _sio
while self.is_running:
try:
current_time = datetime.now()

# 检查各类数据是否需要推送
for data_type, config in self.push_config.items():
if not config['enabled']:
continue

last_push = self.last_push_times.get(data_type)
if (not last_push or
if (not last_push or
(current_time - last_push).total_seconds() >= config['interval']):

self._push_data_type(data_type)
self.last_push_times[data_type] = current_time

# 等待下一次检查
time.sleep(10) # 每10秒检查一次
_sio.sleep(10) # 使用 socketio.sleep 保证 eventlet 兼容

except Exception as e:
logger.error(f"推送循环错误: {e}")
time.sleep(30) # 出错后等待30秒再继续
_sio.sleep(30)

def _push_data_type(self, data_type: str):
"""推送指定类型的数据"""
Expand All @@ -127,32 +126,40 @@ def _push_data_type(self, data_type: str):
def _push_market_data(self):
"""推送市场数据"""
try:
# 获取活跃股票列表
active_stocks = self.data_manager.get_active_stocks()

for stock in active_stocks[:20]: # 限制推送数量
ts_code = stock['ts_code']

# 获取最新数据
latest_data = self.data_manager.get_latest_data(ts_code, '1min', 1)
if latest_data:
market_data = {
'ts_code': ts_code,
'datetime': latest_data[0]['datetime'],
'open': latest_data[0]['open'],
'high': latest_data[0]['high'],
'low': latest_data[0]['low'],
'close': latest_data[0]['close'],
'volume': latest_data[0]['volume'],
'amount': latest_data[0]['amount'],
'change_pct': self._calculate_change_pct(latest_data[0])
}

broadcast_market_data(ts_code, market_data)
broadcast_market_data('all', market_data) # 广播到全局房间

logger.debug(f"推送市场数据完成,股票数量: {len(active_stocks)}")

# 获取有分钟数据的股票列表
active_stocks = self.data_manager.get_available_minute_stocks()

pushed_count = 0
for ts_code in active_stocks[:20]: # 限制推送数量
# 尝试各周期,优先1min,fallback到更粗粒度
latest_data = pd.DataFrame()
for period in ['1min', '5min', '15min', '30min', '60min']:
latest_data = self.data_manager.get_minute_latest_data(ts_code, period, 2)
if not latest_data.empty:
break

if latest_data.empty:
continue

row = latest_data.iloc[0]
market_data = {
'ts_code': ts_code,
'datetime': str(row.get('datetime', '')),
'open': float(row.get('open', 0)),
'high': float(row.get('high', 0)),
'low': float(row.get('low', 0)),
'close': float(row.get('close', 0)),
'volume': float(row.get('volume', 0)),
'amount': float(row.get('amount', 0)),
'change_pct': self._calculate_change_pct(latest_data)
}

broadcast_market_data(ts_code, market_data)
broadcast_market_data('all', market_data) # 广播到全局房间
pushed_count += 1

logger.info(f"推送市场数据完成,股票数量: {pushed_count}/{len(active_stocks)}")

except Exception as e:
logger.error(f"推送市场数据失败: {e}")

Expand Down Expand Up @@ -240,18 +247,22 @@ def _push_monitor_data(self):
"""推送监控数据"""
try:
# 获取监控数据
anomaly_result = self.monitor_service.detect_anomalies(
change_threshold=5.0, volume_threshold=3.0
)
anomaly_list = anomaly_result.get('data', {}).get('anomalies', []) \
if isinstance(anomaly_result, dict) and anomaly_result.get('success') else []

monitor_data = {
'market_overview': self.monitor_service.get_market_overview(),
'top_movers': self.monitor_service.get_top_movers(limit=10),
'anomalies': self.monitor_service.detect_anomalies(
change_threshold=5.0, volume_threshold=3.0
),
'sentiment': self.monitor_service.calculate_market_sentiment(period_hours=1)
'market_overview': self.data_manager.get_market_overview(),
'top_movers': anomaly_list,
'anomalies': anomaly_list,
'sentiment': self.monitor_service.get_market_sentiment(period_hours=1)
}

broadcast_monitor_data(monitor_data)
logger.debug("推送监控数据完成")
logger.info("推送监控数据完成")

except Exception as e:
logger.error(f"推送监控数据失败: {e}")

Expand Down Expand Up @@ -331,17 +342,22 @@ def _get_news_payload(self) -> List[Dict[str, Any]]:
"""获取可推送的新闻数据。默认不生成模拟新闻。"""
return []

def _calculate_change_pct(self, current_data: Dict) -> float:
"""计算涨跌幅"""
def _calculate_change_pct(self, latest_data) -> float:
"""计算涨跌幅,传入最近2条DataFrame记录"""
try:
# 获取前一个交易日收盘价(简化处理)
prev_close = current_data.get('open', current_data['close'])
current_close = current_data['close']

if latest_data.empty:
return 0.0
current_close = float(latest_data.iloc[0].get('close', 0))
if len(latest_data) >= 2:
prev_close = float(latest_data.iloc[1].get('close', 0))
else:
# 只有1条数据时用开盘价作为近似
prev_close = float(latest_data.iloc[0].get('open', current_close))

if prev_close and prev_close != 0:
return round(((current_close - prev_close) / prev_close) * 100, 2)
return 0.0

except Exception:
return 0.0

Expand Down
4 changes: 2 additions & 2 deletions app/templates/realtime_analysis/websocket_management.html
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ <h5 class="mb-0"><i class="bi bi-journal-text"></i> 消息日志</h5>
}

try {
socket = io('http://127.0.0.1:5001', {
socket = io(window.location.origin, {
transports: ['websocket', 'polling']
});

Expand Down Expand Up @@ -656,7 +656,7 @@ <h6 class="mb-0">${getTypeName(type)}</h6>
<small class="text-muted">${timestamp}</small>
</div>
<div class="small text-muted">
${JSON.stringify(data, null, 2).substring(0, 200)}...
${JSON.stringify(data, null, 2).substring(0, 200)}${JSON.stringify(data).length > 200 ? '...' : ''}
</div>
`;

Expand Down
145 changes: 145 additions & 0 deletions tests/services/test_websocket_push_service_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""WebSocket推送服务生命周期与合约测试"""
from unittest.mock import patch

import pandas as pd

from app.services.websocket_push_service import WebSocketPushService


# ---------- lifecycle ----------

def test_stop_push_service_sets_is_running_false():
"""stop 应将 is_running 置 False"""
service = WebSocketPushService()
service.is_running = True
service.stop_push_service()
assert service.is_running is False


def test_start_push_service_refuses_double_start():
"""运行中再次 start 应 warning 并跳过"""
service = WebSocketPushService()
service.is_running = True
with patch("app.extensions.socketio") as mock_sio:
service.start_push_service()
mock_sio.start_background_task.assert_not_called()


# ---------- _push_market_data ----------

def test_push_market_data_uses_available_stocks_and_fallback_period():
"""应使用 get_available_minute_stocks + get_minute_latest_data"""
service = WebSocketPushService()

fake_df = pd.DataFrame([
{"datetime": "2026-01-01 10:00", "open": 10.0, "high": 10.5,
"low": 9.8, "close": 10.2, "volume": 1000, "amount": 10000},
{"datetime": "2026-01-01 09:59", "open": 10.1, "high": 10.3,
"low": 9.9, "close": 10.0, "volume": 900, "amount": 9000},
])

with patch.object(service.data_manager, "get_available_minute_stocks",
return_value=["000001.SZ"]), \
patch.object(service.data_manager, "get_minute_latest_data",
return_value=fake_df), \
patch("app.services.websocket_push_service.broadcast_market_data") as bc:

service._push_market_data()

assert bc.call_count == 2 # per-stock + 'all'
# First call: specific stock
assert bc.call_args_list[0].args[0] == "000001.SZ"
# Second call: broadcast to 'all'
assert bc.call_args_list[1].args[0] == "all"


def test_push_market_data_skips_empty_data():
"""get_minute_latest_data 返回空 DataFrame 时跳过该股票"""
service = WebSocketPushService()

with patch.object(service.data_manager, "get_available_minute_stocks",
return_value=["000001.SZ"]), \
patch.object(service.data_manager, "get_minute_latest_data",
return_value=pd.DataFrame()), \
patch("app.services.websocket_push_service.broadcast_market_data") as bc:

service._push_market_data()

bc.assert_not_called()


# ---------- _push_monitor_data ----------

def test_push_monitor_data_unpacks_anomaly_list():
"""detect_anomalies 返回的包装 dict 应解包为 anomalies 列表"""
service = WebSocketPushService()

anomaly_response = {
"success": True,
"data": {
"anomalies": [{"ts_code": "000001.SZ", "anomaly_types": ["急涨"]}],
"total_count": 1,
},
}

with patch.object(service.data_manager, "get_market_overview",
return_value={"total_stocks": 100}), \
patch.object(service.monitor_service, "detect_anomalies",
return_value=anomaly_response), \
patch.object(service.monitor_service, "get_market_sentiment",
return_value={"sentiment": "neutral"}), \
patch("app.services.websocket_push_service.broadcast_monitor_data") as bc:

service._push_monitor_data()

payload = bc.call_args.args[0]
# top_movers 和 anomalies 都应该是 list,不是包装 dict
assert isinstance(payload["top_movers"], list)
assert isinstance(payload["anomalies"], list)
assert payload["top_movers"][0]["ts_code"] == "000001.SZ"


def test_push_monitor_data_handles_detect_failure():
"""detect_anomalies 返回 success=False 时应降级为空列表"""
service = WebSocketPushService()

with patch.object(service.data_manager, "get_market_overview",
return_value={}), \
patch.object(service.monitor_service, "detect_anomalies",
return_value={"success": False, "message": "error"}), \
patch.object(service.monitor_service, "get_market_sentiment",
return_value={}), \
patch("app.services.websocket_push_service.broadcast_monitor_data") as bc:

service._push_monitor_data()

payload = bc.call_args.args[0]
assert payload["top_movers"] == []
assert payload["anomalies"] == []


# ---------- _calculate_change_pct ----------

def test_calculate_change_pct_with_two_rows():
service = WebSocketPushService()
df = pd.DataFrame([
{"close": 11.0},
{"close": 10.0},
])
pct = service._calculate_change_pct(df)
assert pct == 10.0


def test_calculate_change_pct_single_row_fallback():
service = WebSocketPushService()
df = pd.DataFrame([
{"close": 11.0, "open": 10.0},
])
pct = service._calculate_change_pct(df)
assert pct == 10.0


def test_calculate_change_pct_empty_df():
service = WebSocketPushService()
pct = service._calculate_change_pct(pd.DataFrame())
assert pct == 0.0
Loading