diff --git a/qracer/cli.py b/qracer/cli.py index 7c3faf3..7686b92 100644 --- a/qracer/cli.py +++ b/qracer/cli.py @@ -1048,9 +1048,12 @@ def serve(check_interval: int) -> None: from qracer.autonomous import AutonomousMonitor from qracer.watchlist import Watchlist + # Watchlist is shared by the autonomous monitor and the Telegram + # /watchlist + /briefing commands, so build it unconditionally. + watchlist = Watchlist(_user_dir() / "watchlist.json") + autonomous_monitor: AutonomousMonitor | None = None if app_cfg.autonomous_enabled: - watchlist = Watchlist(_user_dir() / "watchlist.json") autonomous_monitor = AutonomousMonitor( watchlist, data_registry, @@ -1059,6 +1062,9 @@ def serve(check_interval: int) -> None: cooldown_minutes=app_cfg.alert_cooldown_minutes, ) + sessions_dir = _user_dir() / "sessions" + reports_dir = _user_dir() / "reports" + server = Server( alert_monitor, task_executor, @@ -1066,6 +1072,10 @@ def serve(check_interval: int) -> None: autonomous_monitor=autonomous_monitor, telegram_poller=telegram_poller, tick_interval=1.0, + watchlist=watchlist, + data_registry=data_registry, + sessions_dir=sessions_dir, + reports_dir=reports_dir, ) def _handle_signal(signum: int, _frame: object) -> None: @@ -1085,7 +1095,14 @@ def _handle_signal(signum: int, _frame: object) -> None: f" cooldown={app_cfg.alert_cooldown_minutes}m" ) if telegram_poller is not None: - click.echo(" Telegram bot: receiving commands (try /help in chat)") + authorised = len(telegram_poller.allowed_chat_ids) + if authorised > 1: + click.echo( + f" Telegram bot: receiving commands ({authorised} authorised chats; " + "try /help in chat)" + ) + else: + click.echo(" Telegram bot: receiving commands (try /help in chat)") click.echo(" Press Ctrl+C to stop.\n") try: diff --git a/qracer/notifications/factory.py b/qracer/notifications/factory.py index 5894245..bd79709 100644 --- a/qracer/notifications/factory.py +++ b/qracer/notifications/factory.py @@ -46,6 +46,11 @@ def build_notification_registry( return registry +def _parse_chat_ids(raw: str) -> list[str]: + """Split a comma-separated chat-id list and drop blanks.""" + return [part.strip() for part in raw.split(",") if part.strip()] + + def build_telegram_poller( credentials: dict[str, str], *, @@ -59,11 +64,28 @@ def build_telegram_poller( The default ``timeout=1`` keeps the long-poll short enough to coexist with the 1-second :class:`~qracer.server.Server` tick; standalone callers can pass a larger value (e.g. 30) for true long-polling. + + ``TELEGRAM_ALLOWED_CHAT_IDS`` (comma-separated, optional) authorises + additional chats — e.g. ``"111,222"`` lets two users talk to the bot. + The primary chat (``TELEGRAM_CHAT_ID``) is always authorised and used as + the default reply target. """ bot_token = credentials.get("TELEGRAM_BOT_TOKEN", "") chat_id = credentials.get("TELEGRAM_CHAT_ID", "") if not bot_token or not chat_id: return None - poller = TelegramBotPoller(bot_token=bot_token, chat_id=chat_id, timeout=timeout) - logger.info("Telegram bot command poller initialised") + allowed = _parse_chat_ids(credentials.get("TELEGRAM_ALLOWED_CHAT_IDS", "")) + poller = TelegramBotPoller( + bot_token=bot_token, + chat_id=chat_id, + allowed_chat_ids=allowed or None, + timeout=timeout, + ) + if len(poller.allowed_chat_ids) > 1: + logger.info( + "Telegram bot command poller initialised (authorised chats: %d)", + len(poller.allowed_chat_ids), + ) + else: + logger.info("Telegram bot command poller initialised") return poller diff --git a/qracer/notifications/telegram_poller.py b/qracer/notifications/telegram_poller.py index 9321f22..9cc4a8c 100644 --- a/qracer/notifications/telegram_poller.py +++ b/qracer/notifications/telegram_poller.py @@ -13,10 +13,12 @@ import asyncio import json import logging +import time import urllib.error import urllib.parse import urllib.request -from dataclasses import dataclass +from collections import deque +from dataclasses import dataclass, field from typing import Any logger = logging.getLogger(__name__) @@ -27,6 +29,12 @@ # truncation suffixes still fit. _DEFAULT_MESSAGE_CHAR_LIMIT = 4000 +# Default rate-limit: 20 commands per chat per 60 seconds. Balances +# responsiveness for normal use against runaway loops or abuse on a shared +# chat. +_DEFAULT_RATE_LIMIT_COMMANDS = 20 +_DEFAULT_RATE_LIMIT_WINDOW_S = 60.0 + @dataclass(frozen=True) class BotCommand: @@ -34,16 +42,22 @@ class BotCommand: Example:: - BotCommand.parse("/analyze AAPL") - # → BotCommand(action="analyze", args=["AAPL"], raw_text="/analyze AAPL") + BotCommand.parse("/analyze AAPL", chat_id="12345") + # → BotCommand(action="analyze", args=["AAPL"], + # raw_text="/analyze AAPL", chat_id="12345") + + ``chat_id`` is the sender's chat — callers can use it as the target of + :meth:`TelegramBotPoller.send_reply` so replies go back to whoever asked + (useful when ``allowed_chat_ids`` authorises more than one chat). """ action: str args: list[str] raw_text: str + chat_id: str = "" @classmethod - def parse(cls, text: str) -> BotCommand | None: + def parse(cls, text: str, chat_id: str = "") -> BotCommand | None: """Parse a Telegram message into a :class:`BotCommand`. Returns ``None`` if the text is not a recognised command (i.e. does @@ -62,24 +76,47 @@ def parse(cls, text: str) -> BotCommand | None: action = parts[0].split("@", 1)[0].lower() if not action: return None - return cls(action=action, args=parts[1:], raw_text=text) + return cls(action=action, args=parts[1:], raw_text=text, chat_id=str(chat_id)) + + +@dataclass +class _RateBucket: + """Sliding-window counter for a single chat.""" + + timestamps: deque[float] = field(default_factory=deque) + + def admit(self, now: float, limit: int, window: float) -> bool: + """Return ``True`` if a new command is within limits at ``now``.""" + cutoff = now - window + while self.timestamps and self.timestamps[0] <= cutoff: + self.timestamps.popleft() + if len(self.timestamps) >= limit: + return False + self.timestamps.append(now) + return True class TelegramBotPoller: """Receive bot commands from Telegram via the ``getUpdates`` long-poll API. Tracks the update offset so messages are never returned twice, filters - messages to those originating from the authorised chat, and parses - incoming text into :class:`BotCommand` objects. + messages to those originating from an authorised chat (``chat_id`` plus + any ``allowed_chat_ids``), parses incoming text into :class:`BotCommand` + objects, and enforces a per-chat sliding-window rate limit. - Replies can be sent back to the same chat via :meth:`send_reply`. + Replies can be sent back to any authorised chat via :meth:`send_reply`; + when ``chat_id`` is omitted the primary chat (``self.chat_id``) is used. Usage:: - poller = TelegramBotPoller(bot_token="...", chat_id="123") + poller = TelegramBotPoller( + bot_token="...", + chat_id="123", + allowed_chat_ids=["123", "456"], + ) commands = await poller.poll() for cmd in commands: - await poller.send_reply(f"Got: {cmd.action}") + await poller.send_reply(f"Got: {cmd.action}", chat_id=cmd.chat_id) """ def __init__( @@ -87,8 +124,11 @@ def __init__( bot_token: str, chat_id: str, *, + allowed_chat_ids: list[str] | None = None, timeout: int = 30, message_char_limit: int = _DEFAULT_MESSAGE_CHAR_LIMIT, + rate_limit_commands: int = _DEFAULT_RATE_LIMIT_COMMANDS, + rate_limit_window_seconds: float = _DEFAULT_RATE_LIMIT_WINDOW_S, ) -> None: if not bot_token: raise ValueError("TELEGRAM_BOT_TOKEN is required but was empty") @@ -96,10 +136,28 @@ def __init__( raise ValueError("TELEGRAM_CHAT_ID is required but was empty") self._bot_token = bot_token self._chat_id = str(chat_id) + + # Authorised senders. Always include the primary chat_id; merge any + # extras while preserving insertion order and dropping blanks. + authorised: list[str] = [self._chat_id] + for extra in allowed_chat_ids or []: + extra_str = str(extra).strip() + if extra_str and extra_str not in authorised: + authorised.append(extra_str) + self._allowed_chat_ids: tuple[str, ...] = tuple(authorised) + self._timeout = max(0, int(timeout)) self._message_char_limit = message_char_limit self._offset: int | None = None + if rate_limit_commands < 0: + raise ValueError("rate_limit_commands must be >= 0") + if rate_limit_window_seconds <= 0: + raise ValueError("rate_limit_window_seconds must be > 0") + self._rate_limit_commands = rate_limit_commands + self._rate_limit_window = rate_limit_window_seconds + self._rate_buckets: dict[str, _RateBucket] = {} + @property def offset(self) -> int | None: """Current update offset (``None`` until the first update arrives).""" @@ -107,16 +165,22 @@ def offset(self) -> int | None: @property def chat_id(self) -> str: - """The authorised chat ID this poller filters by.""" + """The primary chat ID — default target for :meth:`send_reply`.""" return self._chat_id + @property + def allowed_chat_ids(self) -> tuple[str, ...]: + """All chat IDs authorised to send commands (primary first).""" + return self._allowed_chat_ids + async def poll(self) -> list[BotCommand]: """Long-poll Telegram for new commands. Returns a list of :class:`BotCommand` parsed from messages that - arrived from the authorised chat. The offset is advanced past - the highest update ID returned, so subsequent calls only return - new messages. + arrived from an authorised chat. The offset is advanced past the + highest update ID returned, so subsequent calls only return new + messages. Commands that exceed the per-chat rate limit are logged + and dropped. Network and API errors are logged and converted to an empty list — the caller is expected to retry on the next tick. @@ -158,6 +222,7 @@ def _poll_sync(self) -> list[BotCommand]: commands: list[BotCommand] = [] max_update_id = -1 + now = time.monotonic() for update in payload.get("result", []): update_id = update.get("update_id") if isinstance(update_id, int) and update_id > max_update_id: @@ -168,39 +233,74 @@ def _poll_sync(self) -> list[BotCommand]: continue chat = message.get("chat") or {} - if str(chat.get("id")) != self._chat_id: - logger.debug("Ignoring message from unauthorised chat %s", chat.get("id")) + sender_chat_id = str(chat.get("id")) + if sender_chat_id not in self._allowed_chat_ids: + logger.debug( + "Ignoring message from unauthorised chat %s", + sender_chat_id, + ) continue text = message.get("text") if not isinstance(text, str): continue - cmd = BotCommand.parse(text) - if cmd is not None: - commands.append(cmd) + cmd = BotCommand.parse(text, chat_id=sender_chat_id) + if cmd is None: + continue + + if not self._admit(sender_chat_id, now): + logger.warning( + "Rate-limited command from chat %s: /%s", + sender_chat_id, + cmd.action, + ) + continue + commands.append(cmd) if max_update_id >= 0: self._offset = max_update_id + 1 return commands - async def send_reply(self, text: str) -> bool: - """Send a plain-text reply to the authorised chat. - - Long replies are truncated to ``message_char_limit`` characters - with a trailing ``"..."``. Returns ``True`` on HTTP 200. + def _admit(self, chat_id: str, now: float) -> bool: + """Return True when this chat is within the sliding-window limit.""" + if self._rate_limit_commands == 0: + return False + bucket = self._rate_buckets.get(chat_id) + if bucket is None: + bucket = _RateBucket() + self._rate_buckets[chat_id] = bucket + return bucket.admit(now, self._rate_limit_commands, self._rate_limit_window) + + async def send_reply(self, text: str, chat_id: str | None = None) -> bool: + """Send a plain-text reply. + + ``chat_id`` defaults to the primary :attr:`chat_id`; pass an explicit + value to reply to a secondary authorised chat (e.g. the sender's + :attr:`BotCommand.chat_id`). Unknown chat IDs fall back to the + primary chat with a warning log. + + Long replies are truncated to ``message_char_limit`` characters with + a trailing ``"..."``. Returns ``True`` on HTTP 200. """ - return await asyncio.to_thread(self._send_reply_sync, text) + target = chat_id if chat_id else self._chat_id + if target not in self._allowed_chat_ids: + logger.warning( + "send_reply called with unauthorised chat %s; falling back to primary", + target, + ) + target = self._chat_id + return await asyncio.to_thread(self._send_reply_sync, text, target) - def _send_reply_sync(self, text: str) -> bool: + def _send_reply_sync(self, text: str, chat_id: str) -> bool: if not text: return False if len(text) > self._message_char_limit: text = text[: self._message_char_limit - 3] + "..." url = f"{_TELEGRAM_API}/bot{self._bot_token}/sendMessage" - payload = {"chat_id": self._chat_id, "text": text} + payload = {"chat_id": chat_id, "text": text} data = json.dumps(payload).encode() req = urllib.request.Request( url, diff --git a/qracer/server.py b/qracer/server.py index c9fcb8f..e3c715d 100644 --- a/qracer/server.py +++ b/qracer/server.py @@ -8,16 +8,22 @@ import asyncio import logging +import re import time +from pathlib import Path from qracer.alert_monitor import AlertMonitor from qracer.alerts import AlertCondition from qracer.autonomous import AutonomousMonitor +from qracer.conversation.quickpath import generate_briefing +from qracer.data.providers import PriceProvider +from qracer.data.registry import DataRegistry from qracer.notifications.providers import Notification, NotificationCategory from qracer.notifications.registry import NotificationRegistry from qracer.notifications.telegram_poller import BotCommand, TelegramBotPoller from qracer.task_executor import TaskExecutor from qracer.tasks import TaskActionType +from qracer.watchlist import Watchlist logger = logging.getLogger(__name__) @@ -41,6 +47,10 @@ def __init__( autonomous_monitor: AutonomousMonitor | None = None, telegram_poller: TelegramBotPoller | None = None, tick_interval: float = 1.0, + watchlist: Watchlist | None = None, + data_registry: DataRegistry | None = None, + sessions_dir: Path | None = None, + reports_dir: Path | None = None, ) -> None: self._alert_monitor = alert_monitor self._task_executor = task_executor @@ -48,6 +58,10 @@ def __init__( self._notifications = notifications or NotificationRegistry() self._telegram_poller = telegram_poller self._tick_interval = tick_interval + self._watchlist = watchlist + self._data_registry = data_registry + self._sessions_dir = sessions_dir + self._reports_dir = reports_dir self._shutdown_event = asyncio.Event() self._started_at: float | None = None @@ -124,14 +138,14 @@ async def _tick(self) -> None: async def _handle_bot_command(self, command: BotCommand) -> None: """Dispatch an incoming bot command and reply with the result.""" try: - reply = self._dispatch_bot_command(command) + reply = await self._dispatch_bot_command(command) except Exception as exc: logger.exception("Bot command handler failed: /%s", command.action) reply = f"Error handling /{command.action}: {exc}" if reply and self._telegram_poller is not None: - await self._telegram_poller.send_reply(reply) + await self._telegram_poller.send_reply(reply, chat_id=command.chat_id or None) - def _dispatch_bot_command(self, command: BotCommand) -> str: + async def _dispatch_bot_command(self, command: BotCommand) -> str: """Route a :class:`BotCommand` to the matching handler. Handlers return the reply text to send back to the user. Long @@ -150,6 +164,12 @@ def _dispatch_bot_command(self, command: BotCommand) -> str: return self._cmd_tasks() if action == "schedule": return self._cmd_schedule(command.args) + if action == "briefing": + return await self._cmd_briefing() + if action == "watchlist": + return await self._cmd_watchlist() + if action == "thesis": + return self._cmd_thesis() if action in {"analyze", "portfolio"}: return ( f"/{action} is not supported in bot mode yet — " @@ -166,6 +186,9 @@ def _cmd_help() -> str: return ( "qracer bot commands:\n" "/status — server status and uptime\n" + "/briefing — session briefing since the last REPL run\n" + "/watchlist — watchlist tickers with current prices\n" + "/thesis — recent saved trade theses\n" "/alerts — list active price alerts\n" "/alert TICKER above|below PRICE — create a price alert\n" "/tasks — list scheduled tasks\n" @@ -243,6 +266,90 @@ def _cmd_schedule(self, args: list[str]) -> str: return f"Invalid schedule: {exc}" return f"Scheduled task {task.id}: {task.describe()}" + async def _cmd_briefing(self) -> str: + """Compose a session-start-style briefing from current state.""" + if self._watchlist is None or self._data_registry is None or self._sessions_dir is None: + return ( + "Briefing unavailable in this mode. " + "Run `qracer repl` on the host for session-start briefings." + ) + try: + briefing = await generate_briefing( + self._watchlist, + self._data_registry, + self._alert_monitor.store, + self._task_executor.store, + self._sessions_dir, + ) + except Exception: + logger.exception("Telegram /briefing generation failed") + return "Briefing failed — see server logs for details." + if not briefing: + return "No briefing: no prior session on file (or nothing new since)." + return briefing + + async def _cmd_watchlist(self) -> str: + """Return watchlist tickers with current prices.""" + if self._watchlist is None: + return "Watchlist unavailable — not configured on this server." + tickers = self._watchlist.tickers + if not tickers: + return "Watchlist is empty. Add tickers from the qracer REPL with 'watch TICKER'." + if self._data_registry is None: + return "Watchlist:\n" + "\n".join(f" {t}" for t in tickers) + lines = [f"Watchlist ({len(tickers)}):"] + for ticker in tickers: + try: + price = await self._data_registry.async_get_with_fallback( + PriceProvider, "get_price", ticker + ) + except Exception: + logger.debug("Price fetch failed for %s", ticker, exc_info=True) + lines.append(f" {ticker}: price unavailable") + continue + if isinstance(price, (int, float)): + lines.append(f" {ticker}: ${price:,.2f}") + else: + lines.append(f" {ticker}: price unavailable") + return "\n".join(lines) + + def _cmd_thesis(self) -> str: + """Summarise the most recent saved trade-thesis report(s). + + Reads the ``reports_dir`` the REPL writes to (via + :class:`~qracer.conversation.report_exporter.ReportExporter`) and + extracts the Trade-Thesis section of each Markdown report. + """ + if self._reports_dir is None or not self._reports_dir.exists(): + return ( + "No saved theses found. Run the qracer REPL and use " + "`save` after a thesis query to make theses visible here." + ) + try: + md_files = sorted( + (p for p in self._reports_dir.glob("*.md") if p.is_file()), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + except OSError: + logger.debug("Failed to list reports dir", exc_info=True) + return "Thesis lookup failed — see server logs for details." + + entries: list[str] = [] + for path in md_files: + if len(entries) >= 3: + break + summary = _extract_thesis_section(path) + if summary is None: + continue + when = time.strftime("%Y-%m-%d %H:%M", time.localtime(path.stat().st_mtime)) + entries.append(f"[{when}] {path.name}\n{summary}") + + if not entries: + return "No saved theses found in reports directory." + header = f"Recent theses ({len(entries)}):" + return "\n\n".join([header, *entries]) + async def _notify(self, category: NotificationCategory, title: str, body: str) -> None: """Send a notification if any channels are registered.""" if not self._notifications.channels: @@ -255,6 +362,37 @@ def shutdown(self) -> None: self._shutdown_event.set() +_THESIS_HEADING = re.compile(r"^##\s+Trade Thesis\s*$", re.MULTILINE) + + +def _extract_thesis_section(path: Path) -> str | None: + """Return the ``## Trade Thesis`` body from a saved Markdown report. + + Returns ``None`` if the file is unreadable or has no thesis section. + The returned text is trimmed to the next ``---`` or ``##`` boundary + and capped at 800 characters to keep Telegram replies compact. + """ + try: + text = path.read_text(encoding="utf-8") + except OSError: + return None + match = _THESIS_HEADING.search(text) + if match is None: + return None + body = text[match.end() :] + + # Stop at the next top-level section or horizontal rule. + stop = len(body) + for marker in ("\n## ", "\n---"): + idx = body.find(marker) + if idx >= 0 and idx < stop: + stop = idx + body = body[:stop].strip() + if len(body) > 800: + body = body[:797] + "..." + return body or None + + def _format_duration(seconds: float) -> str: """Format a duration in seconds as ``"1h 23m 45s"`` (omitting empty units).""" seconds = max(0, int(seconds)) diff --git a/tests/notifications/test_factory.py b/tests/notifications/test_factory.py index d7499c6..ef926da 100644 --- a/tests/notifications/test_factory.py +++ b/tests/notifications/test_factory.py @@ -1,6 +1,9 @@ """Tests for the notification factory.""" -from qracer.notifications.factory import build_notification_registry +from qracer.notifications.factory import ( + build_notification_registry, + build_telegram_poller, +) class TestBuildNotificationRegistry: @@ -27,3 +30,53 @@ def test_empty_string_values_are_skipped(self): creds = {"TELEGRAM_BOT_TOKEN": "", "TELEGRAM_CHAT_ID": ""} reg = build_notification_registry(creds) assert reg.channels == [] + + +class TestBuildTelegramPoller: + def test_returns_none_without_credentials(self): + assert build_telegram_poller({}) is None + + def test_returns_none_without_bot_token(self): + assert build_telegram_poller({"TELEGRAM_CHAT_ID": "1"}) is None + + def test_returns_none_without_chat_id(self): + assert build_telegram_poller({"TELEGRAM_BOT_TOKEN": "tok"}) is None + + def test_primary_chat_authorised_by_default(self): + poller = build_telegram_poller({"TELEGRAM_BOT_TOKEN": "tok", "TELEGRAM_CHAT_ID": "1"}) + assert poller is not None + assert poller.allowed_chat_ids == ("1",) + + def test_allowed_chat_ids_parsed_from_comma_list(self): + poller = build_telegram_poller( + { + "TELEGRAM_BOT_TOKEN": "tok", + "TELEGRAM_CHAT_ID": "1", + "TELEGRAM_ALLOWED_CHAT_IDS": "2, 3 ,4", + } + ) + assert poller is not None + # Whitespace trimmed, primary still present, duplicates deduped. + assert poller.allowed_chat_ids == ("1", "2", "3", "4") + + def test_allowed_chat_ids_blanks_ignored(self): + poller = build_telegram_poller( + { + "TELEGRAM_BOT_TOKEN": "tok", + "TELEGRAM_CHAT_ID": "1", + "TELEGRAM_ALLOWED_CHAT_IDS": ",, ,", + } + ) + assert poller is not None + assert poller.allowed_chat_ids == ("1",) + + def test_allowed_chat_ids_dedupes_primary(self): + poller = build_telegram_poller( + { + "TELEGRAM_BOT_TOKEN": "tok", + "TELEGRAM_CHAT_ID": "1", + "TELEGRAM_ALLOWED_CHAT_IDS": "1,2", + } + ) + assert poller is not None + assert poller.allowed_chat_ids == ("1", "2") diff --git a/tests/notifications/test_telegram_poller.py b/tests/notifications/test_telegram_poller.py index fd0a713..9a6f39f 100644 --- a/tests/notifications/test_telegram_poller.py +++ b/tests/notifications/test_telegram_poller.py @@ -334,3 +334,250 @@ async def test_send_reply_url_error_returns_false(self, poller: TelegramBotPolle with patch(target, side_effect=urllib.error.URLError("offline")): ok = await poller.send_reply("hello") assert ok is False + + +# --------------------------------------------------------------------------- +# allowed_chat_ids — multi-chat auth +# --------------------------------------------------------------------------- + + +class TestAllowedChatIds: + def test_defaults_to_primary_only(self) -> None: + poller = TelegramBotPoller(bot_token="tok", chat_id="999") + assert poller.allowed_chat_ids == ("999",) + + def test_primary_always_authorised(self) -> None: + poller = TelegramBotPoller(bot_token="tok", chat_id="999", allowed_chat_ids=["1"]) + assert "999" in poller.allowed_chat_ids + assert "1" in poller.allowed_chat_ids + + def test_duplicates_deduped(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", chat_id="999", allowed_chat_ids=["999", "1", "1"] + ) + assert poller.allowed_chat_ids == ("999", "1") + + def test_blank_entries_dropped(self) -> None: + poller = TelegramBotPoller(bot_token="tok", chat_id="999", allowed_chat_ids=["", " ", "1"]) + assert poller.allowed_chat_ids == ("999", "1") + + async def test_poll_accepts_secondary_chat(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + allowed_chat_ids=["42"], + timeout=1, + ) + payload = { + "ok": True, + "result": [ + _make_update(1, 999, "/status"), + _make_update(2, 42, "/alerts"), + _make_update(3, 777, "/leak"), # unauthorised + ], + } + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=_mock_response(payload)): + commands = await poller.poll() + + actions = [c.action for c in commands] + assert actions == ["status", "alerts"] + chats = [c.chat_id for c in commands] + assert chats == ["999", "42"] + + async def test_send_reply_routes_to_secondary_chat(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + allowed_chat_ids=["42"], + timeout=1, + ) + resp = MagicMock() + resp.status = 200 + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=resp) as mock: + ok = await poller.send_reply("hi", chat_id="42") + assert ok is True + body = json.loads(mock.call_args[0][0].data) + assert body["chat_id"] == "42" + + async def test_send_reply_unauthorised_chat_falls_back_to_primary(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + allowed_chat_ids=["42"], + timeout=1, + ) + resp = MagicMock() + resp.status = 200 + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=resp) as mock: + ok = await poller.send_reply("hi", chat_id="8888") + assert ok is True + body = json.loads(mock.call_args[0][0].data) + assert body["chat_id"] == "999" + + async def test_send_reply_default_uses_primary(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + allowed_chat_ids=["42"], + timeout=1, + ) + resp = MagicMock() + resp.status = 200 + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=resp) as mock: + await poller.send_reply("hi") + body = json.loads(mock.call_args[0][0].data) + assert body["chat_id"] == "999" + + +# --------------------------------------------------------------------------- +# Rate limiting +# --------------------------------------------------------------------------- + + +class TestRateLimit: + def test_rate_limit_commands_must_be_non_negative(self) -> None: + with pytest.raises(ValueError, match="rate_limit_commands"): + TelegramBotPoller(bot_token="tok", chat_id="999", rate_limit_commands=-1) + + def test_rate_limit_window_must_be_positive(self) -> None: + with pytest.raises(ValueError, match="rate_limit_window_seconds"): + TelegramBotPoller(bot_token="tok", chat_id="999", rate_limit_window_seconds=0) + + async def test_poll_drops_commands_over_limit(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + timeout=1, + rate_limit_commands=2, + rate_limit_window_seconds=60.0, + ) + payload = { + "ok": True, + "result": [ + _make_update(1, 999, "/status"), + _make_update(2, 999, "/alerts"), + _make_update(3, 999, "/tasks"), # dropped + _make_update(4, 999, "/status"), # dropped + ], + } + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=_mock_response(payload)): + commands = await poller.poll() + + actions = [c.action for c in commands] + assert actions == ["status", "alerts"] + # Offset still advances past the dropped updates so we don't + # re-fetch them on the next poll. + assert poller.offset == 5 + + async def test_rate_limit_is_per_chat(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + allowed_chat_ids=["42"], + timeout=1, + rate_limit_commands=1, + rate_limit_window_seconds=60.0, + ) + payload = { + "ok": True, + "result": [ + _make_update(1, 999, "/status"), + _make_update(2, 42, "/status"), # different chat, own bucket + _make_update(3, 999, "/alerts"), # dropped + _make_update(4, 42, "/alerts"), # dropped + ], + } + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=_mock_response(payload)): + commands = await poller.poll() + + chats = [(c.chat_id, c.action) for c in commands] + assert chats == [("999", "status"), ("42", "status")] + + async def test_rate_limit_window_expires(self) -> None: + import qracer.notifications.telegram_poller as mod + + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + timeout=1, + rate_limit_commands=1, + rate_limit_window_seconds=60.0, + ) + payload_first = { + "ok": True, + "result": [_make_update(1, 999, "/status")], + } + payload_second = { + "ok": True, + "result": [_make_update(2, 999, "/alerts")], + } + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + + # First poll at t=0 — admitted. + with ( + patch.object(mod.time, "monotonic", return_value=0.0), + patch(target, return_value=_mock_response(payload_first)), + ): + first = await poller.poll() + assert len(first) == 1 + + # Second poll at t=120s (> window) — the earlier timestamp has aged + # out and the new command is admitted again. + with ( + patch.object(mod.time, "monotonic", return_value=120.0), + patch(target, return_value=_mock_response(payload_second)), + ): + second = await poller.poll() + assert len(second) == 1 + assert second[0].action == "alerts" + + async def test_rate_limit_zero_blocks_everything(self) -> None: + poller = TelegramBotPoller( + bot_token="tok", + chat_id="999", + timeout=1, + rate_limit_commands=0, + rate_limit_window_seconds=60.0, + ) + payload = { + "ok": True, + "result": [_make_update(1, 999, "/status")], + } + target = "qracer.notifications.telegram_poller.urllib.request.urlopen" + with patch(target, return_value=_mock_response(payload)): + commands = await poller.poll() + assert commands == [] + + +# --------------------------------------------------------------------------- +# BotCommand.chat_id plumbing +# --------------------------------------------------------------------------- + + +class TestBotCommandChatId: + def test_parse_default_blank(self) -> None: + cmd = BotCommand.parse("/status") + assert cmd is not None + assert cmd.chat_id == "" + + def test_parse_records_chat_id(self) -> None: + cmd = BotCommand.parse("/status", chat_id="12345") + assert cmd is not None + assert cmd.chat_id == "12345" + + def test_parse_coerces_chat_id_to_str(self) -> None: + cmd = BotCommand.parse("/status", chat_id=12345) # type: ignore[arg-type] + assert cmd is not None + assert cmd.chat_id == "12345" diff --git a/tests/test_server.py b/tests/test_server.py index ded6641..1c95416 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from qracer.alerts import Alert, AlertCondition from qracer.notifications.telegram_poller import BotCommand @@ -233,130 +233,132 @@ def _server(monitor=None, executor=None, **kwargs) -> Server: **kwargs, ) - def test_help(self) -> None: + async def test_help(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("help", [], "/help")) + out = await server._dispatch_bot_command(BotCommand("help", [], "/help")) assert "/status" in out assert "/alerts" in out assert "/alert" in out assert "/tasks" in out assert "/schedule" in out - def test_start_aliases_help(self) -> None: + async def test_start_aliases_help(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("start", [], "/start")) + out = await server._dispatch_bot_command(BotCommand("start", [], "/start")) assert "/status" in out - def test_unknown_command(self) -> None: + async def test_unknown_command(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("nope", [], "/nope")) + out = await server._dispatch_bot_command(BotCommand("nope", [], "/nope")) assert "Unknown command" in out - def test_analyze_returns_not_supported(self) -> None: + async def test_analyze_returns_not_supported(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("analyze", ["AAPL"], "/analyze AAPL")) + out = await server._dispatch_bot_command(BotCommand("analyze", ["AAPL"], "/analyze AAPL")) assert "not supported" in out.lower() - def test_portfolio_returns_not_supported(self) -> None: + async def test_portfolio_returns_not_supported(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("portfolio", [], "/portfolio")) + out = await server._dispatch_bot_command(BotCommand("portfolio", [], "/portfolio")) assert "not supported" in out.lower() - def test_status(self) -> None: + async def test_status(self) -> None: notifications = MagicMock() notifications.channels = ["telegram"] server = self._server(notifications=notifications) - out = server._dispatch_bot_command(BotCommand("status", [], "/status")) + out = await server._dispatch_bot_command(BotCommand("status", [], "/status")) assert "uptime" in out assert "telegram" in out assert "autonomous: off" in out - def test_alerts_empty(self) -> None: + async def test_alerts_empty(self) -> None: monitor = _make_monitor() monitor.store.get_active.return_value = [] server = self._server(monitor=monitor) - out = server._dispatch_bot_command(BotCommand("alerts", [], "/alerts")) + out = await server._dispatch_bot_command(BotCommand("alerts", [], "/alerts")) assert out == "No active alerts." - def test_alerts_lists_each(self) -> None: + async def test_alerts_lists_each(self) -> None: monitor = _make_monitor() monitor.store.get_active.return_value = [ _alert("a1", "AAPL", 200), _alert("b2", "MSFT", 410), ] server = self._server(monitor=monitor) - out = server._dispatch_bot_command(BotCommand("alerts", [], "/alerts")) + out = await server._dispatch_bot_command(BotCommand("alerts", [], "/alerts")) assert "a1" in out assert "AAPL" in out assert "b2" in out assert "MSFT" in out - def test_create_alert_validates_args(self) -> None: + async def test_create_alert_validates_args(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("alert", ["AAPL"], "/alert AAPL")) + out = await server._dispatch_bot_command(BotCommand("alert", ["AAPL"], "/alert AAPL")) assert "Usage" in out - def test_create_alert_rejects_unknown_condition(self) -> None: + async def test_create_alert_rejects_unknown_condition(self) -> None: server = self._server() - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand("alert", ["AAPL", "near", "200"], "/alert AAPL near 200") ) assert "Unknown condition" in out - def test_create_alert_rejects_change_pct(self) -> None: + async def test_create_alert_rejects_change_pct(self) -> None: server = self._server() - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand("alert", ["AAPL", "change_pct", "5"], "/alert AAPL change_pct 5") ) assert "change_pct" in out assert "CLI" in out - def test_create_alert_rejects_invalid_price(self) -> None: + async def test_create_alert_rejects_invalid_price(self) -> None: server = self._server() - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand("alert", ["AAPL", "above", "abc"], "/alert AAPL above abc") ) assert "Invalid price" in out - def test_create_alert_persists(self) -> None: + async def test_create_alert_persists(self) -> None: monitor = _make_monitor() created = _alert("xx", "AAPL", 200) monitor.store.create.return_value = created server = self._server(monitor=monitor) - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand("alert", ["AAPL", "above", "200"], "/alert AAPL above 200") ) monitor.store.create.assert_called_once_with("AAPL", AlertCondition.ABOVE, 200.0) assert "Created alert xx" in out - def test_tasks_empty(self) -> None: + async def test_tasks_empty(self) -> None: executor = _make_executor() executor.store.get_active.return_value = [] server = self._server(executor=executor) - out = server._dispatch_bot_command(BotCommand("tasks", [], "/tasks")) + out = await server._dispatch_bot_command(BotCommand("tasks", [], "/tasks")) assert out == "No scheduled tasks." - def test_tasks_lists_each(self) -> None: + async def test_tasks_lists_each(self) -> None: executor = _make_executor() executor.store.get_active.return_value = [ _task("t1", "AAPL"), _task("t2", "MSFT", schedule="daily 09:30"), ] server = self._server(executor=executor) - out = server._dispatch_bot_command(BotCommand("tasks", [], "/tasks")) + out = await server._dispatch_bot_command(BotCommand("tasks", [], "/tasks")) assert "t1" in out assert "AAPL" in out assert "t2" in out assert "daily 09:30" in out - def test_schedule_validates_args(self) -> None: + async def test_schedule_validates_args(self) -> None: server = self._server() - out = server._dispatch_bot_command(BotCommand("schedule", ["analyze"], "/schedule analyze")) + out = await server._dispatch_bot_command( + BotCommand("schedule", ["analyze"], "/schedule analyze") + ) assert "Usage" in out - def test_schedule_rejects_unknown_action(self) -> None: + async def test_schedule_rejects_unknown_action(self) -> None: server = self._server() - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand( "schedule", ["foo", "AAPL", "every", "1h"], @@ -365,11 +367,11 @@ def test_schedule_rejects_unknown_action(self) -> None: ) assert "Unknown action" in out - def test_schedule_rejects_invalid_spec(self) -> None: + async def test_schedule_rejects_invalid_spec(self) -> None: executor = _make_executor() executor.store.create.side_effect = ValueError("bad spec") server = self._server(executor=executor) - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand( "schedule", ["analyze", "AAPL", "tomorrow"], @@ -379,11 +381,11 @@ def test_schedule_rejects_invalid_spec(self) -> None: assert "Invalid schedule" in out assert "bad spec" in out - def test_schedule_creates_task(self) -> None: + async def test_schedule_creates_task(self) -> None: executor = _make_executor() executor.store.create.return_value = _task("nn", "AAPL") server = self._server(executor=executor) - out = server._dispatch_bot_command( + out = await server._dispatch_bot_command( BotCommand( "schedule", ["analyze", "aapl", "every", "1h"], @@ -396,6 +398,232 @@ def test_schedule_creates_task(self) -> None: assert "Scheduled task nn" in out +class TestNewBotCommands: + """Tests for the /briefing, /watchlist, and /thesis handlers.""" + + @staticmethod + def _server(**kwargs) -> Server: + return Server(_make_monitor(), _make_executor(), **kwargs) + + # ---- /briefing ---- + + async def test_briefing_missing_deps_returns_hint(self) -> None: + server = self._server() + out = await server._dispatch_bot_command(BotCommand("briefing", [], "/briefing")) + assert "Briefing unavailable" in out + + async def test_briefing_no_prior_session_returns_hint(self, tmp_path) -> None: + import qracer.server as server_mod + + watchlist = MagicMock() + data_registry = MagicMock() + server = self._server( + watchlist=watchlist, + data_registry=data_registry, + sessions_dir=tmp_path, + ) + with patch.object(server_mod, "generate_briefing", AsyncMock(return_value=None)): + out = await server._dispatch_bot_command(BotCommand("briefing", [], "/briefing")) + assert "No briefing" in out + + async def test_briefing_returns_briefing_text(self, tmp_path) -> None: + import qracer.server as server_mod + + watchlist = MagicMock() + data_registry = MagicMock() + server = self._server( + watchlist=watchlist, + data_registry=data_registry, + sessions_dir=tmp_path, + ) + with patch.object( + server_mod, + "generate_briefing", + AsyncMock(return_value="Session Briefing\n AAPL: $200"), + ): + out = await server._dispatch_bot_command(BotCommand("briefing", [], "/briefing")) + assert "Session Briefing" in out + assert "AAPL" in out + + async def test_briefing_failure_is_reported(self, tmp_path) -> None: + import qracer.server as server_mod + + watchlist = MagicMock() + data_registry = MagicMock() + server = self._server( + watchlist=watchlist, + data_registry=data_registry, + sessions_dir=tmp_path, + ) + with patch.object( + server_mod, + "generate_briefing", + AsyncMock(side_effect=RuntimeError("boom")), + ): + out = await server._dispatch_bot_command(BotCommand("briefing", [], "/briefing")) + assert "Briefing failed" in out + + # ---- /watchlist ---- + + async def test_watchlist_unconfigured(self) -> None: + server = self._server() + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "unavailable" in out.lower() + + async def test_watchlist_empty(self) -> None: + watchlist = MagicMock() + watchlist.tickers = [] + server = self._server(watchlist=watchlist) + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "empty" in out.lower() + + async def test_watchlist_no_data_registry_shows_tickers_only(self) -> None: + watchlist = MagicMock() + watchlist.tickers = ["AAPL", "NVDA"] + server = self._server(watchlist=watchlist) + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "AAPL" in out + assert "NVDA" in out + assert "$" not in out # no price data + + async def test_watchlist_with_prices(self) -> None: + watchlist = MagicMock() + watchlist.tickers = ["AAPL", "NVDA"] + data_registry = MagicMock() + data_registry.async_get_with_fallback = AsyncMock(side_effect=[200.0, 1250.5]) + server = self._server(watchlist=watchlist, data_registry=data_registry) + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "AAPL: $200.00" in out + assert "NVDA: $1,250.50" in out + + async def test_watchlist_handles_price_failures(self) -> None: + watchlist = MagicMock() + watchlist.tickers = ["AAPL", "BAD"] + data_registry = MagicMock() + data_registry.async_get_with_fallback = AsyncMock( + side_effect=[200.0, RuntimeError("no feed")] + ) + server = self._server(watchlist=watchlist, data_registry=data_registry) + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "AAPL: $200.00" in out + assert "BAD: price unavailable" in out + + async def test_watchlist_handles_non_numeric_price(self) -> None: + watchlist = MagicMock() + watchlist.tickers = ["FOO"] + data_registry = MagicMock() + data_registry.async_get_with_fallback = AsyncMock(return_value=None) + server = self._server(watchlist=watchlist, data_registry=data_registry) + out = await server._dispatch_bot_command(BotCommand("watchlist", [], "/watchlist")) + assert "FOO: price unavailable" in out + + # ---- /thesis ---- + + async def test_thesis_no_reports_dir(self) -> None: + server = self._server() + out = await server._dispatch_bot_command(BotCommand("thesis", [], "/thesis")) + assert "No saved theses" in out + + async def test_thesis_empty_reports_dir(self, tmp_path) -> None: + server = self._server(reports_dir=tmp_path) + out = await server._dispatch_bot_command(BotCommand("thesis", [], "/thesis")) + assert "No saved theses" in out + + async def test_thesis_skips_reports_without_thesis_section(self, tmp_path) -> None: + (tmp_path / "notes.md").write_text("# Hello\n\nJust a note.\n") + server = self._server(reports_dir=tmp_path) + out = await server._dispatch_bot_command(BotCommand("thesis", [], "/thesis")) + assert "No saved theses" in out + + async def test_thesis_lists_recent_saved_reports(self, tmp_path) -> None: + report = tmp_path / "AAPL-2026-04-15.md" + report.write_text( + "# Analysis Report: AAPL\n\n" + "---\n\n" + "## Trade Thesis\n\n" + "- **Ticker:** AAPL\n" + "- **Entry Zone:** $175.00 – $180.00\n" + "- **Target Price:** $200.00\n" + "- **Stop Loss:** $165.00\n" + "\nLong AAPL on AI earnings.\n\n" + "---\n\n" + "## Data Sources\n\n- news\n" + ) + server = self._server(reports_dir=tmp_path) + out = await server._dispatch_bot_command(BotCommand("thesis", [], "/thesis")) + assert "Recent theses" in out + assert "AAPL-2026-04-15.md" in out + assert "Entry Zone" in out + # The "Data Sources" section should not leak into the thesis body. + assert "Data Sources" not in out + + async def test_thesis_caps_at_three_entries(self, tmp_path) -> None: + import time as _t + + body = "# Report\n\n## Trade Thesis\n\nSome thesis body.\n\n---\n" + for i in range(5): + p = tmp_path / f"T{i}.md" + p.write_text(body) + # Ensure distinct mtimes for deterministic ordering. + mtime = 1_700_000_000 + i + import os + + os.utime(p, (mtime, mtime)) + _ = _t # silence unused-import in case of future refactors + + server = self._server(reports_dir=tmp_path) + out = await server._dispatch_bot_command(BotCommand("thesis", [], "/thesis")) + # Three most-recent (T4, T3, T2), oldest two excluded. + assert "T4.md" in out + assert "T3.md" in out + assert "T2.md" in out + assert "T0.md" not in out + assert "T1.md" not in out + + +class TestChatIdRoutedReplies: + async def test_reply_goes_back_to_sender_chat_id(self) -> None: + monitor = _make_monitor() + monitor.store.get_active.return_value = [] + executor = _make_executor() + poller = _make_poller() + poller.poll = AsyncMock( + return_value=[ + BotCommand( + action="alerts", + args=[], + raw_text="/alerts", + chat_id="42", + ) + ] + ) + server = Server(monitor, executor, telegram_poller=poller) + await server._tick() + + poller.send_reply.assert_awaited_once() + args, kwargs = poller.send_reply.await_args + # Reply text is first positional, chat_id is passed by keyword. + assert kwargs.get("chat_id") == "42" + + async def test_reply_chat_id_is_none_when_command_has_no_chat_id( + self, + ) -> None: + monitor = _make_monitor() + monitor.store.get_active.return_value = [] + executor = _make_executor() + poller = _make_poller() + poller.poll = AsyncMock( + return_value=[BotCommand(action="alerts", args=[], raw_text="/alerts")] + ) + server = Server(monitor, executor, telegram_poller=poller) + await server._tick() + + poller.send_reply.assert_awaited_once() + _, kwargs = poller.send_reply.await_args + # Falls back to the poller's primary chat. + assert kwargs.get("chat_id") is None + + class TestFormatDuration: def test_zero(self) -> None: assert _format_duration(0) == "0s"