diff --git a/examples/attachments_from_user.py b/examples/attachments_from_user.py new file mode 100644 index 00000000..fed0532c --- /dev/null +++ b/examples/attachments_from_user.py @@ -0,0 +1,141 @@ +import logging +import os + +from maxo import Bot, Ctx, Dispatcher +from maxo.enums import AttachmentType +from maxo.routing.filters import BaseFilter +from maxo.routing.updates import MessageCreated +from maxo.utils.facades import MessageCreatedFacade +from maxo.utils.long_polling import LongPolling + +bot = Bot(os.environ["TOKEN"]) +dp = Dispatcher() + + +class AttachmentFilter(BaseFilter[MessageCreated]): + def __init__(self, attachment_type: AttachmentType) -> None: + self._attachment_type = attachment_type + + async def __call__(self, update: MessageCreated, ctx: Ctx) -> bool: + for attachment in update.message.body.attachments or []: + if attachment.type == self._attachment_type: + return True + + # ruff: noqa: SIM103 + if self._attachment_type == AttachmentType.TEXT and update.message.body.text: + return True + + return False + + +@dp.message_created(AttachmentFilter(AttachmentType.AUDIO)) +async def audio_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил голосовое сообщение") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.audio], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.CONTACT)) +async def contact_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с контактом") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.contact], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.FILE)) +async def file_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с файлами") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.file], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.IMAGE)) +async def image_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с изображениями") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=update.message.body.photo, + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.LOCATION)) +async def location_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с геопозицией") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.location], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.SHARE)) +async def share_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с предпросмотром ссылки") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.share], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.STICKER)) +async def sticker_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с стикером") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.sticker], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.VIDEO)) +async def video_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил сообщение с видео") + await facade.bot.send_message( + chat_id=facade.chat_id, + attachments=[update.message.body.video], + ) + + +@dp.message_created(AttachmentFilter(AttachmentType.TEXT)) +async def text_handler( + update: MessageCreated, + facade: MessageCreatedFacade, +) -> None: + await facade.answer_text("Получил простое текстовое сообщение") + + +def main() -> None: + logging.basicConfig(level=logging.DEBUG) + LongPolling(dp).run(bot) + + +if __name__ == "__main__": + main() diff --git a/examples/bot_middlewares.py b/examples/bot_middlewares.py new file mode 100644 index 00000000..f3839ded --- /dev/null +++ b/examples/bot_middlewares.py @@ -0,0 +1,102 @@ +import asyncio +import logging +import os +from collections.abc import Sequence + +from unihttp.http.request import HTTPRequest +from unihttp.http.response import HTTPResponse +from unihttp.middlewares.base import AsyncHandler, AsyncMiddleware + +from maxo import Bot +from maxo.backoff import Backoff, BackoffConfig +from maxo.errors import MaxBotNotFoundError + +logger = logging.getLogger(__name__) + +_DEFAULT_BACKOFF_CONFIG = BackoffConfig( + min_delay=1.0, + max_delay=5.0, + factor=1.3, + jitter=0.1, +) + + +class LoggingMiddleware(AsyncMiddleware): + async def handle( + self, + request: HTTPRequest, + next_handler: AsyncHandler, + ) -> HTTPResponse: + logger.info("Request: %s", request) + response = await next_handler(request) + logger.info("Response: %s", response) + return response + + +class RetryMiddleware(AsyncMiddleware): + def __init__( + self, + retries: int = 3, + backoff_config: BackoffConfig = _DEFAULT_BACKOFF_CONFIG, + status_codes: Sequence[int] | None = None, + exceptions: Sequence[type[Exception]] | None = None, + ) -> None: + self._retries = retries + self._backoff_config = backoff_config + self._status_codes = status_codes or (500, 502, 503, 504) + self._exceptions = exceptions or () + + async def handle( + self, + request: HTTPRequest, + next_handler: AsyncHandler, + ) -> HTTPResponse: + attempt = 0 + backoff = Backoff(self._backoff_config) + while True: + try: + response = await next_handler(request) + if ( + response.status_code in self._status_codes + and attempt < self._retries + ): + logger.warning( + "Bad status code %d: %s", + response.status_code, + response, + ) + backoff.next() + await backoff.sleep() + attempt += 1 + continue + except Exception as e: + if ( + self._exceptions + and isinstance(e, tuple(self._exceptions)) + and attempt < self._retries + ): + logger.warning("Bad exception %s", e, exc_info=e) + backoff.next() + await backoff.sleep() + attempt += 1 + continue + raise + + return response + + +async def main() -> None: + bot = Bot( + token=os.environ["TOKEN"], + middleware=[ + LoggingMiddleware(), + RetryMiddleware(exceptions=[MaxBotNotFoundError]), + ], + ) + async with bot.context(): + await bot.send_message(chat_id=-1) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + asyncio.run(main()) diff --git a/examples/text_formatting.py b/examples/text_formatting.py new file mode 100644 index 00000000..2ca590cc --- /dev/null +++ b/examples/text_formatting.py @@ -0,0 +1,85 @@ +import logging +import os + +from maxo import Bot, Dispatcher +from maxo.enums import TextFormat +from maxo.routing.filters import Command +from maxo.routing.updates import MessageCreated +from maxo.utils.facades import MessageCreatedFacade +from maxo.utils.formatting import ( + Bold, + Italic, + Link, + Mention, + Monospaced, + Strikethrough, + Text, + Underline, + as_list, + as_marked_list, + as_numbered_list, +) +from maxo.utils.long_polling import LongPolling + +bot = Bot(os.environ["TOKEN"]) +dp = Dispatcher() + + +@dp.message_created(Command("start")) +async def start_handler(update: MessageCreated, facade: MessageCreatedFacade) -> None: + text = Text( + "Привет, это демонстрация возможностей форматирования текста.", + "\n\n", + Bold("Это жирный текст."), + "\n", + Italic("Это курсивный текст."), + "\n", + Underline("Это подчеркнутый текст."), + "\n", + Strikethrough("Это зачеркнутый текст."), + "\n", + Monospaced("Это моноширинный текст."), + "\n", + Link( + "Это ссылка на библиотеку maxo.", + url="https://github.com/K1rL3s/maxo", + ), + "\n", + Mention("Это упоминание пользователя.", user_id=update.message.sender.id), + "\n\n", + "Вы также можете использовать вспомогательные функции для создания списков:", + "\n\n", + as_list( + "Простой список:", + "Элемент 1", + "Элемент 2", + "Элемент 3", + ), + "\n\n", + as_marked_list( + "Маркированный список:", + "Элемент 1", + "Элемент 2", + "Элемент 3", + ), + "\n\n", + as_numbered_list( + "Нумерованный список:", + "Элемент 1", + "Эleмент 2", + "Элемент 3", + start=4, + ), + ) + + await facade.answer_text(text.as_html(), format=TextFormat.HTML) + await facade.answer_text(text.as_markdown(), format=TextFormat.MARKDOWN) + + +def main() -> None: + logging.basicConfig(level=logging.DEBUG) + LongPolling(dp).run(bot) + + +if __name__ == "__main__": + main() diff --git a/examples/tg_max_one_fsm/README.md b/examples/tg_max_one_fsm/README.md new file mode 100644 index 00000000..cd8141e2 --- /dev/null +++ b/examples/tg_max_one_fsm/README.md @@ -0,0 +1,54 @@ +# Telegram + Max: Одна FSM + +Этот пример показывает, как использовать одну FSM для двух ботов: одного для Telegram (используя aiogram) и одного для Max (используя maxo). + +## Как это работает + +1. **Общая база данных**: Оба бота используют одну и ту же базу данных (`db.sqlite`), которая создается в родительской директории примера. Класс `UserRepo` в `user_repo.py` обрабатывает операции с базой данных. +2. **Связывание пользователей**: + - Когда пользователь впервые взаимодействует с любым из ботов, в таблице `users` создается новая запись с уникальным `shared_id` + - У каждого бота есть команда `/start`, которая показывает `shared_id` + - Чтобы связать аккаунты, нужно отправить команду `/link ` в другого бота + - В вашей системе вы можете использовать другой подход связывания +3. **Общее состояние FSM**: + - Состояние хранится в Redis'е + - `SharedFSMContextMiddleware` создает `FSMContext` с общим ключом, основанным на `shared_id` из базы данных + - Это гарантирует, что у пользователя будет одинаковое состояние FSM в обоих ботах + +## Как запустить + +Из директории `examples`: + +1. **Установите зависимости**: + ```bash + pip install aiogram maxo redis aiosqlite magic-filter + ``` +2. **Запустите Redis**: + ```bash + docker compose -f ./tg_max_one_fsm/docker-compose.yml run --remove-orphans -d -p 6379:6379 redis + ``` +3. **Установите переменные окружения**: + ```bash + export TG_TOKEN="tg_token" + export MAX_TOKEN="max_token" + export REDIS_URL="redis://localhost:6379/0" + ``` +4. **Запустите ботов**: + ```bash + python -m tg_max_one_fsm.tg + python -m tg_max_one_fsm.max + ``` + +## Как использовать + +1. **Telegram бот**: + - Отправьте `/start`, чтобы получить ваш `shared_id` и клавиатуру для смены состояний + - Используйте кнопки, чтобы переключаться между `state1` и `state2` + - Отправьте `/state`, чтобы проверить текущее состояние + - Отправьте `/link `, чтобы связать свой аккаунт с другим аккаунтом +2. **Max бот**: + - Отправьте `/start`, чтобы получить ваш `shared_id` и клавиатуру для смены состояний + - Отправьте команду `/link `, полученную от телеграм-бота + - Теперь у вас общее состояние с телеграм-ботом + - Используйте кнопки для смены состояния + - Отправьте `/state`, чтобы проверить текущее состояние. Вы увидите, что оно совпадает с состоянием в телеграм-боте diff --git a/examples/tg_max_one_fsm/__init__.py b/examples/tg_max_one_fsm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/tg_max_one_fsm/docker-compose.yml b/examples/tg_max_one_fsm/docker-compose.yml new file mode 100644 index 00000000..7b786639 --- /dev/null +++ b/examples/tg_max_one_fsm/docker-compose.yml @@ -0,0 +1,12 @@ +services: + redis: + container_name: maxo-redis-fsm + image: redis:7.4.3-alpine3.21 + restart: unless-stopped + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + command: "redis-server --loglevel warning" diff --git a/examples/tg_max_one_fsm/ids.py b/examples/tg_max_one_fsm/ids.py new file mode 100644 index 00000000..f3d12e73 --- /dev/null +++ b/examples/tg_max_one_fsm/ids.py @@ -0,0 +1,6 @@ +from typing import NewType + +MaxId = NewType("MaxId", int) +TgId = NewType("TgId", int) +DbId = NewType("DbId", int) +SharedId = NewType("SharedId", int) diff --git a/examples/tg_max_one_fsm/max/__init__.py b/examples/tg_max_one_fsm/max/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/tg_max_one_fsm/max/__main__.py b/examples/tg_max_one_fsm/max/__main__.py new file mode 100644 index 00000000..a309de12 --- /dev/null +++ b/examples/tg_max_one_fsm/max/__main__.py @@ -0,0 +1,168 @@ +import asyncio +import logging +import os + +from magic_filter import F + +from maxo import Bot, Dispatcher, Router +from maxo.fsm.context import FSMContext +from maxo.fsm.key_builder import DefaultKeyBuilder +from maxo.fsm.state import State, StatesGroup +from maxo.fsm.storages.redis import RedisStorage +from maxo.integrations.magic_filter import MagicFilter +from maxo.routing.filters import Command, CommandObject, CommandStart +from maxo.routing.updates import MessageCallback, MessageCreated +from maxo.types import CallbackButton +from maxo.utils.facades import MessageCallbackFacade, MessageCreatedFacade +from maxo.utils.long_polling import LongPolling + +from ..ids import SharedId +from ..user_repo import DbUser, UserRepo +from .current_user import CurrentUserMiddleware +from .fsm_context import SharedFSMContextMiddleware + +router = Router() + + +class MyStates(StatesGroup): + state1 = State() + state2 = State() + + +def get_keyboard() -> list[list[CallbackButton]]: + return [ + [ + CallbackButton(text="Перейти в состояние 1", payload="to_state_1"), + CallbackButton(text="Перейти в состояние 2", payload="to_state_2"), + ], + [ + CallbackButton(text="Очистить состояние", payload="clear_state"), + ], + ] + + +@router.message_created(CommandStart()) +async def start_handler( + message: MessageCreated, + facade: MessageCreatedFacade, + fsm_context: FSMContext, + current_user: DbUser, +) -> None: + current_state = await fsm_context.get_state() + + await facade.send_message( + ( + f"Ваш общий ID: {current_user.shared_id}\n\n" + f"Отправьте эту команду боту TG: /link {current_user.shared_id}\n\n" + "Или отправьте эту команду этому боту из другого аккаунта, " + "чтобы связать их: /link " + ), + ) + await facade.send_message( + text=f"Ваше текущее состояние: {current_state}", + keyboard=get_keyboard(), + ) + + +@router.message_created(Command("state")) +async def get_state_handler( + message: MessageCreated, + fsm_context: FSMContext, + facade: MessageCreatedFacade, +) -> None: + current_state = await fsm_context.get_state() + await facade.send_message(text=f"Ваше текущее состояние: {current_state}") + + +@router.message_created(Command("link")) +async def handle_deeplink( + message: MessageCreated, + command: CommandObject, + facade: MessageCreatedFacade, + user_repo: UserRepo, + current_user: DbUser, +) -> None: + try: + shared_id_to_link = SharedId(int(command.args)) + await user_repo.link_accounts( + current_user=current_user, + shared_id_to_link=shared_id_to_link, + ) + await facade.send_message(text="Аккаунты успешно связаны!") + except (IndexError, ValueError): + await facade.send_message(text="Использование: /link ") + + +@router.message_callback(MagicFilter(F.payload == "to_state_1")) +async def to_state_1( + callback: MessageCallback, + fsm_context: FSMContext, + facade: MessageCallbackFacade, +) -> None: + await fsm_context.set_state(MyStates.state1) + current_state = await fsm_context.get_state() + await facade.edit_message( + text=f"Ваше текущее состояние: {current_state}", + keyboard=get_keyboard(), + ) + + +@router.message_callback(MagicFilter(F.payload == "to_state_2")) +async def to_state_2( + callback: MessageCallback, + fsm_context: FSMContext, + facade: MessageCallbackFacade, +) -> None: + await fsm_context.set_state(MyStates.state2) + current_state = await fsm_context.get_state() + await facade.edit_message( + text=f"Ваше текущее состояние: {current_state}", + keyboard=get_keyboard(), + ) + + +@router.message_callback(MagicFilter(F.payload == "clear_state")) +async def clear_state( + callback: MessageCallback, + fsm_context: FSMContext, + facade: MessageCallbackFacade, +) -> None: + await fsm_context.clear() + current_state = await fsm_context.get_state() + await facade.edit_message( + text=f"Состояние очищено. Ваше текущее состояние: {current_state}", + keyboard=get_keyboard(), + ) + + +async def main() -> None: + token = os.environ["MAX_TOKEN"] + redis_url = os.environ["REDIS_URL"] + + user_repo = UserRepo("../db.sqlite") + await user_repo.create_table() + + key_builder = DefaultKeyBuilder(prefix="fsm", separator=":", with_bot_id=False) + storage = RedisStorage.from_url(url=redis_url, key_builder=key_builder) + event_isolation = storage.create_isolation() + dp = Dispatcher( + key_builder=None, # because use custom FSM + storage=None, # because use custom FSM + events_isolation=None, # because use custom FSM + disable_fsm=True, # because use custom FSM + workflow_data={"user_repo": user_repo}, + ) + dp.update.middleware.outer(CurrentUserMiddleware()) + dp.update.middleware.outer( + SharedFSMContextMiddleware(storage=storage, events_isolation=event_isolation), + ) + dp.include(router) + + bot = Bot(token=token) + polling = LongPolling(dp) + await polling.start(bot) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + asyncio.run(main()) diff --git a/examples/tg_max_one_fsm/max/current_user.py b/examples/tg_max_one_fsm/max/current_user.py new file mode 100644 index 00000000..3b01c742 --- /dev/null +++ b/examples/tg_max_one_fsm/max/current_user.py @@ -0,0 +1,30 @@ +from typing import Any + +from maxo import Ctx +from maxo.routing.interfaces import BaseMiddleware, NextMiddleware +from maxo.routing.middlewares.update_context import UPDATE_CONTEXT_KEY +from maxo.routing.signals import MaxoUpdate +from maxo.types import UpdateContext + +from ..ids import MaxId +from ..user_repo import ExternalType, UserRepo + + +class CurrentUserMiddleware(BaseMiddleware[MaxoUpdate[Any]]): + async def __call__( + self, + update: MaxoUpdate, + ctx: Ctx, + next: NextMiddleware[MaxoUpdate[Any]], + ) -> Any: + user_repo: UserRepo = ctx["user_repo"] + update_context: UpdateContext = ctx[UPDATE_CONTEXT_KEY] + + if update_context.user_id: + user = await user_repo.get_or_create_user( + external_id=MaxId(update_context.user_id), + external_type=ExternalType.MAX, + ) + ctx["current_user"] = user + + return await next(ctx) diff --git a/examples/tg_max_one_fsm/max/fsm_context.py b/examples/tg_max_one_fsm/max/fsm_context.py new file mode 100644 index 00000000..bd82036e --- /dev/null +++ b/examples/tg_max_one_fsm/max/fsm_context.py @@ -0,0 +1,55 @@ +from typing import Any + +from maxo.fsm.context import FSMContext +from maxo.fsm.key_builder import StorageKey +from maxo.fsm.storages.base import BaseEventIsolation, BaseStorage +from maxo.routing.ctx import Ctx +from maxo.routing.interfaces.middleware import BaseMiddleware, NextMiddleware +from maxo.routing.signals.update import MaxoUpdate + +from ..user_repo import DbUser + +FSM_STORAGE_KEY = "fsm_storage" +FSM_CONTEXT_KEY = "fsm_context" +FSM_CONTEXT_STATE_KEY = "state" # same as "fsm_context", Подражаение aiogram +RAW_STATE_KEY = "raw_state" + + +class SharedFSMContextMiddleware(BaseMiddleware[MaxoUpdate[Any]]): + __slots__ = ("_events_isolation", "_storage") + + def __init__( + self, + storage: BaseStorage, + events_isolation: BaseEventIsolation, + ) -> None: + self._storage = storage + self._events_isolation = events_isolation + + async def __call__( + self, + update: MaxoUpdate[Any], + ctx: Ctx, + next: NextMiddleware[MaxoUpdate[Any]], + ) -> Any: + ctx[FSM_STORAGE_KEY] = self._storage + + current_user = ctx.get("current_user") + if current_user is None: + return await next(ctx) + + storage_key = self.make_storage_key(user=current_user) + + async with self._events_isolation.lock(key=storage_key): + fsm_context = FSMContext(key=storage_key, storage=self._storage) + ctx[FSM_CONTEXT_KEY] = fsm_context + ctx[FSM_CONTEXT_STATE_KEY] = fsm_context + ctx[RAW_STATE_KEY] = await fsm_context.get_state() + + return await next(ctx) + + def make_storage_key( + self, + user: DbUser, + ) -> StorageKey: + return StorageKey(bot_id=None, chat_id=user.shared_id, user_id=user.shared_id) diff --git a/examples/tg_max_one_fsm/tg/__init__.py b/examples/tg_max_one_fsm/tg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/tg_max_one_fsm/tg/__main__.py b/examples/tg_max_one_fsm/tg/__main__.py new file mode 100644 index 00000000..9513699c --- /dev/null +++ b/examples/tg_max_one_fsm/tg/__main__.py @@ -0,0 +1,145 @@ +import asyncio +import logging +import os + +from aiogram import Bot, Dispatcher, F, Router +from aiogram.filters import Command, CommandObject, CommandStart +from aiogram.fsm.context import FSMContext +from aiogram.fsm.state import State, StatesGroup +from aiogram.fsm.storage.base import DefaultKeyBuilder +from aiogram.fsm.storage.redis import RedisStorage +from aiogram.types import CallbackQuery, InlineKeyboardMarkup, Message +from aiogram.utils.keyboard import InlineKeyboardBuilder + +from ..ids import SharedId +from ..user_repo import DbUser, UserRepo +from .current_user import CurrentUserMiddleware +from .fsm_context import SharedFSMContextMiddleware + +router = Router() + + +class MyStates(StatesGroup): + state1 = State() + state2 = State() + + +def get_keyboard() -> InlineKeyboardMarkup: + builder = InlineKeyboardBuilder() + builder.button(text="Перейти в состояние 1", callback_data="to_state_1") + builder.button(text="Перейти в состояние 2", callback_data="to_state_2") + builder.button(text="Очистить состояние", callback_data="clear_state") + builder.adjust(2) + return builder.as_markup() + + +@router.message(CommandStart()) +async def start_handler( + message: Message, + current_user: DbUser, + state: FSMContext, +) -> None: + current_state = await state.get_state() + + await message.answer( + ( + f"Ваш общий ID: {current_user.shared_id}\n\n" + f"Отправьте эту команду боту Max: /link {current_user.shared_id}\n\n" + "Или отправьте эту команду этому боту из другого аккаунта, " + "чтобы связать их: /link " + ), + ) + await message.answer( + f"Ваше текущее состояние: {current_state}", + reply_markup=get_keyboard(), + ) + + +@router.message(Command("state")) +async def get_state_handler(message: Message, state: FSMContext) -> None: + current_state = await state.get_state() + await message.answer(f"Ваше текущее состояние: {current_state}") + + +@router.message(Command("link")) +async def handle_deeplink( + message: Message, + command: CommandObject, + user_repo: UserRepo, + current_user: DbUser, +) -> None: + try: + shared_id_to_link = SharedId(int(command.args)) + await user_repo.link_accounts( + current_user=current_user, + shared_id_to_link=shared_id_to_link, + ) + await message.answer("Аккаунты успешно связаны!") + except (IndexError, ValueError): + await message.answer("Использование: /link ") + + +@router.callback_query(F.data == "to_state_1") +async def to_state_1(callback: CallbackQuery, state: FSMContext) -> None: + await state.set_state(MyStates.state1) + current_state = await state.get_state() + await callback.message.edit_text( + f"Ваше текущее состояние: {current_state}", + reply_markup=get_keyboard(), + ) + await callback.answer() + + +@router.callback_query(F.data == "to_state_2") +async def to_state_2(callback: CallbackQuery, state: FSMContext) -> None: + await state.set_state(MyStates.state2) + current_state = await state.get_state() + await callback.message.edit_text( + f"Ваше текущее состояние: {current_state}", + reply_markup=get_keyboard(), + ) + await callback.answer() + + +@router.callback_query(F.data == "clear_state") +async def clear_state(callback: CallbackQuery, state: FSMContext) -> None: + await state.clear() + current_state = await state.get_state() + await callback.message.edit_text( + f"Состояние очищено. Ваше текущее состояние: {current_state}", + reply_markup=get_keyboard(), + ) + await callback.answer() + + +async def main() -> None: + token = os.environ["TG_TOKEN"] + redis_url = os.environ["REDIS_URL"] + + user_repo = UserRepo("../db.sqlite") + await user_repo.create_table() + + key_builder = DefaultKeyBuilder(prefix="fsm", separator=":", with_bot_id=False) + storage = RedisStorage.from_url(url=redis_url, key_builder=key_builder) + event_isolation = storage.create_isolation() + dp = Dispatcher( + key_builder=None, # because use custom FSM + storage=None, # because use custom FSM + events_isolation=None, # because use custom FSM + disable_fsm=True, # because use custom FSM + user_repo=user_repo, + ) + dp.update.outer_middleware(CurrentUserMiddleware()) + dp.update.outer_middleware( + SharedFSMContextMiddleware(storage=storage, events_isolation=event_isolation), + ) + dp.include_router(router) + + bot = Bot(token=token) + + await dp.start_polling(bot) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + asyncio.run(main()) diff --git a/examples/tg_max_one_fsm/tg/current_user.py b/examples/tg_max_one_fsm/tg/current_user.py new file mode 100644 index 00000000..415d1861 --- /dev/null +++ b/examples/tg_max_one_fsm/tg/current_user.py @@ -0,0 +1,29 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from aiogram import BaseMiddleware +from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext +from aiogram.types import TelegramObject + +from ..ids import TgId +from ..user_repo import ExternalType, UserRepo + + +class CurrentUserMiddleware(BaseMiddleware): + async def __call__( + self, + handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]], + event: TelegramObject, + data: dict[str, Any], + ) -> Any: + user_repo: UserRepo = data["user_repo"] + event_context: EventContext = data[EVENT_CONTEXT_KEY] + + if event_context.user_id: + user = await user_repo.get_or_create_user( + external_id=TgId(event_context.user_id), + external_type=ExternalType.TG, + ) + data["current_user"] = user + + return await handler(event, data) diff --git a/examples/tg_max_one_fsm/tg/fsm_context.py b/examples/tg_max_one_fsm/tg/fsm_context.py new file mode 100644 index 00000000..7364a879 --- /dev/null +++ b/examples/tg_max_one_fsm/tg/fsm_context.py @@ -0,0 +1,61 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from aiogram.dispatcher.middlewares.base import BaseMiddleware +from aiogram.fsm.context import FSMContext +from aiogram.fsm.storage.base import ( + DEFAULT_DESTINY, + BaseEventIsolation, + BaseStorage, + StorageKey, +) +from aiogram.types import TelegramObject + +from ..user_repo import DbUser + + +class SharedFSMContextMiddleware(BaseMiddleware): + def __init__( + self, + storage: BaseStorage, + events_isolation: BaseEventIsolation, + ) -> None: + self.storage = storage + self.events_isolation = events_isolation + + async def __call__( + self, + handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]], + event: TelegramObject, + data: dict[str, Any], + ) -> Any: + data["fsm_storage"] = self.storage + + current_user: DbUser = data.get("current_user") + if current_user is None: + return await handler(event, data) + + context = self.get_context(current_user) + async with self.events_isolation.lock(key=context.key): + data.update({"state": context, "raw_state": await context.get_state()}) + return await handler(event, data) + + def get_context( + self, + user: DbUser, + ) -> FSMContext: + return FSMContext( + storage=self.storage, + key=StorageKey( + user_id=user.shared_id, + chat_id=user.shared_id, + bot_id=None, + thread_id=None, + business_connection_id=None, + destiny=DEFAULT_DESTINY, + ), + ) + + async def close(self) -> None: + await self.storage.close() + await self.events_isolation.close() diff --git a/examples/tg_max_one_fsm/user_repo.py b/examples/tg_max_one_fsm/user_repo.py new file mode 100644 index 00000000..6806a4fd --- /dev/null +++ b/examples/tg_max_one_fsm/user_repo.py @@ -0,0 +1,132 @@ +import uuid +from dataclasses import dataclass +from enum import Enum + +import aiosqlite + +from .ids import DbId, MaxId, SharedId, TgId + + +class ExternalType(Enum): + TG = "TG" + MAX = "MAX" + + +@dataclass +class DbUser: + id: DbId + external_id: TgId | MaxId + external_type: ExternalType + shared_id: SharedId + + +class UserRepo: + def __init__(self, db_path: str) -> None: + self.db_path = db_path + + async def create_table(self) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + external_id BIGINT, + external_type TEXT, + shared_id TEXT, + UNIQUE(external_id, external_type) + ) + """, + ) + await db.commit() + + async def get_or_create_user( + self, + external_id: TgId | MaxId, + external_type: ExternalType, + ) -> DbUser: + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + ( + "SELECT id, shared_id FROM users " + "WHERE external_id = ? AND external_type = ?" + ), + (external_id, external_type.value), + ) + row = await cursor.fetchone() + if row: + db_id, shared_id = row + return DbUser( + id=DbId(db_id), + external_id=external_id, + external_type=external_type, + shared_id=SharedId(int(shared_id)), + ) + + shared_id = str(uuid.uuid4().int) + cursor = await db.execute( + ( + "INSERT OR IGNORE INTO users " + "(external_id, external_type, shared_id) " + "VALUES (?, ?, ?)" + ), + (external_id, external_type.value, shared_id), + ) + await db.commit() + + if cursor.lastrowid == 0 or cursor.rowcount == 0: + # Concurrent insert happened, re-fetch + cursor = await db.execute( + ( + "SELECT id, shared_id FROM users " + "WHERE external_id = ? AND external_type = ?" + ), + (external_id, external_type.value), + ) + row = await cursor.fetchone() + db_id, shared_id = row + return DbUser( + id=DbId(db_id), + external_id=external_id, + external_type=external_type, + shared_id=SharedId(int(shared_id)), + ) + + db_id = cursor.lastrowid + return DbUser( + id=DbId(db_id), + external_id=external_id, + external_type=external_type, + shared_id=SharedId(int(shared_id)), + ) + + async def link_accounts( + self, + current_user: DbUser, + shared_id_to_link: SharedId, + ) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "UPDATE users SET shared_id = ? WHERE shared_id = ?", + (str(shared_id_to_link), str(current_user.shared_id)), + ) + await db.commit() + + async def get_user_by_shared_id(self, shared_id: SharedId) -> list[DbUser]: + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + ( + "SELECT id, external_id, external_type, shared_id FROM users " + "WHERE shared_id = ?" + ), + (str(shared_id),), + ) + rows = await cursor.fetchall() + return [ + DbUser( + id=DbId(row[0]), + external_id=row[1], + external_type=ExternalType(row[2]), + shared_id=SharedId(int(row[3])), + ) + for row in rows + ] diff --git a/pyproject.toml b/pyproject.toml index 95e4041f..3db9a073 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ magic_filter = ["magic_filter>=1.0.0,<2.0.0"] dishka = ["dishka>=1.0.0,<2.0.0"] redis = ["redis[hiredis]>=5.0.1,<8.0.0"] +fastapi = ["fastapi>=0.128.0"] [dependency-groups] docs = [ @@ -69,6 +70,8 @@ tests = [ "pytest-repeat==0.9.4", "nox==2026.2.9", "nox-uv==0.7.1", + "pytest-aiohttp==1.0.5", + "httpx==0.27.0", ] lint = [ "mypy==1.19.0", @@ -82,7 +85,7 @@ dev = [ { include-group = "lint" }, { include-group = "tests" }, { include-group = "docs" }, - "maxo[magic_filter,dishka,redis]" + "maxo[magic_filter,dishka,redis,fastapi]" ] [project.urls] @@ -176,6 +179,8 @@ ignore = [ "PLR1704", # https://docs.astral.sh/ruff/rules/redefined-argument-from-local/ "PLW2901", # https://docs.astral.sh/ruff/rules/redefined-loop-name/ "ERA001", # https://docs.astral.sh/ruff/rules/commented-out-code/ # Удалить после починки всего + "UP040", # https://docs.astral.sh/ruff/rules/non-pep695-type-alias/ # Адаптикс с ним не работает + "TC003", # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/ ] [tool.ruff.lint.per-file-ignores] @@ -189,12 +194,13 @@ ignore = [ "FBT003", "SLF001", ] -"examples/**/*.py" = ["T201", "D"] +"examples/**/*.py" = ["T201", "D", "TID252"] "src/maxo/types/*.py" = ["E501", "D", "W291", "W293"] "src/maxo/enums/*.py" = ["E501", "D", "W291", "W293"] "src/maxo/bot/methods/*.py" = ["E501", "D", "W291", "W293"] "src/maxo/routing/updates/*.py" = ["E501", "D"] "src/maxo/dialogs/test_tools/**/*.py" = ["S101"] +"src/maxo/serialization.py" = ["PLW0603"] [tool.ruff.lint.isort] case-sensitive = true @@ -211,6 +217,7 @@ line-length = 88 target-version = ["py312"] include = 'src/.*\.py$|tests/.*\.py$|examples/.*\.py$' -[tool.pytest] +[tool.pytest.ini_options] log_cli = true log_cli_level = "DEBUG" +asyncio_mode = "strict" diff --git a/src/maxo/backoff.py b/src/maxo/backoff.py index a0ce0e4e..479beb9b 100644 --- a/src/maxo/backoff.py +++ b/src/maxo/backoff.py @@ -1,3 +1,4 @@ +import asyncio from dataclasses import dataclass from random import normalvariate @@ -73,3 +74,6 @@ def reset(self) -> None: self._counter = 0 self._current_delay = 0.0 self._next_delay = self.min_delay + + async def sleep(self) -> None: + await asyncio.sleep(self.current_delay) diff --git a/src/maxo/bot/api_client.py b/src/maxo/bot/api_client.py index 2ccfe080..4c2cdd84 100644 --- a/src/maxo/bot/api_client.py +++ b/src/maxo/bot/api_client.py @@ -1,30 +1,19 @@ +import io import json -from collections.abc import Callable -from datetime import UTC, datetime -from typing import Any, Never +import pathlib +from collections.abc import AsyncGenerator, Callable +from typing import Any, BinaryIO, Never -from adaptix import Chain, P, Retort, dumper, loader from aiohttp import ClientSession +from anyio import open_file from unihttp.clients.aiohttp import AiohttpAsyncClient from unihttp.http import HTTPResponse -from unihttp.markers import QueryMarker from unihttp.method import BaseMethod from unihttp.middlewares import AsyncMiddleware -from unihttp.serializers.adaptix import DEFAULT_RETORT, for_marker +from unihttp.serialize import RequestDumper, ResponseLoader from maxo import loggers from maxo.__meta__ import __version__ -from maxo._internal._adaptix.concat_provider import concat_provider -from maxo._internal._adaptix.has_tag_provider import has_tag_provider -from maxo.bot.warming_up import WarmingUpType, warming_up_retort -from maxo.enums import ( - AttachmentRequestType, - AttachmentType, - ButtonType, - MarkupElementType, - UpdateType, -) -from maxo.enums.text_format import TextFormat from maxo.errors import ( MaxBotApiError, MaxBotBadRequestError, @@ -36,144 +25,15 @@ MaxBotUnauthorizedError, MaxBotUnknownServerError, ) -from maxo.omit import Omittable -from maxo.routing.updates import ( - BotAddedToChat, - BotRemovedFromChat, - BotStarted, - BotStopped, - ChatTitleChanged, - DialogCleared, - DialogMuted, - DialogRemoved, - DialogUnmuted, - MessageCallback, - MessageCreated, - MessageEdited, - MessageRemoved, - UserAddedToChat, - UserRemovedFromChat, -) -from maxo.types import ( - Attachments, - AudioAttachment, - AudioAttachmentRequest, - CallbackButton, - ContactAttachment, - ContactAttachmentRequest, - EmphasizedMarkup, - FileAttachment, - FileAttachmentRequest, - InlineKeyboardAttachment, - InlineKeyboardAttachmentRequest, - LinkButton, - LinkMarkup, - LocationAttachment, - LocationAttachmentRequest, - MessageButton, - MonospacedMarkup, - OpenAppButton, - PhotoAttachment, - PhotoAttachmentRequest, - RequestContactButton, - RequestGeoLocationButton, - ShareAttachment, - ShareAttachmentRequest, - StickerAttachment, - StickerAttachmentRequest, - StrikethroughMarkup, - StrongMarkup, - UnderlineMarkup, - UserMentionMarkup, - VideoAttachment, - VideoAttachmentRequest, -) - -_has_tag_providers = concat_provider( - # ---> UpdateType <--- - has_tag_provider(BotAddedToChat, "update_type", UpdateType.BOT_ADDED), - has_tag_provider(BotRemovedFromChat, "update_type", UpdateType.BOT_REMOVED), - has_tag_provider(BotStarted, "update_type", UpdateType.BOT_STARTED), - has_tag_provider(BotStopped, "update_type", UpdateType.BOT_STOPPED), - has_tag_provider(ChatTitleChanged, "update_type", UpdateType.CHAT_TITLE_CHANGED), - has_tag_provider(DialogCleared, "update_type", UpdateType.DIALOG_CLEARED), - has_tag_provider(DialogMuted, "update_type", UpdateType.DIALOG_MUTED), - has_tag_provider(DialogRemoved, "update_type", UpdateType.DIALOG_REMOVED), - has_tag_provider(DialogUnmuted, "update_type", UpdateType.DIALOG_UNMUTED), - has_tag_provider(MessageCallback, "update_type", UpdateType.MESSAGE_CALLBACK), - has_tag_provider(MessageCreated, "update_type", UpdateType.MESSAGE_CREATED), - has_tag_provider(MessageEdited, "update_type", UpdateType.MESSAGE_EDITED), - has_tag_provider(MessageRemoved, "update_type", UpdateType.MESSAGE_REMOVED), - has_tag_provider(UserAddedToChat, "update_type", UpdateType.USER_ADDED), - has_tag_provider(UserRemovedFromChat, "update_type", UpdateType.USER_REMOVED), - # ---> AttachmentType <--- - has_tag_provider(AudioAttachment, "type", AttachmentType.AUDIO), - has_tag_provider(ContactAttachment, "type", AttachmentType.CONTACT), - has_tag_provider(FileAttachment, "type", AttachmentType.FILE), - has_tag_provider(PhotoAttachment, "type", AttachmentType.IMAGE), - has_tag_provider(InlineKeyboardAttachment, "type", AttachmentType.INLINE_KEYBOARD), - has_tag_provider(LocationAttachment, "type", AttachmentType.LOCATION), - has_tag_provider(ShareAttachment, "type", AttachmentType.SHARE), - has_tag_provider(StickerAttachment, "type", AttachmentType.STICKER), - has_tag_provider(VideoAttachment, "type", AttachmentType.VIDEO), - # ---> MarkupElementType <--- - has_tag_provider(EmphasizedMarkup, "type", MarkupElementType.EMPHASIZED), - has_tag_provider(LinkMarkup, "type", MarkupElementType.LINK), - has_tag_provider(MonospacedMarkup, "type", MarkupElementType.MONOSPACED), - has_tag_provider( - StrikethroughMarkup, - "type", - MarkupElementType.STRIKETHROUGH, - ), - has_tag_provider(StrongMarkup, "type", MarkupElementType.STRONG), - has_tag_provider(UnderlineMarkup, "type", MarkupElementType.UNDERLINE), - has_tag_provider(UserMentionMarkup, "type", MarkupElementType.USER_MENTION), - # ---> AttachmentRequestType <--- - has_tag_provider(PhotoAttachmentRequest, "type", AttachmentRequestType.IMAGE), - has_tag_provider(VideoAttachmentRequest, "type", AttachmentRequestType.VIDEO), - has_tag_provider(AudioAttachmentRequest, "type", AttachmentRequestType.AUDIO), - has_tag_provider(FileAttachmentRequest, "type", AttachmentRequestType.FILE), - has_tag_provider(StickerAttachmentRequest, "type", AttachmentRequestType.STICKER), - has_tag_provider(ContactAttachmentRequest, "type", AttachmentRequestType.CONTACT), - has_tag_provider( - InlineKeyboardAttachmentRequest, - "type", - AttachmentRequestType.INLINE_KEYBOARD, - ), - has_tag_provider(LocationAttachmentRequest, "type", AttachmentRequestType.LOCATION), - has_tag_provider(ShareAttachmentRequest, "type", AttachmentRequestType.SHARE), - # ---> KeyboardButtonType <--- - has_tag_provider(CallbackButton, "type", ButtonType.CALLBACK), - has_tag_provider(LinkButton, "type", ButtonType.LINK), - has_tag_provider( - RequestContactButton, - "type", - ButtonType.REQUEST_CONTACT, - ), - has_tag_provider( - RequestGeoLocationButton, - "type", - ButtonType.REQUEST_GEO_LOCATION, - ), - has_tag_provider( - OpenAppButton, - "type", - ButtonType.OPEN_APP, - ), - has_tag_provider( - MessageButton, - "type", - ButtonType.MESSAGE, - ), -) +from maxo.types import AttachmentPayload class MaxApiClient(AiohttpAsyncClient): def __init__( self, token: str, - warming_up: bool, - text_format: TextFormat | None = None, + request_dumper: RequestDumper, + response_loader: ResponseLoader, base_url: str = "https://platform-api.max.ru/", middleware: list[AsyncMiddleware] | None = None, session: ClientSession | None = None, @@ -181,8 +41,6 @@ def __init__( json_loads: Callable[[str | bytes | bytearray], Any] = json.loads, ) -> None: self._token = token - self._warming_up = warming_up - self._text_format = text_format if session is None: session = ClientSession() @@ -192,12 +50,6 @@ def __init__( if "User-Agent" not in session.headers: session.headers["User-Agent"] = f"maxo/{__version__}" - if middleware is None: - middleware = [] - - request_dumper = self._init_method_dumper() - response_loader = self._init_response_loader() - super().__init__( base_url=base_url, request_dumper=request_dumper, @@ -208,55 +60,6 @@ def __init__( json_loads=json_loads, ) - def _init_method_dumper(self) -> Retort: - retort = DEFAULT_RETORT.extend( - recipe=[ - _has_tag_providers, - dumper( - for_marker(QueryMarker, P[None]), - lambda _: "null", - ), - dumper( - for_marker(QueryMarker, P[bool]), - lambda item: int(item), - ), - dumper( - for_marker(QueryMarker, P[list[str]] | P[list[int]]), - lambda seq: ",".join(str(el) for el in seq), - ), - dumper( - P[TextFormat] - | P[TextFormat | None] - | P[Omittable[TextFormat]] - | P[Omittable[TextFormat | None]], - lambda item: item or self._text_format, - ), - dumper( - P[Attachments], - lambda attachment: attachment.to_request(), - chain=Chain.FIRST, - ), - ], - ) - - if self._warming_up: - retort = warming_up_retort(retort, warming_up=WarmingUpType.METHOD) - - return retort - - def _init_response_loader(self) -> Retort: - retort = DEFAULT_RETORT.extend( - recipe=[ - _has_tag_providers, - loader(P[datetime], lambda x: datetime.fromtimestamp(x / 1000, tz=UTC)), - ], - ) - - if self._warming_up: - retort = warming_up_retort(retort, warming_up=WarmingUpType.TYPES) - - return retort - def handle_error(self, response: HTTPResponse, method: BaseMethod[Any]) -> Never: # ruff: noqa: PLR2004 code: str = response.data.get("code") or response.data.get("error_code", "") @@ -295,3 +98,90 @@ def validate_response(self, response: HTTPResponse, method: BaseMethod) -> None: response.status_code, ) response.status_code = 400 + + async def download( + self, + url: str | AttachmentPayload, + destination: BinaryIO | pathlib.Path | str | None = None, + timeout: int = 30, + chunk_size: int = 65536, + seek: bool = True, + ) -> BinaryIO | None: + if isinstance(url, AttachmentPayload): + url = url.url + + return await self._download_file( + url, + destination=destination, + timeout=timeout, + chunk_size=chunk_size, + seek=seek, + ) + + async def _download_file( + self, + url: str, + destination: BinaryIO | pathlib.Path | str | None, + timeout: int, + chunk_size: int, + seek: bool, + ) -> BinaryIO | None: + if destination is None: + destination = io.BytesIO() + + stream = self._stream_content( + url=url, + timeout=timeout, + chunk_size=chunk_size, + raise_for_status=True, + ) + + if isinstance(destination, (str, pathlib.Path)): + await self.__download_file(destination=destination, stream=stream) + return None + return await self.__download_file_binary_io( + destination=destination, + seek=seek, + stream=stream, + ) + + async def _stream_content( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: int = 30, + chunk_size: int = 65536, + raise_for_status: bool = True, + ) -> AsyncGenerator[bytes, None]: + async with self._session.get( + url, + timeout=timeout, + headers=headers, + raise_for_status=raise_for_status, + ) as resp: + async for chunk in resp.content.iter_chunked(chunk_size): + yield chunk + + @classmethod + async def __download_file( + cls, + destination: str | pathlib.Path, + stream: AsyncGenerator[bytes, None], + ) -> None: + async with await open_file(destination, "wb") as f: + async for chunk in stream: + await f.write(chunk) + + @classmethod + async def __download_file_binary_io( + cls, + destination: BinaryIO, + seek: bool, + stream: AsyncGenerator[bytes, None], + ) -> BinaryIO: + async for chunk in stream: + destination.write(chunk) + destination.flush() + if seek is True: + destination.seek(0) + return destination diff --git a/src/maxo/bot/bot.py b/src/maxo/bot/bot.py index 1de3580d..0ed6e52b 100644 --- a/src/maxo/bot/bot.py +++ b/src/maxo/bot/bot.py @@ -1,10 +1,16 @@ -from collections.abc import AsyncIterator +import json +import pathlib +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager -from typing import Self, TypeVar +from typing import Any, BinaryIO, Self, TypeVar +from adaptix import Retort from unihttp.bind_method import bind_method +from unihttp.middlewares import AsyncMiddleware +from maxo import loggers from maxo.bot.api_client import MaxApiClient +from maxo.bot.defaults import BotDefaults from maxo.bot.methods import ( AddMembers, AnswerOnCallback, @@ -47,24 +53,43 @@ EmptyBotState, RunningBotState, ) -from maxo.enums.text_format import TextFormat -from maxo.types import MaxoType +from maxo.errors import MaxBotApiError +from maxo.serialization import get_retort +from maxo.types import AttachmentPayload, MaxoType _MethodResultT = TypeVar("_MethodResultT", bound=MaxoType) class Bot: - __slots__ = ("_state", "_text_format", "_token", "_warming_up") + __slots__ = ( + "_defaults", + "_json_dumps", + "_json_loads", + "_middleware", + "_retort", + "_state", + "_token", + "_warming_up", + ) def __init__( self, token: str, - text_format: TextFormat | None = None, + *, + defaults: BotDefaults | None = None, warming_up: bool = True, + middleware: list[AsyncMiddleware] | None = None, + json_dumps: Callable[[Any], str] = json.dumps, + json_loads: Callable[[str | bytes | bytearray], Any] = json.loads, ) -> None: + self._defaults = defaults or BotDefaults() self._token = token - self._text_format = text_format self._warming_up = warming_up + self._middleware = middleware + self._json_dumps = json_dumps + self._json_loads = json_loads + + self._retort = get_retort(defaults=self._defaults, warming_up=warming_up) self._state = EmptyBotState() @@ -72,11 +97,31 @@ def __init__( def state(self) -> BotState: return self._state + @property + def retort(self) -> Retort: + return self._retort + + @property + def defaults(self) -> BotDefaults: + return self._defaults + + @property + def token(self) -> str: + """Bot API token. Treat as secret — avoid logging or exposing.""" + return self._token + async def start(self) -> None: if self.state.started: return - api_client = MaxApiClient(self._token, self._warming_up, self._text_format) + api_client = MaxApiClient( + token=self._token, + request_dumper=self._retort, + response_loader=self._retort, + middleware=self._middleware, + json_dumps=self._json_dumps, + json_loads=self._json_loads, + ) self._state = ConnectingBotState(api_client=api_client) info = await self.get_my_info() @@ -97,6 +142,16 @@ async def call_method( ) -> _MethodResultT: return await self.state.api_client.call_method(method) + async def silent_call_method(self, method: MaxoMethod[_MethodResultT]) -> None: + try: + await self.call_method(method) + except MaxBotApiError as e: + # In due to WebHook mechanism doesn't allow getting response for + # requests called in answer to WebHook request. + # Need to skip unsuccessful responses. + # For debugging here is added logging. + loggers.bot.error("Failed to make answer: %s: %s", e.__class__.__name__, e) + async def close(self) -> None: if self.state.closed or not self.state.started: return @@ -104,6 +159,22 @@ async def close(self) -> None: await self.state.api_client.close() self._state = ClosedBotState() + async def download( + self, + url: str | AttachmentPayload, + destination: BinaryIO | pathlib.Path | str | None = None, + timeout: int = 30, + chunk_size: int = 65536, + seek: bool = True, + ) -> BinaryIO | None: + return await self.state.api_client.download( + url=url, + destination=destination, + timeout=timeout, + chunk_size=chunk_size, + seek=seek, + ) + # Bots edit_bot_info = bind_method(EditBotInfo) diff --git a/src/maxo/bot/defaults.py b/src/maxo/bot/defaults.py new file mode 100644 index 00000000..34a57352 --- /dev/null +++ b/src/maxo/bot/defaults.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from maxo.enums import TextFormat + + +@dataclass +class BotDefaults: + """Default values for bot API calls.""" + + text_format: TextFormat | None = None + """Default text format for messages""" + disable_link_preview: bool | None = None + """Default value for disable_link_preview parameter""" diff --git a/src/maxo/bot/methods/base.py b/src/maxo/bot/methods/base.py index 7ea579a0..3c382223 100644 --- a/src/maxo/bot/methods/base.py +++ b/src/maxo/bot/methods/base.py @@ -1,6 +1,6 @@ from unihttp.method import BaseMethod -from maxo.types import MaxoType +from maxo.types.base import MaxoType class MaxoMethod[MethodResultT](BaseMethod[MethodResultT], MaxoType): diff --git a/src/maxo/bot/methods/subscriptions/subscribe.py b/src/maxo/bot/methods/subscriptions/subscribe.py index e1bec93a..40339750 100644 --- a/src/maxo/bot/methods/subscriptions/subscribe.py +++ b/src/maxo/bot/methods/subscriptions/subscribe.py @@ -8,8 +8,51 @@ class Subscribe(MaxoMethod[SimpleQueryResult]): """ Подписка на обновления - Подписывает бота на получение обновлений через WebHook. После вызова этого метода бот будет получать уведомления о новых событиях в чатах на указанный URL. - Ваш сервер **должен** прослушивать один из следующих портов: `80`, `8080`, `443`, `8443`, `16384`-`32383` + Метод настраивает доставку событий бота через Webhook — основной механизм получения событий в продуктовых интеграциях. При активной подписке Long Polling не работает + + ## Модель доставки событий + + После вызова метода события отправляются на указанный Webhook-endpoint в виде HTTPS POST-запросов с объектом [`Update`](https://dev.max.ru/docs-api/objects/Update) + + Как обрабатывается событие: + + 1. При наступлении события выполняется вызов Webhook-endpoint + 2. Выполняется TLS-валидация целевого endpoint для безопасной передачи данных + 3. На endpoint отправляется HTTP-запрос + 4. Если при создании подписки указан `secret`, проверяется заголовок `X-Max-Bot-Api-Secret` + 5. При успешной валидации возвращается HTTP 200 OK + 6. Выполняется бизнес-логика обработки события + 7. Инициируются вызовы MAX API + + + ## Требования к Webhook-endpoint + + ### URL и порт + + Webhook-endpoint должен быть доступен по HTTPS на порту 443. Ваш сервер должен прослушивать этот порт. Порт в URL не указывается: + + ``` + https://your-domain.com/webhook + ``` + + > Поддерживается только порт 443. Если endpoint недоступен, события не доставляются + + ### Безопасность соединения (TLS) + + Перед отправкой событий устанавливается HTTPS-соединение и проверяется TLS-сертификат Webhook-endpoint. Это необходимо для безопасной передачи информации + + Требования к сертификату: + + - сертификат выдан доверенным центром сертификации + - самоподписанные сертификаты не поддерживаются + - доменное имя в URL совпадает с CN или SAN сертификата + - сервер предоставляет полную цепочку сертификатов + + > Если TLS-проверка не проходит, события не доставляются + + ### Обработка запросов + + Webhook-endpoint должен возвращать **HTTP 200** в течение 30 секунд. Любой другой код ответа или превышение тайм-аута — ошибка доставки Пример запроса: ```bash @@ -23,10 +66,37 @@ class Subscribe(MaxoMethod[SimpleQueryResult]): }' ``` + ### Политика повторных попыток + + Если доставка не удалась, выполняется до 10 повторных попыток с экспоненциально растущим интервалом: + + - 1-я попытка: через 60 секунд + - 2-я попытка: через 150 секунд (60 × 2,5) + - 3-я попытка: через 375 секунд (150 × 2,5) + - и так далее + + > Если в течение 8 часов по URL вебхука не получен успешный ответ, бот автоматически отписывается от вебхука + + ## Безопасность Webhook-запросов + + Параметр `secret` позволяет убедиться, что Webhook-запросы приходят от MAX, а не от третьей стороны. Это необязательный параметр, но мы настоятельно рекомендуем указывать его. Проверяйте значение заголовка `X-Max-Bot-Api-Secret` на Webhook-сервере и отклоняйте запросы при несоответствии + + Если `secret` указан при создании подписки, он передаётся в заголовке `X-Max-Bot-Api-Secret` каждого Webhook-запроса + + + + --- + + ## Формат и типы событий + + Webhook-запрос содержит объект [`Update`](https://dev.max.ru/docs-api/objects/Update) + + Полный список типов событий и структура объекта описаны в разделе [Update](https://dev.max.ru/docs-api/objects/Update) + Args: secret: Cекрет, который должен быть отправлен в заголовке `X-Max-Bot-Api-Secret` в каждом запросе Webhook. Разрешены только символы `A-Z`, `a-z`, `0-9`, и дефис. Заголовок рекомендован, чтобы запрос поступал из установленного веб-узла - update_types: Список типов обновлений, которые ваш бот хочет получать. Для полного списка типов см. объект [Update](https://dev.max.ru/docs-api/objects/Update) - url: URL HTTP(S)-эндпойнта вашего бота. Должен начинаться с `http(s)://` + update_types: Список типов обновлений, которые хочет получать ваш бот. Для полного списка типов см. объект [Update](https://dev.max.ru/docs-api/objects/Update) + url: URL HTTPS-endpoint вашего бота. Должен начинаться с `https://` Источник: https://dev.max.ru/docs-api/methods/POST/subscriptions """ @@ -35,8 +105,8 @@ class Subscribe(MaxoMethod[SimpleQueryResult]): __method__ = "post" url: Body[str] - """URL HTTP(S)-эндпойнта вашего бота. Должен начинаться с `http(s)://`""" + """URL HTTPS-endpoint вашего бота. Должен начинаться с `https://`""" secret: Body[Omittable[str]] = Omitted() """Cекрет, который должен быть отправлен в заголовке `X-Max-Bot-Api-Secret` в каждом запросе Webhook. Разрешены только символы `A-Z`, `a-z`, `0-9`, и дефис. Заголовок рекомендован, чтобы запрос поступал из установленного веб-узла""" update_types: Body[Omittable[list[str]]] = Omitted() - """Список типов обновлений, которые ваш бот хочет получать. Для полного списка типов см. объект [Update](https://dev.max.ru/docs-api/objects/Update)""" + """Список типов обновлений, которые хочет получать ваш бот. Для полного списка типов см. объект [Update](https://dev.max.ru/docs-api/objects/Update)""" diff --git a/src/maxo/dialogs/api/entities/stack.py b/src/maxo/dialogs/api/entities/stack.py index ca96bb24..f60c36eb 100644 --- a/src/maxo/dialogs/api/entities/stack.py +++ b/src/maxo/dialogs/api/entities/stack.py @@ -17,7 +17,7 @@ def new_int_id() -> int: - return int(time.time()) % 100_000_000 + random.randint(0, 99) * 100_000_000 + return int(time.time() * 1000) % 100_000_000 + random.randint(0, 99) * 100_000_000 def id_to_str(int_id: int) -> str: diff --git a/src/maxo/dialogs/dialog.py b/src/maxo/dialogs/dialog.py index b8a39839..ad3af389 100644 --- a/src/maxo/dialogs/dialog.py +++ b/src/maxo/dialogs/dialog.py @@ -237,7 +237,7 @@ async def process_result( manager, ) - def include_router(self, router: BaseRouter) -> BaseRouter: + def include(self, *routers: BaseRouter) -> BaseRouter: raise TypeError("Dialog cannot include routers") async def process_close( diff --git a/src/maxo/dialogs/test_tools/bot_client.py b/src/maxo/dialogs/test_tools/bot_client.py index 8065d177..0f17d3ff 100644 --- a/src/maxo/dialogs/test_tools/bot_client.py +++ b/src/maxo/dialogs/test_tools/bot_client.py @@ -30,7 +30,7 @@ class FakeBot(Bot): def __init__(self) -> None: - super().__init__("", None, warming_up=False) + super().__init__("", warming_up=False) info = BotInfo( user_id=1000, first_name="bot", diff --git a/src/maxo/enums/attachment_type.py b/src/maxo/enums/attachment_type.py index 8dbdb6e7..3cccb495 100644 --- a/src/maxo/enums/attachment_type.py +++ b/src/maxo/enums/attachment_type.py @@ -1,4 +1,5 @@ from enum import StrEnum +from typing import TypeAlias class AttachmentType(StrEnum): @@ -15,6 +16,9 @@ class AttachmentType(StrEnum): STICKER = "sticker" VIDEO = "video" + # Подражание aiogram + DOCUMENT = FILE + # Подражание aiogram -ContentType = AttachmentType +ContentType: TypeAlias = AttachmentType diff --git a/src/maxo/filters.py b/src/maxo/filters.py new file mode 100644 index 00000000..4e00e56c --- /dev/null +++ b/src/maxo/filters.py @@ -0,0 +1,46 @@ +# ruff: noqa: E402 + +import warnings + +warnings.warn( + "Алиас `maxo.filters` сделан для удобного портирования ботов с `aiogram` " + "и будет удален в будущих версиях. " + "Пожалуйста, обновите импорты на 'from maxo.routing.filters import ...'", + FutureWarning, + stacklevel=2, +) + +# `MagicFilter` and `MagicData` in maxo.integrations.magic_filter + +from maxo.routing.filters.always import AlwaysFalseFilter, AlwaysTrueFilter +from maxo.routing.filters.base import BaseFilter +from maxo.routing.filters.command import Command, CommandStart +from maxo.routing.filters.deeplink import DeeplinkFilter +from maxo.routing.filters.exception import ExceptionMessageFilter, ExceptionTypeFilter +from maxo.routing.filters.logic import ( + AndFilter, + InvertFilter, + OrFilter, + and_f, + invert_f, + or_f, +) +from maxo.routing.filters.payload import Payload + +__all__ = ( + "AlwaysFalseFilter", + "AlwaysTrueFilter", + "AndFilter", + "BaseFilter", + "Command", + "CommandStart", + "DeeplinkFilter", + "ExceptionMessageFilter", + "ExceptionTypeFilter", + "InvertFilter", + "OrFilter", + "Payload", + "and_f", + "invert_f", + "or_f", +) diff --git a/src/maxo/loggers.py b/src/maxo/loggers.py index 800d7920..f8c09194 100644 --- a/src/maxo/loggers.py +++ b/src/maxo/loggers.py @@ -4,4 +4,5 @@ long_polling = getLogger("maxo.long_polling") update_context = getLogger("maxo.routing.update_context") utils = getLogger("maxo.utils") +bot = getLogger("maxo.bot") bot_session = getLogger("maxo.bot.session") diff --git a/src/maxo/omit.py b/src/maxo/omit.py index 227fc300..f558c691 100644 --- a/src/maxo/omit.py +++ b/src/maxo/omit.py @@ -6,7 +6,7 @@ _OmittedValueT = TypeVar("_OmittedValueT") Omitted = UniOmitted -Omittable: TypeAlias = _OmittedValueT | Omitted # noqa: UP040 +Omittable: TypeAlias = _OmittedValueT | Omitted def is_omitted(value: Any) -> TypeIs[Omitted]: diff --git a/src/maxo/routing/dispatcher.py b/src/maxo/routing/dispatcher.py index 7c9c0e2c..2c3b07a2 100644 --- a/src/maxo/routing/dispatcher.py +++ b/src/maxo/routing/dispatcher.py @@ -10,9 +10,9 @@ from maxo.routing.middlewares.error import ErrorMiddleware from maxo.routing.middlewares.fsm_context import FSMContextMiddleware from maxo.routing.middlewares.update_context import UpdateContextMiddleware -from maxo.routing.observers.signal import SignalObserver +from maxo.routing.observers import UpdateObserver from maxo.routing.routers.simple import Router -from maxo.routing.sentinels import UNHANDLED +from maxo.routing.sentinels import UNHANDLED, SkipHandler from maxo.routing.signals.base import BaseSignal from maxo.routing.signals.update import MaxoUpdate from maxo.routing.updates.base import BaseUpdate @@ -22,7 +22,7 @@ class Dispatcher(Router): - update: SignalObserver[MaxoUpdate[Any]] + update: UpdateObserver[MaxoUpdate[Any]] def __init__( self, @@ -32,6 +32,7 @@ def __init__( storage: BaseStorage | None = None, events_isolation: BaseEventIsolation | None = None, key_builder: BaseKeyBuilder | None = None, + disable_fsm: bool = False, ) -> None: super().__init__(self.__class__.__name__) @@ -39,23 +40,29 @@ def __init__( self.workflow_data["dispatcher"] = self self.workflow_data["router"] = self - self.update = self._observers[MaxoUpdate] = SignalObserver[MaxoUpdate]() + self.update = self._observers[MaxoUpdate] = UpdateObserver[MaxoUpdate]() self.update.middleware.outer(ErrorMiddleware(self)) self.update.middleware.outer(UpdateContextMiddleware()) self.update.handler(self._feed_update_handler) # State system settings - if key_builder is None: - key_builder = DefaultKeyBuilder() + if not disable_fsm: + if key_builder is None: + key_builder = DefaultKeyBuilder() - if storage is None: - storage = MemoryStorage(key_builder=key_builder) + if storage is None: + storage = MemoryStorage(key_builder=key_builder) - if events_isolation is None: - events_isolation = SimpleEventIsolation(key_builder=key_builder) + if events_isolation is None: + events_isolation = SimpleEventIsolation(key_builder=key_builder) - self.update.middleware.outer(FSMContextMiddleware(storage, events_isolation)) + # Note that when FSM middleware is disabled, + # the event isolation is also disabled + # Because the isolation mechanism is a part of the FS + self.update.middleware.outer( + FSMContextMiddleware(storage, events_isolation), + ) # Facade settings self.update.middleware.outer(FacadeMiddleware()) @@ -100,9 +107,15 @@ async def feed_update(self, update: BaseUpdate, bot: Bot | None = None) -> Any: ctx["ctx"] = ctx return await self.trigger(ctx) - async def _feed_update_handler(self, ctx: Ctx) -> Any: - ctx["update"] = ctx["update"].update - return await self.trigger(ctx) + async def _feed_update_handler(self, update: MaxoUpdate[Any], ctx: Ctx) -> Any: + ctx_copy = Ctx(dict(ctx)) + ctx_copy["update"] = update.update + + result = await self.trigger(ctx_copy) + if result is UNHANDLED: + raise SkipHandler + + return result async def _emit_before_startup_handler(self) -> None: validate_router_graph(self) diff --git a/src/maxo/routing/filters/__init__.py b/src/maxo/routing/filters/__init__.py index 8fb8c7e1..b6562ec2 100644 --- a/src/maxo/routing/filters/__init__.py +++ b/src/maxo/routing/filters/__init__.py @@ -2,7 +2,7 @@ from .always import AlwaysFalseFilter, AlwaysTrueFilter from .base import BaseFilter -from .command import Command, CommandStart +from .command import Command, CommandObject, CommandStart from .deeplink import DeeplinkFilter from .exception import ExceptionMessageFilter, ExceptionTypeFilter from .logic import AndFilter, InvertFilter, OrFilter, and_f, invert_f, or_f @@ -14,6 +14,7 @@ "AndFilter", "BaseFilter", "Command", + "CommandObject", "CommandStart", "DeeplinkFilter", "ExceptionMessageFilter", diff --git a/src/maxo/routing/filters/command.py b/src/maxo/routing/filters/command.py index 0069a186..83582337 100644 --- a/src/maxo/routing/filters/command.py +++ b/src/maxo/routing/filters/command.py @@ -1,16 +1,14 @@ import re from collections.abc import Iterable, Sequence -from dataclasses import replace -from re import Pattern -from typing import ( - cast, -) +from dataclasses import field, replace +from re import Match, Pattern +from typing import cast from maxo import Bot, Ctx from maxo.routing.filters import BaseFilter from maxo.routing.updates import MessageCreated +from maxo.types.base import MaxoType from maxo.types.bot_command import BotCommand -from maxo.types.command_object import CommandObject CommandPatternType = str | re.Pattern | BotCommand @@ -19,6 +17,27 @@ class CommandException(Exception): pass +class CommandObject(MaxoType): + prefix: str = "/" + command: str = "" + mention: str | None = None + args: str | None = field(repr=False, default=None) + regexp_match: Match[str] | None = field(repr=False, default=None) + + @property + def mentioned(self) -> bool: + return bool(self.mention) + + @property + def text(self) -> str: + line = self.prefix + self.command + if self.mention: + line += "@" + self.mention + if self.args: + line += " " + self.args + return line + + class Command(BaseFilter[MessageCreated]): __slots__ = ("commands", "ignore_case", "ignore_mention", "magic", "prefix") diff --git a/src/maxo/routing/handlers/update.py b/src/maxo/routing/handlers/update.py index aea47e94..fe33dadf 100644 --- a/src/maxo/routing/handlers/update.py +++ b/src/maxo/routing/handlers/update.py @@ -71,9 +71,7 @@ async def execute_filter(self, ctx: Ctx) -> bool: async def __call__(self, ctx: Ctx) -> _ReturnT_co: update = ctx.pop("update") wrapped = partial(self._handler_fn, update, **self._prepare_kwargs(ctx)) - try: - if self._awaitable: - return await wrapped() - return await asyncio.to_thread(wrapped) - finally: - ctx["update"] = update + ctx["update"] = update + if self._awaitable: + return await wrapped() + return await asyncio.to_thread(wrapped) diff --git a/src/maxo/routing/interfaces/router.py b/src/maxo/routing/interfaces/router.py index 82f5c5cb..86b7ee84 100644 --- a/src/maxo/routing/interfaces/router.py +++ b/src/maxo/routing/interfaces/router.py @@ -44,6 +44,14 @@ def children_routers(self) -> Sequence["BaseRouter"]: def include(self, *routers: "BaseRouter") -> None: raise NotImplementedError + # Подражание aiogram + def include_router(self, router: "BaseRouter") -> None: + return self.include(router) + + # Подражание aiogram + def include_routers(self, *routers: "BaseRouter") -> None: + return self.include(*routers) + @abstractmethod async def trigger_child(self, ctx: Ctx) -> Any: raise NotImplementedError diff --git a/src/maxo/routing/middlewares/fsm_context.py b/src/maxo/routing/middlewares/fsm_context.py index 8eb9b8df..9108eb8a 100644 --- a/src/maxo/routing/middlewares/fsm_context.py +++ b/src/maxo/routing/middlewares/fsm_context.py @@ -9,8 +9,9 @@ from maxo.routing.signals.update import MaxoUpdate from maxo.types.update_context import UpdateContext -FSM_STORAGE_KEY = "fsm_storage" # and "storage" too +FSM_STORAGE_KEY = "fsm_storage" FSM_CONTEXT_KEY = "fsm_context" +FSM_CONTEXT_STATE_KEY = "state" # same as "fsm_context", Подражаение aiogram RAW_STATE_KEY = "raw_state" @@ -41,11 +42,9 @@ async def __call__( return await next(ctx) async with self._events_isolation.lock(key=storage_key): - fsm_context = FSMContext( - key=storage_key, - storage=self._storage, - ) + fsm_context = FSMContext(key=storage_key, storage=self._storage) ctx[FSM_CONTEXT_KEY] = fsm_context + ctx[FSM_CONTEXT_STATE_KEY] = fsm_context ctx[RAW_STATE_KEY] = await fsm_context.get_state() return await next(ctx) diff --git a/src/maxo/routing/observers/base.py b/src/maxo/routing/observers/base.py index 76b7e04f..b7f17e4c 100644 --- a/src/maxo/routing/observers/base.py +++ b/src/maxo/routing/observers/base.py @@ -8,7 +8,7 @@ from maxo.routing.interfaces.observer import ObserverState from maxo.routing.middlewares.manager import MiddlewareManagerFacade from maxo.routing.observers.state import EmptyObserverState -from maxo.routing.sentinels import UNHANDLED +from maxo.routing.sentinels import UNHANDLED, SkipHandler from maxo.routing.updates.base import BaseUpdate _UpdateT = TypeVar("_UpdateT", bound=BaseUpdate) @@ -79,7 +79,10 @@ async def handler_lookup(self, ctx: Ctx) -> Any: for handler in self._handlers: if await handler.execute_filter(ctx): - return await self.execute_handler(ctx, handler) + try: + return await self.execute_handler(ctx, handler) + except SkipHandler: + continue return UNHANDLED diff --git a/src/maxo/routing/observers/signal.py b/src/maxo/routing/observers/signal.py index 0fe9df4c..17354de7 100644 --- a/src/maxo/routing/observers/signal.py +++ b/src/maxo/routing/observers/signal.py @@ -34,12 +34,12 @@ async def handler_lookup(self, ctx: Ctx) -> Any: if not await self.execute_filter(ctx): return UNHANDLED - result = UNHANDLED for handler in self._handlers: if await handler.execute_filter(ctx): - result = await self.execute_handler(ctx, handler) + await self.execute_handler(ctx, handler) - return result + # Возврат UNHANDLED для того, чтобы сигнал прошёлся по дочерним роутерам + return UNHANDLED if TYPE_CHECKING: diff --git a/src/maxo/routing/routers/simple.py b/src/maxo/routing/routers/simple.py index 3b503040..b46d2a49 100644 --- a/src/maxo/routing/routers/simple.py +++ b/src/maxo/routing/routers/simple.py @@ -12,7 +12,7 @@ from maxo.routing.observers import SignalObserver, UpdateObserver from maxo.routing.observers.state import EmptyObserverState, StartedObserverState from maxo.routing.routers.state import EmptyRouterState, StartedRouterState -from maxo.routing.sentinels import UNHANDLED, SkipHandler +from maxo.routing.sentinels import UNHANDLED from maxo.routing.signals.shutdown import AfterShutdown, BeforeShutdown from maxo.routing.signals.startup import AfterStartup, BeforeStartup from maxo.routing.updates import ( @@ -54,6 +54,9 @@ def __init__(self, name: str | None = None) -> None: self.user_added_to_chat = UpdateObserver[UserAddedToChat]() self.user_removed_from_chat = UpdateObserver[UserRemovedFromChat]() + self.message = self.message_created # Подражание aiogram + self.callback_query = self.message_callback # Подражание aiogram + self.exception = self.exceptions = self.error = self.errors = UpdateObserver[ ErrorEvent[Any, Any] ]() @@ -144,11 +147,7 @@ async def trigger(self, ctx: Ctx) -> Any: return await chain_middlewares(ctx) async def _trigger(self, ctx: Ctx, *, observer: Observer) -> Any: - try: - result = await observer.handler_lookup(ctx) - except SkipHandler: - result = UNHANDLED - + result = await observer.handler_lookup(ctx) if result is UNHANDLED: return await self.trigger_child(ctx) return result diff --git a/src/maxo/routing/signals/update.py b/src/maxo/routing/signals/update.py index 1533fb9f..3cc052ab 100644 --- a/src/maxo/routing/signals/update.py +++ b/src/maxo/routing/signals/update.py @@ -1,12 +1,11 @@ from typing import Generic, TypeVar from maxo.omit import Omittable, Omitted -from maxo.routing.signals.base import BaseSignal from maxo.routing.updates.base import BaseUpdate _UpdateT = TypeVar("_UpdateT", bound=BaseUpdate) -class MaxoUpdate(BaseSignal, Generic[_UpdateT]): +class MaxoUpdate(BaseUpdate, Generic[_UpdateT]): update: _UpdateT marker: Omittable[int | None] = Omitted() diff --git a/src/maxo/routing/updates/__init__.py b/src/maxo/routing/updates/__init__.py index 120275f1..50985c28 100644 --- a/src/maxo/routing/updates/__init__.py +++ b/src/maxo/routing/updates/__init__.py @@ -9,7 +9,7 @@ from .dialog_removed import DialogRemoved from .dialog_unmuted import DialogUnmuted from .error import ErrorEvent -from .message_callback import MessageCallback +from .message_callback import CallbackQuery, MessageCallback from .message_created import MessageCreated from .message_edited import MessageEdited from .message_removed import MessageRemoved @@ -23,6 +23,7 @@ "BotRemovedFromChat", "BotStarted", "BotStopped", + "CallbackQuery", "ChatTitleChanged", "DialogCleared", "DialogMuted", diff --git a/src/maxo/routing/updates/message_callback.py b/src/maxo/routing/updates/message_callback.py index 830c9b6a..1d2a4f79 100644 --- a/src/maxo/routing/updates/message_callback.py +++ b/src/maxo/routing/updates/message_callback.py @@ -1,3 +1,5 @@ +from typing import TypeAlias + from maxo.enums.update_type import UpdateType from maxo.errors import AttributeIsEmptyError from maxo.omit import Omittable, Omitted, is_defined @@ -59,3 +61,6 @@ def payload(self) -> str | None: @property def user(self) -> User: return self.callback.user + + +CallbackQuery: TypeAlias = MessageCallback # Подражание aiogram diff --git a/src/maxo/serialization.py b/src/maxo/serialization.py new file mode 100644 index 00000000..65adef77 --- /dev/null +++ b/src/maxo/serialization.py @@ -0,0 +1,190 @@ +from datetime import UTC, datetime + +from adaptix import Chain, P, Retort, dumper, loader +from unihttp.markers import QueryMarker +from unihttp.serializers.adaptix import DEFAULT_RETORT, for_marker + +from maxo._internal._adaptix.concat_provider import concat_provider +from maxo._internal._adaptix.has_tag_provider import has_tag_provider +from maxo.bot.defaults import BotDefaults +from maxo.bot.warming_up import WarmingUpType, warming_up_retort +from maxo.enums import ( + AttachmentRequestType, + AttachmentType, + ButtonType, + MarkupElementType, + TextFormat, + UpdateType, +) +from maxo.omit import Omittable +from maxo.routing.updates import ( + BotAddedToChat, + BotRemovedFromChat, + BotStarted, + BotStopped, + ChatTitleChanged, + DialogCleared, + DialogMuted, + DialogRemoved, + DialogUnmuted, + MessageCallback, + MessageCreated, + MessageEdited, + MessageRemoved, + UserAddedToChat, + UserRemovedFromChat, +) +from maxo.types import ( + Attachments, + AudioAttachment, + AudioAttachmentRequest, + CallbackButton, + ContactAttachment, + ContactAttachmentRequest, + EmphasizedMarkup, + FileAttachment, + FileAttachmentRequest, + InlineKeyboardAttachment, + InlineKeyboardAttachmentRequest, + LinkButton, + LinkMarkup, + LocationAttachment, + LocationAttachmentRequest, + MessageButton, + MonospacedMarkup, + OpenAppButton, + PhotoAttachment, + PhotoAttachmentRequest, + RequestContactButton, + RequestGeoLocationButton, + ShareAttachment, + ShareAttachmentRequest, + StickerAttachment, + StickerAttachmentRequest, + StrikethroughMarkup, + StrongMarkup, + UnderlineMarkup, + UserMentionMarkup, + VideoAttachment, + VideoAttachmentRequest, +) + +TAG_PROVIDERS = concat_provider( + # ---> UpdateType <--- + has_tag_provider(BotAddedToChat, "update_type", UpdateType.BOT_ADDED), + has_tag_provider(BotRemovedFromChat, "update_type", UpdateType.BOT_REMOVED), + has_tag_provider(BotStarted, "update_type", UpdateType.BOT_STARTED), + has_tag_provider(BotStopped, "update_type", UpdateType.BOT_STOPPED), + has_tag_provider(ChatTitleChanged, "update_type", UpdateType.CHAT_TITLE_CHANGED), + has_tag_provider(DialogCleared, "update_type", UpdateType.DIALOG_CLEARED), + has_tag_provider(DialogMuted, "update_type", UpdateType.DIALOG_MUTED), + has_tag_provider(DialogRemoved, "update_type", UpdateType.DIALOG_REMOVED), + has_tag_provider(DialogUnmuted, "update_type", UpdateType.DIALOG_UNMUTED), + has_tag_provider(MessageCallback, "update_type", UpdateType.MESSAGE_CALLBACK), + has_tag_provider(MessageCreated, "update_type", UpdateType.MESSAGE_CREATED), + has_tag_provider(MessageEdited, "update_type", UpdateType.MESSAGE_EDITED), + has_tag_provider(MessageRemoved, "update_type", UpdateType.MESSAGE_REMOVED), + has_tag_provider(UserAddedToChat, "update_type", UpdateType.USER_ADDED), + has_tag_provider(UserRemovedFromChat, "update_type", UpdateType.USER_REMOVED), + # ---> AttachmentType <--- + has_tag_provider(AudioAttachment, "type", AttachmentType.AUDIO), + has_tag_provider(ContactAttachment, "type", AttachmentType.CONTACT), + has_tag_provider(FileAttachment, "type", AttachmentType.FILE), + has_tag_provider(PhotoAttachment, "type", AttachmentType.IMAGE), + has_tag_provider(InlineKeyboardAttachment, "type", AttachmentType.INLINE_KEYBOARD), + has_tag_provider(LocationAttachment, "type", AttachmentType.LOCATION), + has_tag_provider(ShareAttachment, "type", AttachmentType.SHARE), + has_tag_provider(StickerAttachment, "type", AttachmentType.STICKER), + has_tag_provider(VideoAttachment, "type", AttachmentType.VIDEO), + # ---> MarkupElementType <--- + has_tag_provider(EmphasizedMarkup, "type", MarkupElementType.EMPHASIZED), + has_tag_provider(LinkMarkup, "type", MarkupElementType.LINK), + has_tag_provider(MonospacedMarkup, "type", MarkupElementType.MONOSPACED), + has_tag_provider(StrikethroughMarkup, "type", MarkupElementType.STRIKETHROUGH), + has_tag_provider(StrongMarkup, "type", MarkupElementType.STRONG), + has_tag_provider(UnderlineMarkup, "type", MarkupElementType.UNDERLINE), + has_tag_provider(UserMentionMarkup, "type", MarkupElementType.USER_MENTION), + # ---> AttachmentRequestType <--- + has_tag_provider(PhotoAttachmentRequest, "type", AttachmentRequestType.IMAGE), + has_tag_provider(VideoAttachmentRequest, "type", AttachmentRequestType.VIDEO), + has_tag_provider(AudioAttachmentRequest, "type", AttachmentRequestType.AUDIO), + has_tag_provider(FileAttachmentRequest, "type", AttachmentRequestType.FILE), + has_tag_provider(StickerAttachmentRequest, "type", AttachmentRequestType.STICKER), + has_tag_provider(ContactAttachmentRequest, "type", AttachmentRequestType.CONTACT), + has_tag_provider( + InlineKeyboardAttachmentRequest, + "type", + AttachmentRequestType.INLINE_KEYBOARD, + ), + has_tag_provider(LocationAttachmentRequest, "type", AttachmentRequestType.LOCATION), + has_tag_provider(ShareAttachmentRequest, "type", AttachmentRequestType.SHARE), + # ---> KeyboardButtonType <--- + has_tag_provider(CallbackButton, "type", ButtonType.CALLBACK), + has_tag_provider(LinkButton, "type", ButtonType.LINK), + has_tag_provider(RequestContactButton, "type", ButtonType.REQUEST_CONTACT), + has_tag_provider(RequestGeoLocationButton, "type", ButtonType.REQUEST_GEO_LOCATION), + has_tag_provider(OpenAppButton, "type", ButtonType.OPEN_APP), + has_tag_provider(MessageButton, "type", ButtonType.MESSAGE), +) + + +_retort: Retort | None = None + + +def get_retort( + *, + defaults: BotDefaults | None = None, + warming_up: bool = True, +) -> Retort: + global _retort + + if _retort is not None: + return _retort + + _retort = create_retort(defaults=defaults, warming_up=warming_up) + return _retort + + +def create_retort( + *, + defaults: BotDefaults | None = None, + warming_up: bool = True, +) -> Retort: + if defaults is None: + defaults = BotDefaults() + + retort = DEFAULT_RETORT.extend( + recipe=[ + TAG_PROVIDERS, + dumper( + for_marker(QueryMarker, P[None]), + lambda _: "null", + ), + dumper( + for_marker(QueryMarker, P[bool]), + lambda item: int(item), + ), + dumper( + for_marker(QueryMarker, P[list[str]] | P[list[int]]), + lambda seq: ",".join(str(el) for el in seq), + ), + dumper( + P[TextFormat] + | P[TextFormat | None] + | P[Omittable[TextFormat]] + | P[Omittable[TextFormat | None]], + lambda item: item or defaults.text_format, + ), + dumper( + P[Attachments], + lambda attachment: attachment.to_request(), + chain=Chain.FIRST, + ), + loader(P[datetime], lambda x: datetime.fromtimestamp(x / 1000, tz=UTC)), + ], + ) + if warming_up: + retort = warming_up_retort(retort, warming_up=WarmingUpType.TYPES) + retort = warming_up_retort(retort, warming_up=WarmingUpType.METHOD) + + return retort diff --git a/src/maxo/types/__init__.py b/src/maxo/types/__init__.py index 440c6731..79b4441b 100644 --- a/src/maxo/types/__init__.py +++ b/src/maxo/types/__init__.py @@ -23,7 +23,6 @@ from .chat_list import ChatList from .chat_member import ChatMember from .chat_members_list import ChatMembersList -from .command_object import CommandObject from .contact_attachment import ContactAttachment from .contact_attachment_payload import ContactAttachmentPayload from .contact_attachment_request import ContactAttachmentRequest @@ -114,7 +113,6 @@ "ChatList", "ChatMember", "ChatMembersList", - "CommandObject", "ContactAttachment", "ContactAttachmentPayload", "ContactAttachmentRequest", diff --git a/src/maxo/types/buttons.py b/src/maxo/types/buttons.py index 629ede8a..5901158b 100644 --- a/src/maxo/types/buttons.py +++ b/src/maxo/types/buttons.py @@ -1,4 +1,3 @@ -from maxo.types.button import Button from maxo.types.callback_button import CallbackButton from maxo.types.chat_button import ChatButton from maxo.types.link_button import LinkButton @@ -15,5 +14,4 @@ | OpenAppButton | MessageButton | ChatButton - | Button ) diff --git a/src/maxo/types/command_object.py b/src/maxo/types/command_object.py deleted file mode 100644 index 51609211..00000000 --- a/src/maxo/types/command_object.py +++ /dev/null @@ -1,26 +0,0 @@ -from dataclasses import field -from re import Match - -from maxo.types.base import MaxoType - - -# Самодельный объект -class CommandObject(MaxoType): - prefix: str = "/" - command: str = "" - mention: str | None = None - args: str | None = field(repr=False, default=None) - regexp_match: Match[str] | None = field(repr=False, default=None) - - @property - def mentioned(self) -> bool: - return bool(self.mention) - - @property - def text(self) -> str: - line = self.prefix + self.command - if self.mention: - line += "@" + self.mention - if self.args: - line += " " + self.args - return line diff --git a/src/maxo/types/contact_attachment.py b/src/maxo/types/contact_attachment.py index 5ab341b2..57f8a4af 100644 --- a/src/maxo/types/contact_attachment.py +++ b/src/maxo/types/contact_attachment.py @@ -45,7 +45,7 @@ def to_request(self) -> ContactAttachmentRequest: name=( self.payload.max_info.first_name if is_defined(self.payload.max_info) - else Omitted() + else None ), contact_id=( self.payload.max_info.user_id diff --git a/src/maxo/types/message_body.py b/src/maxo/types/message_body.py index b972987f..38641ffc 100644 --- a/src/maxo/types/message_body.py +++ b/src/maxo/types/message_body.py @@ -1,10 +1,18 @@ from maxo.errors import AttributeIsEmptyError from maxo.omit import Omittable, Omitted, is_defined from maxo.types.attachments import Attachments +from maxo.types.audio_attachment import AudioAttachment from maxo.types.base import MaxoType +from maxo.types.contact_attachment import ContactAttachment +from maxo.types.file_attachment import FileAttachment from maxo.types.inline_keyboard_attachment import InlineKeyboardAttachment from maxo.types.keyboard import Keyboard +from maxo.types.location_attachment import LocationAttachment from maxo.types.markup_elements import MarkupElements +from maxo.types.photo_attachment import PhotoAttachment +from maxo.types.share_attachment import ShareAttachment +from maxo.types.sticker_attachment import StickerAttachment +from maxo.types.video_attachment import VideoAttachment from maxo.utils.text_decorations import ( TextDecoration, html_decoration, @@ -43,9 +51,7 @@ def id(self) -> str: @property def keyboard(self) -> Keyboard | None: - if not self.attachments: - return None - for attachment in self.attachments: + for attachment in self.attachments or []: if isinstance(attachment, InlineKeyboardAttachment): return attachment.payload return None @@ -54,6 +60,64 @@ def keyboard(self) -> Keyboard | None: def reply_markup(self) -> Keyboard | None: return self.keyboard + @property + def photo(self) -> list[PhotoAttachment]: + return [ + attachment + for attachment in self.attachments or [] + if isinstance(attachment, PhotoAttachment) + ] + + @property + def video(self) -> list[VideoAttachment]: + return [ + attachment + for attachment in self.attachments or [] + if isinstance(attachment, VideoAttachment) + ] + + @property + def audio(self) -> AudioAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, AudioAttachment): + return attachment + return None + + @property + def file(self) -> FileAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, FileAttachment): + return attachment + return None + + @property + def sticker(self) -> StickerAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, StickerAttachment): + return attachment + return None + + @property + def contact(self) -> ContactAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, ContactAttachment): + return attachment + return None + + @property + def share(self) -> ShareAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, ShareAttachment): + return attachment + return None + + @property + def location(self) -> LocationAttachment | None: + for attachment in self.attachments or []: + if isinstance(attachment, LocationAttachment): + return attachment + return None + def _unparse_entities(self, text_decoration: TextDecoration) -> str: text = self.text or "" entities = self.markup or [] diff --git a/src/maxo/types/simple_query_result.py b/src/maxo/types/simple_query_result.py index 7a498376..e22ea1fa 100644 --- a/src/maxo/types/simple_query_result.py +++ b/src/maxo/types/simple_query_result.py @@ -9,11 +9,11 @@ class SimpleQueryResult(MaxoType): Args: message: Объяснительное сообщение, если результат не был успешным - success: `true`, если запрос был успешным, `false` в противном случае + success: `true`, если запрос был успешным, `false` — в противном случае """ success: bool - """`true`, если запрос был успешным, `false` в противном случае""" + """`true`, если запрос был успешным, `false` — в противном случае""" message: Omittable[str] = Omitted() """Объяснительное сообщение, если результат не был успешным""" diff --git a/src/maxo/utils/deeplink.py b/src/maxo/utils/deeplink.py index a4ba76ea..2df31cbd 100644 --- a/src/maxo/utils/deeplink.py +++ b/src/maxo/utils/deeplink.py @@ -1,8 +1,8 @@ __all__ = ( "create_deep_link", + "create_max_http_link", "create_start_link", "create_startapp_link", - "create_telegram_link", "decode_payload", "encode_payload", ) @@ -11,7 +11,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Literal, cast -from maxo.utils.link import create_telegram_link +from maxo.utils.link import create_max_http_link from maxo.utils.payload import decode_payload, encode_payload if TYPE_CHECKING: @@ -78,9 +78,9 @@ def create_deep_link( raise ValueError(f"Payload must be up to {PAYLOAD_MAX_LEN} characters long.") if not app_name: - deep_link = create_telegram_link(username, **{cast(str, link_type): payload}) + deep_link = create_max_http_link(username, **{cast(str, link_type): payload}) else: - deep_link = create_telegram_link( + deep_link = create_max_http_link( username, app_name, **{cast(str, link_type): payload}, diff --git a/src/maxo/utils/facades/methods/message.py b/src/maxo/utils/facades/methods/message.py index 1835fac8..55df141b 100644 --- a/src/maxo/utils/facades/methods/message.py +++ b/src/maxo/utils/facades/methods/message.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from maxo.enums import MessageLinkType, TextFormat -from maxo.omit import Omittable, Omitted +from maxo.omit import Omittable, Omitted, is_omitted from maxo.types.buttons import InlineButtons from maxo.types.chat import Chat from maxo.types.chat_members_list import ChatMembersList @@ -45,6 +45,12 @@ async def send_message( chat_type=recipient.chat_type, ) + if ( + is_omitted(disable_link_preview) + and self.bot.defaults.disable_link_preview is not None + ): + disable_link_preview = self.bot.defaults.disable_link_preview + attachments = await self.build_attachments( base=[], keyboard=keyboard, diff --git a/src/maxo/utils/formatting.py b/src/maxo/utils/formatting.py new file mode 100644 index 00000000..7a1737ec --- /dev/null +++ b/src/maxo/utils/formatting.py @@ -0,0 +1,374 @@ +import dataclasses +import textwrap +from collections.abc import Generator, Iterable, Iterator +from typing import Any, ClassVar, Self + +from maxo.enums import MarkupElementType +from maxo.types.emphasized_markup import EmphasizedMarkup +from maxo.types.link_markup import LinkMarkup +from maxo.types.markup_element import MarkupElement +from maxo.types.markup_elements import MarkupElements +from maxo.types.monospaced_markup import MonospacedMarkup +from maxo.types.strikethrough_markup import StrikethroughMarkup +from maxo.types.strong_markup import StrongMarkup +from maxo.types.underline_markup import UnderlineMarkup +from maxo.types.user_mention_markup import UserMentionMarkup +from maxo.utils.text_decorations import ( + add_surrogates, + html_decoration, + markdown_decoration, + remove_surrogates, +) + +NodeType = Any + + +def sizeof(value: str) -> int: + return len(value.encode("utf-16-le")) // 2 + + +class Text(Iterable[NodeType]): + type: ClassVar[str | None] = None + + __slots__ = ("_body", "_params") + + def __init__( + self, + *body: NodeType, + **params: Any, + ) -> None: + self._body: tuple[NodeType, ...] = body + self._params: dict[str, Any] = params + + @classmethod + def from_entities(cls, text: str, entities: list[MarkupElements]) -> "Text": + return cls( + *_unparse_entities( + text=add_surrogates(text), + entities=( + sorted(entities, key=lambda item: item.offset) if entities else [] + ), + ), + ) + + def render( + self, + *, + _offset: int = 0, + _sort: bool = True, + _collect_entities: bool = True, + ) -> tuple[str, list[MarkupElements]]: + """Render elements tree as text with entities list.""" + text = "" + entities = [] + offset = _offset + + for node in self._body: + if not isinstance(node, Text): + node = str(node) + text += node + offset += sizeof(node) + else: + node_text, node_entities = node.render( + _offset=offset, + _sort=False, + _collect_entities=_collect_entities, + ) + text += node_text + offset += sizeof(node_text) + if _collect_entities: + entities.extend(node_entities) + + if _collect_entities and self.type: + entities.append( + self._render_entity(offset=_offset, length=offset - _offset), + ) + + if _collect_entities and _sort: + entities.sort(key=lambda entity: entity.offset) + + return text, entities + + def _render_entity(self, *, offset: int, length: int) -> MarkupElements: + if self.type is None: + raise ValueError("Node without type can't be rendered as entity") + + markup_map = { + MarkupElementType.STRONG: StrongMarkup, + MarkupElementType.EMPHASIZED: EmphasizedMarkup, + MarkupElementType.UNDERLINE: UnderlineMarkup, + MarkupElementType.STRIKETHROUGH: StrikethroughMarkup, + MarkupElementType.MONOSPACED: MonospacedMarkup, + MarkupElementType.LINK: LinkMarkup, + MarkupElementType.USER_MENTION: UserMentionMarkup, + } + markup_class: type[MarkupElements] = markup_map.get(self.type, MarkupElement) + return markup_class( + type=self.type, + from_=offset, + length=length, + **self._params, + ) + + def as_kwargs( + self, + *, + text_key: str = "text", + replace_format: bool = True, + format_key: str = "format", + ) -> dict[str, Any]: + """ + Render element tree as keyword arguments for usage in an API call. + + .. code-block:: python + + entities = Text(...) + await facade.answer_text(**entities.as_kwargs()) + """ + text_value, _ = self.render() + result: dict[str, Any] = {text_key: text_value} + if replace_format: + result[format_key] = None + return result + + def as_html(self) -> str: + """Render elements tree as HTML markup.""" + text, entities = self.render() + return html_decoration.unparse(text, entities) + + def as_markdown(self) -> str: + """Render elements tree as Markdown markup.""" + text, entities = self.render() + return markdown_decoration.unparse(text, entities) + + def replace(self: Self, *args: Any, **kwargs: Any) -> Self: + return type(self)(*args, **{**self._params, **kwargs}) + + def as_pretty_string(self, indent: bool = False) -> str: + sep = ",\n" if indent else ", " + body = sep.join( + ( + item.as_pretty_string(indent=indent) + if isinstance(item, Text) + else repr(item) + ) + for item in self._body + ) + params = sep.join( + f"{k}={v!r}" for k, v in self._params.items() if v is not None + ) + + args = [] + if body: + args.append(body) + if params: + args.append(params) + + args_str = sep.join(args) + if indent: + args_str = textwrap.indent("\n" + args_str + "\n", " ") + return f"{type(self).__name__}({args_str})" + + def __add__(self, other: NodeType) -> "Text": + if ( + isinstance(other, Text) + and other.type == self.type + and self._params == other._params + ): + return type(self)(*self, *other, **self._params) + if type(self) is Text and isinstance(other, str): + return type(self)(*self, other, **self._params) + return Text(self, other) + + def __iter__(self) -> Iterator[NodeType]: + yield from self._body + + def __len__(self) -> int: + text, _ = self.render(_collect_entities=False) + return sizeof(text) + + def __getitem__(self, item: slice) -> "Text": + if not isinstance(item, slice): + raise TypeError("Can only be sliced") + if (item.start is None or item.start == 0) and item.stop is None: + return self.replace(*self._body) + start = 0 if item.start is None else item.start + stop = len(self) if item.stop is None else item.stop + if start == stop: + return self.replace() + + nodes = [] + position = 0 + + for node in self._body: + node_size = len(node) + current_position = position + position += node_size + if position < start: + continue + if current_position > stop: + break + a = max((0, start - current_position)) + b = min((node_size, stop - current_position)) + new_node = node[a:b] + if not new_node: + continue + nodes.append(new_node) + + return self.replace(*nodes) + + +class Bold(Text): + type = MarkupElementType.STRONG + + +class Italic(Text): + type = MarkupElementType.EMPHASIZED + + +class Underline(Text): + type = MarkupElementType.UNDERLINE + + +class Strikethrough(Text): + type = MarkupElementType.STRIKETHROUGH + + +class Monospaced(Text): + type = MarkupElementType.MONOSPACED + + +class Link(Text): + type = MarkupElementType.LINK + + def __init__(self, *body: NodeType, url: str, **params: Any) -> None: + super().__init__(*body, url=url, **params) + + +class Mention(Text): + type = MarkupElementType.USER_MENTION + + def __init__(self, *body: NodeType, user_id: int, **params: Any) -> None: + super().__init__(*body, user_id=user_id, **params) + + +NODE_TYPES: dict[str | None, type[Text]] = { + Text.type: Text, + Bold.type: Bold, + Italic.type: Italic, + Underline.type: Underline, + Strikethrough.type: Strikethrough, + Link.type: Link, + Mention.type: Mention, + Monospaced.type: Monospaced, +} + + +def _apply_entity(entity: MarkupElements, *nodes: NodeType) -> NodeType: + """Apply single entity to text.""" + node_type = NODE_TYPES.get(entity.type, Text) + + entity_dict = dataclasses.asdict(entity) + for key in ("type", "from_", "length"): + entity_dict.pop(key, None) + + return node_type( + *nodes, + **entity_dict, + ) + + +def _unparse_entities( + text: bytes, + entities: list[MarkupElements], + offset: int | None = None, + length: int | None = None, +) -> Generator[NodeType, None, None]: + if offset is None: + offset = 0 + length = length or len(text) + + for index, entity in enumerate(entities): + if entity.offset * 2 < offset: + continue + if entity.offset * 2 > offset: + yield remove_surrogates(text[offset : entity.offset * 2]) + start = entity.offset * 2 + offset = entity.offset * 2 + entity.length * 2 + + sub_entities = list( + filter(lambda e: e.offset * 2 < (offset or 0), entities[index + 1 :]), + ) + yield _apply_entity( + entity, + *_unparse_entities(text, sub_entities, offset=start, length=offset), + ) + + if offset < length: + yield remove_surrogates(text[offset:length]) + + +def as_line(*items: NodeType, end: str = "\n", sep: str = "") -> Text: + r"""Wrap multiple nodes into line with :code:`\n` at the end of line.""" + if not items: + return Text(end) + if sep: + nodes = [] + for item in items[:-1]: + nodes.extend([item, sep]) + nodes.extend([items[-1], end]) + else: + nodes = [*items, end] + return Text(*nodes) + + +def as_list(*items: NodeType, sep: str = "\n") -> Text: + """Wrap each element to separated lines.""" + if not items: + return Text() + nodes = [] + for item in items[:-1]: + nodes.extend([item, sep]) + nodes.append(items[-1]) + return Text(*nodes) + + +def as_marked_list(*items: NodeType, marker: str = "- ") -> Text: + """Wrap elements as marked list.""" + return as_list(*(Text(marker, item) for item in items)) + + +def as_numbered_list(*items: NodeType, start: int = 1, fmt: str = "{}. ") -> Text: + """Wrap elements as numbered list.""" + return as_list( + *(Text(fmt.format(index), item) for index, item in enumerate(items, start)), + ) + + +def as_section(title: NodeType, *body: NodeType) -> Text: + """Wrap elements as simple section, section has title and body.""" + return Text(title, "\n", *body) + + +def as_marked_section( + title: NodeType, + *body: NodeType, + marker: str = "- ", +) -> Text: + """Wrap elements as section with marked list.""" + return as_section(title, as_marked_list(*body, marker=marker)) + + +def as_numbered_section( + title: NodeType, + *body: NodeType, + start: int = 1, + fmt: str = "{}. ", +) -> Text: + """Wrap elements as section with numbered list.""" + return as_section(title, as_numbered_list(*body, start=start, fmt=fmt)) + + +def as_key_value(key: NodeType, value: NodeType) -> Text: + """Wrap elements pair as key-value line. (:code:`{key}: {value}`).""" + return Text(Bold(key, ":"), " ", value) diff --git a/src/maxo/utils/helpers/attachments.py b/src/maxo/utils/helpers/attachments.py index d37bc0cd..e955f5e4 100644 --- a/src/maxo/utils/helpers/attachments.py +++ b/src/maxo/utils/helpers/attachments.py @@ -1,5 +1,3 @@ -# ruff: noqa: E501 - from typing import assert_never from maxo.types import ( @@ -46,10 +44,11 @@ def request_to_attachment(request: AttachmentsRequests) -> Attachments: ), ): raise TypeError( - f"Cannot convert {type(request).__name__} to an Attachment object directly. " - "Request objects lack server-generated data like IDs, URLs, or resolved user info. " - "This conversion is only possible for request types that have a 1:1 mapping of fields " - "(e.g., LocationAttachmentRequest, InlineKeyboardAttachmentRequest).", + f"Cannot convert {type(request).__name__} to an Attachment object " + "directly. Request objects lack server-generated data like IDs, " + "URLs, or resolved user info. This conversion is only possible for " + "request types that have a 1:1 mapping of fields (e.g., " + "LocationAttachmentRequest, InlineKeyboardAttachmentRequest).", ) assert_never(request) diff --git a/src/maxo/utils/link.py b/src/maxo/utils/link.py index 4bf763a9..230457eb 100644 --- a/src/maxo/utils/link.py +++ b/src/maxo/utils/link.py @@ -16,9 +16,9 @@ def _format_url( return url -def create_tg_link(link: str, **kwargs: Any) -> str: +def create_max_link(link: str, **kwargs: Any) -> str: return _format_url(f"max://{link}", **kwargs) -def create_telegram_link(*path: str, **kwargs: Any) -> str: +def create_max_http_link(*path: str, **kwargs: Any) -> str: return _format_url("https://max.ru", *path, **kwargs) diff --git a/src/maxo/utils/long_polling.py b/src/maxo/utils/long_polling.py index 157b00b0..ac2f67b5 100644 --- a/src/maxo/utils/long_polling.py +++ b/src/maxo/utils/long_polling.py @@ -136,6 +136,7 @@ async def _get_updates( type(exception).__name__, exception, ) + backoff.next() loggers.dispatcher.warning( "Sleep for %f seconds and try again... " "(tryings = %d, username = @%s, bot id = %d)", @@ -144,8 +145,7 @@ async def _get_updates( bot_username, bot_id, ) - await asyncio.sleep(backoff.current_delay) - backoff.next() + await backoff.sleep() continue if failed: diff --git a/src/maxo/webhook/__init__.py b/src/maxo/webhook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxo/webhook/adapters/__init__.py b/src/maxo/webhook/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxo/webhook/adapters/aiohttp/__init__.py b/src/maxo/webhook/adapters/aiohttp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxo/webhook/adapters/aiohttp/adapter.py b/src/maxo/webhook/adapters/aiohttp/adapter.py new file mode 100644 index 00000000..d8978c45 --- /dev/null +++ b/src/maxo/webhook/adapters/aiohttp/adapter.py @@ -0,0 +1,73 @@ +from asyncio import Transport +from collections.abc import Awaitable, Callable +from ipaddress import IPv4Address, IPv6Address +from json import JSONDecodeError +from typing import Any, cast + +from aiohttp import ContentTypeError +from aiohttp.web import Application, Request +from aiohttp.web_response import Response, json_response +from aiosignal import Signal + +from maxo.webhook.adapters.aiohttp.mapping import ( + AiohttpHeadersMapping, + AiohttpQueryMapping, +) +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter + + +class AiohttpBoundRequest(BoundRequest[Request]): + def __init__(self, request: Request) -> None: + super().__init__(request) + self._headers = AiohttpHeadersMapping(self.request.headers) + self._query_params = AiohttpQueryMapping(self.request.query) + + async def json(self) -> dict[str, Any]: + try: + return await self.request.json() + except ContentTypeError as e: + raise JSONDecodeError from e + + @property + def client_ip(self) -> IPv4Address | IPv6Address | str | None: + peer_name = cast(Transport, self.request.transport).get_extra_info("peername") + if peer_name: + return peer_name[0] + return None + + @property + def headers(self) -> AiohttpHeadersMapping: + return self._headers + + @property + def query_params(self) -> AiohttpQueryMapping: + return self._query_params + + @property + def path_params(self) -> dict[str, Any]: + return self.request.match_info + + +class AiohttpWebAdapter(WebAdapter): + def bind(self, request: Request) -> AiohttpBoundRequest: + return AiohttpBoundRequest(request=request) + + def register( + self, + app: Application, + path: str, + handler: Callable[[BoundRequest[Any]], Awaitable[Any]], + on_startup: Signal[Application] | None = None, + on_shutdown: Signal[Application] | None = None, + ) -> None: + async def endpoint(request: Request) -> Any: + return await handler(self.bind(request)) + + app.router.add_route(method="POST", path=path, handler=endpoint) + if on_startup is not None: + app.on_startup.append(on_startup) + if on_shutdown is not None: + app.on_shutdown.append(on_shutdown) + + def create_json_response(self, status: int, payload: dict[str, Any]) -> Response: + return json_response(status=status, data=payload) diff --git a/src/maxo/webhook/adapters/aiohttp/mapping.py b/src/maxo/webhook/adapters/aiohttp/mapping.py new file mode 100644 index 00000000..8c89ad6e --- /dev/null +++ b/src/maxo/webhook/adapters/aiohttp/mapping.py @@ -0,0 +1,15 @@ +from typing import Any + +from multidict import CIMultiDictProxy, MultiMapping + +from maxo.webhook.adapters.base_mapping import MappingABC + + +class AiohttpHeadersMapping(MappingABC[CIMultiDictProxy[str]]): + def getlist(self, name: str) -> list[Any]: + return self._mapping.getall(name, []) + + +class AiohttpQueryMapping(MappingABC[MultiMapping[str]]): + def getlist(self, name: str) -> list[Any]: + return self._mapping.getall(name, []) diff --git a/src/maxo/webhook/adapters/base_adapter.py b/src/maxo/webhook/adapters/base_adapter.py new file mode 100644 index 00000000..1d73e633 --- /dev/null +++ b/src/maxo/webhook/adapters/base_adapter.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from ipaddress import IPv4Address, IPv6Address +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from maxo.webhook.adapters.base_mapping import MappingABC + +R = TypeVar("R") + + +class BoundRequest(ABC, Generic[R]): + """Unified abstraction for requests across frameworks.""" + + __slots__ = ("request",) + + def __init__(self, request: R) -> None: + self.request = request + + @abstractmethod + async def json(self) -> dict[str, Any]: + """Get JSON data from request.""" + raise NotImplementedError + + @property + @abstractmethod + def client_ip(self) -> IPv4Address | IPv6Address | str | None: + """Get client IP address.""" + raise NotImplementedError + + @property + @abstractmethod + def headers(self) -> MappingABC: + """Get request headers.""" + raise NotImplementedError + + @property + @abstractmethod + def query_params(self) -> MappingABC: + """Get request query parameters.""" + raise NotImplementedError + + @property + @abstractmethod + def path_params(self) -> dict[str, Any]: + """Get request path parameters.""" + raise NotImplementedError + + +class WebAdapter(ABC): + """Abstraction for web framework adapters.""" + + @abstractmethod + def bind(self, request: Any) -> BoundRequest: + """Bind request to BoundRequest.""" + raise NotImplementedError + + @abstractmethod + def register( + self, + app: Any, + path: str, + handler: Callable[[BoundRequest], Awaitable[Any]], + on_startup: Callable[..., Awaitable[Any]] | None = None, + on_shutdown: Callable[..., Awaitable[Any]] | None = None, + ) -> None: + """ + Register webhook handler. + + :param app: Web application instance. + :param path: Webhook path. + :param handler: Handler function. + :param on_startup: Optional startup callback. + :param on_shutdown: Optional shutdown callback. + """ + raise NotImplementedError + + @abstractmethod + def create_json_response(self, status: int, payload: dict[str, Any]) -> Any: + """Create JSON response with given status and data.""" + raise NotImplementedError diff --git a/src/maxo/webhook/adapters/base_mapping.py b/src/maxo/webhook/adapters/base_mapping.py new file mode 100644 index 00000000..e0f1a28e --- /dev/null +++ b/src/maxo/webhook/adapters/base_mapping.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView +from typing import Any, Generic, TypeVar + +M = TypeVar("M", bound=Mapping) + + +class MappingABC(ABC, Generic[M]): + def __init__(self, mapping: M) -> None: + self._mapping = mapping + + def get(self, name: str, default: Any = None) -> Any: + return self._mapping.get(name, default) + + @abstractmethod + def getlist(self, name: str) -> list[Any]: + raise NotImplementedError + + def __getitem__(self, name: str) -> Any: + return self._mapping[name] + + def __contains__(self, name: str) -> bool: + return name in self.keys() + + def __len__(self) -> int: + return len(self._mapping) + + def __iter__(self) -> Iterator: + return iter(self._mapping) + + def keys(self) -> KeysView: + return self._mapping.keys() + + def values(self) -> ValuesView: + return self._mapping.values() + + def items(self) -> ItemsView: + return self._mapping.items() diff --git a/src/maxo/webhook/adapters/fastapi/__init__.py b/src/maxo/webhook/adapters/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxo/webhook/adapters/fastapi/adapter.py b/src/maxo/webhook/adapters/fastapi/adapter.py new file mode 100644 index 00000000..25f48ed3 --- /dev/null +++ b/src/maxo/webhook/adapters/fastapi/adapter.py @@ -0,0 +1,70 @@ +from collections.abc import Awaitable, Callable +from ipaddress import IPv4Address, IPv6Address +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter +from maxo.webhook.adapters.fastapi.mapping import ( + FastAPIHeadersMapping, + FastAPIQueryMapping, +) + + +class FastAPIBoundRequest(BoundRequest[Request]): + def __init__(self, request: Request) -> None: + super().__init__(request) + self._headers = FastAPIHeadersMapping(self.request.headers) + self._query_params = FastAPIQueryMapping(self.request.query_params) + + async def json(self) -> dict[str, Any]: + return await self.request.json() + + @property + def client_ip(self) -> IPv4Address | IPv6Address | str | None: + if self.request.client: + return self.request.client.host + return None + + @property + def headers(self) -> FastAPIHeadersMapping: + return self._headers + + @property + def query_params(self) -> FastAPIQueryMapping: + return self._query_params + + @property + def path_params(self) -> dict[str, Any]: + return self.request.path_params + + +class FastApiWebAdapter(WebAdapter): + def bind(self, request: Request) -> FastAPIBoundRequest: + return FastAPIBoundRequest(request=request) + + def register( + self, + app: FastAPI, + path: str, + handler: Callable[[BoundRequest], Awaitable[Any]], + on_startup: Callable[..., Awaitable[Any]] | None = None, + on_shutdown: Callable[..., Awaitable[Any]] | None = None, + ) -> None: + async def endpoint(request: Request) -> Any: + return await handler(self.bind(request)) + + app.add_api_route(path=path, endpoint=endpoint, methods=["POST"]) + + if on_startup is not None: + app.add_event_handler("startup", on_startup) + if on_shutdown is not None: + app.add_event_handler("shutdown", on_shutdown) + + def create_json_response( + self, + status: int, + payload: dict[str, Any], + ) -> JSONResponse: + return JSONResponse(status_code=status, content=payload) diff --git a/src/maxo/webhook/adapters/fastapi/mapping.py b/src/maxo/webhook/adapters/fastapi/mapping.py new file mode 100644 index 00000000..1435cbd5 --- /dev/null +++ b/src/maxo/webhook/adapters/fastapi/mapping.py @@ -0,0 +1,15 @@ +from typing import Any + +from starlette.datastructures import Headers, QueryParams + +from maxo.webhook.adapters.base_mapping import MappingABC + + +class FastAPIHeadersMapping(MappingABC[Headers]): + def getlist(self, name: str) -> list[Any]: + return self._mapping.getlist(name) + + +class FastAPIQueryMapping(MappingABC[QueryParams]): + def getlist(self, name: str) -> list[Any]: + return self._mapping.getlist(name) diff --git a/src/maxo/webhook/config/__init__.py b/src/maxo/webhook/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxo/webhook/config/bot.py b/src/maxo/webhook/config/bot.py new file mode 100644 index 00000000..5af49bbc --- /dev/null +++ b/src/maxo/webhook/config/bot.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + +from maxo.bot.defaults import BotDefaults + + +@dataclass +class BotConfig: + defaults: BotDefaults | None = None + """Default values for bot API calls.""" diff --git a/src/maxo/webhook/engines/__init__.py b/src/maxo/webhook/engines/__init__.py new file mode 100644 index 00000000..51ea2b1f --- /dev/null +++ b/src/maxo/webhook/engines/__init__.py @@ -0,0 +1,9 @@ +from maxo.webhook.engines.base import WebhookEngine +from maxo.webhook.engines.simple import SimpleEngine +from maxo.webhook.engines.token import TokenEngine + +__all__ = ( + "SimpleEngine", + "TokenEngine", + "WebhookEngine", +) diff --git a/src/maxo/webhook/engines/base.py b/src/maxo/webhook/engines/base.py new file mode 100644 index 00000000..9670aee5 --- /dev/null +++ b/src/maxo/webhook/engines/base.py @@ -0,0 +1,138 @@ +import asyncio +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any + +from maxo import Bot, Dispatcher +from maxo.bot.methods.base import MaxoMethod +from maxo.routing.signals import MaxoUpdate +from maxo.routing.updates import Updates +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter +from maxo.webhook.routing.base import BaseRouting +from maxo.webhook.security.security import Security + + +class WebhookEngine(ABC): + """ + Base webhook engine for processing Telegram bot updates. + + Handles incoming webhook requests, bot resolution, security checks, + routing, and dispatching updates to the aiogram dispatcher. Supports + both synchronous and background processing. + """ + + def __init__( + self, + dispatcher: Dispatcher, + /, + web_adapter: WebAdapter, + routing: BaseRouting, + security: Security | None = None, + handle_in_background: bool = True, + ) -> None: + self.dispatcher = dispatcher + self.web_adapter = web_adapter + self.routing = routing + self.security = security + self.handle_in_background = handle_in_background + self._background_feed_update_tasks: set[asyncio.Task[Any]] = set() + + @abstractmethod + def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: + raise NotImplementedError + + @abstractmethod + async def set_webhook(self, *args: Any, **kwargs: Any) -> Bot: + raise NotImplementedError + + @abstractmethod + async def on_startup(self, app: Any, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + @abstractmethod + async def on_shutdown(self, app: Any, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + def _build_workflow_data(self, app: Any, **kwargs: Any) -> dict[str, Any]: + """Build workflow data for startup/shutdown events.""" + return { + "app": app, + "dispatcher": self.dispatcher, + "webhook_engine": self, + **self.dispatcher.workflow_data, + **kwargs, + } + + async def handle_request(self, bound_request: BoundRequest) -> Any: + bot = self._get_bot_from_request(bound_request) + if bot is None: + return self.web_adapter.create_json_response( + status=400, + payload={"detail": "Bot not found"}, + ) + + if self.security is not None and not await self.security.verify( + bot=bot, + bound_request=bound_request, + ): + return self.web_adapter.create_json_response( + status=403, + payload={"detail": "Forbidden"}, + ) + + try: + raw_update = await bound_request.json() + except JSONDecodeError: + return self.web_adapter.create_json_response( + status=400, + payload={"detail": "Bad request"}, + ) + + update = MaxoUpdate(update=bot.retort.load(raw_update, Updates)) + + if self.handle_in_background: + return await self._handle_request_background(bot=bot, update=update) + return await self._handle_request(bot=bot, update=update) + + def register(self, app: Any) -> None: + self.web_adapter.register( + app=app, + path=self.routing.path, + handler=self.handle_request, + on_startup=self.on_startup, + on_shutdown=self.on_shutdown, + ) + + async def _handle_request( + self, + bot: Bot, + update: MaxoUpdate[Any], + ) -> dict[str, Any]: + result = await self.dispatcher.feed_max_update(bot=bot, update=update) + + if not isinstance(result, MaxoMethod): + return self.web_adapter.create_json_response(status=200, payload={}) + + await bot.silent_call_method(method=result) + return self.web_adapter.create_json_response(status=200, payload={}) + + async def _background_feed_update(self, bot: Bot, update: MaxoUpdate[Any]) -> None: + result = await self.dispatcher.feed_max_update( + bot=bot, + update=update, + ) # **self.data + if isinstance(result, MaxoMethod): + await bot.silent_call_method(method=result) + + async def _handle_request_background( + self, + bot: Bot, + update: MaxoUpdate[Any], + ) -> Any: + feed_update_task = asyncio.create_task( + self._background_feed_update(bot=bot, update=update), + ) + self._background_feed_update_tasks.add(feed_update_task) + feed_update_task.add_done_callback(self._background_feed_update_tasks.discard) + + return self.web_adapter.create_json_response(status=200, payload={}) diff --git a/src/maxo/webhook/engines/simple.py b/src/maxo/webhook/engines/simple.py new file mode 100644 index 00000000..ac01c729 --- /dev/null +++ b/src/maxo/webhook/engines/simple.py @@ -0,0 +1,90 @@ +from typing import Any + +from maxo import Bot, Dispatcher +from maxo.routing.signals import ( + AfterShutdown, + AfterStartup, + BeforeShutdown, + BeforeStartup, +) +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter +from maxo.webhook.engines.base import WebhookEngine +from maxo.webhook.routing.base import BaseRouting +from maxo.webhook.security.security import Security + + +class SimpleEngine(WebhookEngine): + """ + Simple webhook engine for single-bot applications. + + Uses a single Bot instance for all webhook requests. + Ideal for applications that handle only one bot. + """ + + def __init__( + self, + dispatcher: Dispatcher, + bot: Bot, + /, + web_adapter: WebAdapter, + routing: BaseRouting, + security: Security | None = None, + handle_in_background: bool = True, + ) -> None: + self.bot = bot + super().__init__( + dispatcher, + web_adapter=web_adapter, + routing=routing, + security=security, + handle_in_background=handle_in_background, + ) + + def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: + """ + Return the single Bot instance for any request. + + :param bound_request: The incoming bound request. + :return: The single Bot instance + """ + return self.bot + + async def set_webhook( + self, + *, + update_types: list[str] | None = None, + ) -> Bot: + """Set the webhook for the Bot instance.""" + secret_token = None + if self.security is not None: + secret_token = await self.security.get_secret_token(bot=self.bot) + + await self.bot.subscribe( + url=self.routing.webhook_point(self.bot), + secret=secret_token, + update_types=update_types, + ) + return self.bot + + async def on_startup(self, app: Any, *args: Any, **kwargs: Any) -> None: + """Call on application startup. Emits dispatcher startup event.""" + workflow_data = self._build_workflow_data(app=app, bot=self.bot, **kwargs) + self.dispatcher.workflow_data.update(workflow_data) + + await self.dispatcher.feed_signal(BeforeStartup(), self.bot) + await self.dispatcher.feed_signal(AfterStartup(), self.bot) + + async def on_shutdown(self, app: Any, *args: Any, **kwargs: Any) -> None: + """ + Call on application shutdown. + + Emits dispatcher shutdown event and closes bot session. + """ + workflow_data = self._build_workflow_data(app=app, bot=self.bot, **kwargs) + self.dispatcher.workflow_data.update(workflow_data) + + await self.dispatcher.feed_signal(BeforeShutdown(), self.bot) + + await self.bot.close() + + await self.dispatcher.feed_signal(AfterShutdown(), self.bot) diff --git a/src/maxo/webhook/engines/token.py b/src/maxo/webhook/engines/token.py new file mode 100644 index 00000000..5ef5defb --- /dev/null +++ b/src/maxo/webhook/engines/token.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Any + +from maxo import Bot, Dispatcher +from maxo.routing.signals import ( + AfterShutdown, + AfterStartup, + BeforeShutdown, + BeforeStartup, +) +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter +from maxo.webhook.config.bot import BotConfig +from maxo.webhook.engines.base import WebhookEngine +from maxo.webhook.routing.base import TokenRouting +from maxo.webhook.security.security import Security + + +class TokenEngine(WebhookEngine): + """ + Multi-bot webhook engine with dynamic bot resolution. + + Resolves Bot instances from request tokens. + Creates and caches Bot instances on-demand. Suitable for multi-tenant applications. + """ + + def __init__( + self, + dispatcher: Dispatcher, + /, + web_adapter: WebAdapter, + routing: TokenRouting, + security: Security | None = None, + bot_config: BotConfig | None = None, + handle_in_background: bool = True, + ) -> None: + super().__init__( + dispatcher, + web_adapter=web_adapter, + routing=routing, + security=security, + handle_in_background=handle_in_background, + ) + self.routing: TokenRouting = routing # for type checker + self.bot_config = bot_config or BotConfig() + self._bots: dict[str, Bot] = {} + + def _get_bot_from_request(self, bound_request: BoundRequest) -> Bot | None: + """ + Get a :class:`Bot` instance from request by token. + + If the bot is not yet created, it will be created automatically. + + :param bound_request: Incoming request + :return: Bot instance or None + """ + token = self.routing.extract_token(bound_request) + if not token: + return None + return self.get_bot(token) + + def get_bot(self, token: str) -> Bot: + """ + Resolve or create a Bot instance by token and cache it. + + :param token: The bot token + :return: Bot + + .. note:: + To connect the bot to Telegram API and set up webhook, + use :meth:`set_webhook`. + """ + bot = self._bots.get(token) + if not bot: + bot = Bot( + token=token, + defaults=self.bot_config.defaults, + ) + self._bots[token] = bot + return bot + + async def set_webhook( + self, + token: str, + *, + update_types: list[str] | None = None, + ) -> Bot: + """Set the webhook for the Bot instance resolved by token.""" + bot = self.get_bot(token) + secret_token = None + if self.security is not None: + secret_token = await self.security.get_secret_token(bot=bot) + + await bot.subscribe( + url=self.routing.webhook_point(bot), + secret=secret_token, + update_types=update_types, + ) + return bot + + async def on_startup( + self, + app: Any, + *args: Any, + bots: set[Bot] | None = None, + **kwargs: Any, + ) -> None: + all_bots = ( + set(bots) | set(self._bots.values()) if bots else set(self._bots.values()) + ) + workflow_data = self._build_workflow_data(app=app, bots=all_bots, **kwargs) + self.dispatcher.workflow_data.update(workflow_data) + + await self.dispatcher.feed_signal(BeforeStartup()) + await self.dispatcher.feed_signal(AfterStartup()) + + async def on_shutdown(self, app: Any, *args: Any, **kwargs: Any) -> None: + workflow_data = self._build_workflow_data( + app=app, + bots=set(self._bots.values()), + **kwargs, + ) + self.dispatcher.workflow_data.update(workflow_data) + + await self.dispatcher.feed_signal(BeforeShutdown()) + + for bot in self._bots.values(): + await bot.close() + self._bots.clear() + + await self.dispatcher.feed_signal(AfterShutdown()) diff --git a/src/maxo/webhook/routing/__init__.py b/src/maxo/webhook/routing/__init__.py new file mode 100644 index 00000000..fea1ce0f --- /dev/null +++ b/src/maxo/webhook/routing/__init__.py @@ -0,0 +1,12 @@ +from maxo.webhook.routing.base import BaseRouting, TokenRouting +from maxo.webhook.routing.path import PathRouting +from maxo.webhook.routing.query import QueryRouting +from maxo.webhook.routing.static import StaticRouting + +__all__ = ( + "BaseRouting", + "PathRouting", + "QueryRouting", + "StaticRouting", + "TokenRouting", +) diff --git a/src/maxo/webhook/routing/base.py b/src/maxo/webhook/routing/base.py new file mode 100644 index 00000000..29660a92 --- /dev/null +++ b/src/maxo/webhook/routing/base.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod + +from yarl import URL + +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest + + +class BaseRouting(ABC): + """ + Abstract base class for webhook routing. + + Defines how webhook URLs are constructed and how keys (tokens) + are extracted from incoming requests. + """ + + def __init__(self, url: str) -> None: + self.url = URL(url) + self.base = self.url.origin() + self.path = self.url.path + + @abstractmethod + def webhook_point(self, bot: Bot) -> str: + """Get the webhook URL for the given bot.""" + raise NotImplementedError + + +class TokenRouting(BaseRouting, ABC): + """Routing by token parameter.""" + + def __init__(self, url: str, param: str = "bot_token") -> None: + super().__init__(url=url) + self.param = param + + @abstractmethod + def extract_token(self, bound_request: BoundRequest) -> str | None: + """Extract the bot token from the incoming request.""" + raise NotImplementedError diff --git a/src/maxo/webhook/routing/path.py b/src/maxo/webhook/routing/path.py new file mode 100644 index 00000000..b053d34d --- /dev/null +++ b/src/maxo/webhook/routing/path.py @@ -0,0 +1,31 @@ +from typing import Any + +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.routing.base import TokenRouting + + +class PathRouting(TokenRouting): + """ + Path-based routing for webhook requests. + + Extracts the bot token from a path parameter in the URL. + Example: https://example.com/webhook/{token} will extract the token from + the path segment. + """ + + def __init__(self, url: str, param: str = "bot_token") -> None: + super().__init__(url=url, param=param) + self.url_template = self.url.human_repr() + + if f"{{{self.param}}}" not in self.url_template: + raise ValueError( + f"Parameter '{self.param}' not found in URL template. " + f"Expected placeholder '{{{self.param}}}' in: {self.url_template}", + ) + + def webhook_point(self, bot: Bot) -> str: + return self.url_template.format_map({self.param: bot.token}) + + def extract_token(self, bound_request: BoundRequest[Any]) -> str | None: + return bound_request.path_params.get(self.param) diff --git a/src/maxo/webhook/routing/query.py b/src/maxo/webhook/routing/query.py new file mode 100644 index 00000000..20881cb8 --- /dev/null +++ b/src/maxo/webhook/routing/query.py @@ -0,0 +1,21 @@ +from typing import Any + +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.routing.base import TokenRouting + + +class QueryRouting(TokenRouting): + """ + Routing strategy based on the URL query parameter. + + Extracts the bot token from a query parameter in the URL. + Example: https://example.com/webhook?token=f9LHodD0 will extract the + token from the query string. + """ + + def webhook_point(self, bot: Bot) -> str: + return self.url.update_query({self.param: bot.token}).human_repr() + + def extract_token(self, bound_request: BoundRequest[Any]) -> str | None: + return bound_request.query_params.get(self.param) diff --git a/src/maxo/webhook/routing/static.py b/src/maxo/webhook/routing/static.py new file mode 100644 index 00000000..43e6075d --- /dev/null +++ b/src/maxo/webhook/routing/static.py @@ -0,0 +1,13 @@ +from maxo import Bot +from maxo.webhook.routing.base import BaseRouting + + +class StaticRouting(BaseRouting): + """Routing without token, static webhook URL.""" + + def __init__(self, url: str) -> None: + super().__init__(url=url) + self.url_template = self.url.human_repr() + + def webhook_point(self, bot: Bot) -> str: + return self.url_template diff --git a/src/maxo/webhook/security/__init__.py b/src/maxo/webhook/security/__init__.py new file mode 100644 index 00000000..f669dbcc --- /dev/null +++ b/src/maxo/webhook/security/__init__.py @@ -0,0 +1,10 @@ +from maxo.webhook.security.base_check import SecurityCheck +from maxo.webhook.security.secret_token import SecretToken, StaticSecretToken +from maxo.webhook.security.security import Security + +__all__ = ( + "SecretToken", + "Security", + "SecurityCheck", + "StaticSecretToken", +) diff --git a/src/maxo/webhook/security/base_check.py b/src/maxo/webhook/security/base_check.py new file mode 100644 index 00000000..3fc0d817 --- /dev/null +++ b/src/maxo/webhook/security/base_check.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest + + +class SecurityCheck(ABC): + """Abstract class for security check on webhook requests.""" + + @abstractmethod + async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + """ + Perform a security check. + + :return: True if the check passes, False otherwise. + """ + raise NotImplementedError diff --git a/src/maxo/webhook/security/secret_token.py b/src/maxo/webhook/security/secret_token.py new file mode 100644 index 00000000..08d8e41b --- /dev/null +++ b/src/maxo/webhook/security/secret_token.py @@ -0,0 +1,51 @@ +import re +from abc import abstractmethod +from hmac import compare_digest + +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.security.base_check import SecurityCheck + +SECRET_HEADER = "X-Max-Bot-Api-Secret" # noqa: S105 +SECRET_TOKEN_PATTERN = re.compile(r"^[A-Za-z0-9-]{5,256}$") + + +class SecretToken(SecurityCheck): + """Abstract base class for secret token verification in webhook requests.""" + + secret_header: str = SECRET_HEADER + + @abstractmethod + def secret_token(self, bot: Bot) -> str: + """Return the secret token for the given bot.""" + raise NotImplementedError + + +class StaticSecretToken(SecretToken): + """ + Static secret token implementation for webhook security. + + Token format: 5-256 characters, only `A-Z, a-z, 0-9, -` are allowed. + See: https://dev.max.ru/docs-api/methods/POST/subscriptions + """ + + def __init__(self, token: str) -> None: + if not SECRET_TOKEN_PATTERN.match(token): + raise ValueError( + "Invalid secret token format. Must be 1-256 characters, " + "only A-Z, a-z, 0-9, -", + ) + self._token = token + + async def verify( + self, + bot: Bot, + bound_request: BoundRequest, + ) -> bool: + incoming = bound_request.headers.get(self.secret_header) + if incoming is None: + return False + return compare_digest(incoming, self._token) + + def secret_token(self, bot: Bot) -> str: + return self._token diff --git a/src/maxo/webhook/security/security.py b/src/maxo/webhook/security/security.py new file mode 100644 index 00000000..de08d27a --- /dev/null +++ b/src/maxo/webhook/security/security.py @@ -0,0 +1,47 @@ +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.security.base_check import SecurityCheck +from maxo.webhook.security.secret_token import SecretToken + + +class Security: + """ + Security management for webhook requests. + + Provides methods to verify requests and manage secret tokens. + """ + + def __init__( + self, + *checks: SecurityCheck, + secret_token: SecretToken | None = None, + ) -> None: + self._secret_token = secret_token + self._checks: tuple[SecurityCheck, ...] = checks + + async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + """ + Verify the security of a webhook request. + + :return: True if the request passes security checks, False otherwise. + """ + if self._secret_token is not None: + ok = await self._secret_token.verify(bot=bot, bound_request=bound_request) + if not ok: + return False + + for checker in self._checks: + if not await checker.verify(bot=bot, bound_request=bound_request): + return False + + return True + + async def get_secret_token(self, *, bot: Bot) -> str | None: + """ + Get the secret token for the given bot, if configured. + + :return: The secret token as a string. + """ + if self._secret_token is None: + return None + return self._secret_token.secret_token(bot=bot) diff --git a/tests/maxo/routing/conftest.py b/tests/maxo/routing/conftest.py new file mode 100644 index 00000000..517fe2b2 --- /dev/null +++ b/tests/maxo/routing/conftest.py @@ -0,0 +1,32 @@ +from typing import Any + +import pytest + +from maxo import Bot, Ctx + + +class MockBotInfo: + def __init__(self, user_id: int) -> None: + self.user_id = user_id + + +class MockBotState: + def __init__(self, user_id: int) -> None: + self.info = MockBotInfo(user_id) + + +class MockBot: + def __init__(self, user_id: int = 1) -> None: + self.state = MockBotState(user_id) + + +@pytest.fixture +def bot() -> MockBot: + return MockBot() + + +@pytest.fixture +def ctx(update: Any, bot: Bot) -> Ctx: + ctx = Ctx({"update": update, "bot": bot}) + ctx["ctx"] = ctx + return ctx diff --git a/tests/maxo/routing/test_middleware.py b/tests/maxo/routing/test_middleware.py index 7b512f57..e846eea7 100644 --- a/tests/maxo/routing/test_middleware.py +++ b/tests/maxo/routing/test_middleware.py @@ -9,6 +9,7 @@ from maxo.routing.dispatcher import Dispatcher from maxo.routing.filters import AlwaysFalseFilter, AlwaysTrueFilter, BaseFilter from maxo.routing.interfaces import NextMiddleware +from maxo.routing.middlewares.fsm_context import FSMContextMiddleware from maxo.routing.routers.simple import Router from maxo.routing.sentinels import UNHANDLED from maxo.routing.signals import BeforeStartup @@ -16,28 +17,8 @@ from maxo.types import Message, MessageBody, Recipient, User -class MockBotInfo: - def __init__(self, user_id: int) -> None: - self.user_id = user_id - - -class MockBotState: - def __init__(self, user_id: int) -> None: - self.info = MockBotInfo(user_id) - - -class MockBot: - def __init__(self, user_id: int = 1) -> None: - self.state = MockBotState(user_id) - - @pytest.fixture -def bot() -> MockBot: - return MockBot() - - -@pytest.fixture -def message_created_update() -> MessageCreated: +def update() -> MessageCreated: return MessageCreated( message=Message( body=MessageBody(mid="test", seq=1), @@ -54,13 +35,6 @@ def message_created_update() -> MessageCreated: ) -@pytest.fixture -def context(message_created_update: MessageCreated, bot: MockBot) -> Ctx: - ctx = Ctx({"update": message_created_update, "bot": bot}) - ctx["ctx"] = ctx - return ctx - - async def handler(_: Any, ctx: Ctx) -> Any: ctx["execution_order"].append("handler") return "OK" @@ -81,7 +55,7 @@ async def middleware( @pytest.mark.asyncio -async def test_middleware_execution_order(context: Ctx) -> None: +async def test_middleware_execution_order(ctx: Ctx) -> None: dp = Dispatcher() dp.message_created.handler(handler) @@ -95,11 +69,11 @@ async def test_middleware_execution_order(context: Ctx) -> None: ) await dp.feed_signal(BeforeStartup()) - context["execution_order"] = [] - result = await dp.trigger(context) + ctx["execution_order"] = [] + result = await dp.trigger(ctx) assert result == "OK" - assert context["execution_order"] == [ + assert ctx["execution_order"] == [ "outer_1_pre", "outer_2_pre", "inner_1_pre", @@ -113,7 +87,7 @@ async def test_middleware_execution_order(context: Ctx) -> None: @pytest.mark.asyncio -async def test_middleware_stops_propagation(context: Ctx) -> None: +async def test_middleware_stops_propagation(ctx: Ctx) -> None: dp = Dispatcher() async def stopping_middleware( @@ -129,11 +103,11 @@ async def stopping_middleware( dp.message_created.handler(handler) await dp.feed_signal(BeforeStartup()) - context["execution_order"] = [] - result = await dp.trigger(context) + ctx["execution_order"] = [] + result = await dp.trigger(ctx) assert result == "STOPPED" - assert context["execution_order"] == [ + assert ctx["execution_order"] == [ "outer_pre", "stopping_middleware", "outer_post", @@ -141,7 +115,7 @@ async def stopping_middleware( @pytest.mark.asyncio -async def test_outer_middleware_runs_if_filter_fails(context: Ctx) -> None: +async def test_outer_middleware_runs_if_filter_fails(ctx: Ctx) -> None: dp = Dispatcher() class UpdateFilter(BaseFilter[MessageCreated]): @@ -154,11 +128,11 @@ async def __call__(self, update: MessageCreated, ctx: Ctx) -> bool: dp.message_created.middleware.outer.add(middleware_factory("outer")) await dp.feed_signal(BeforeStartup()) - context["execution_order"] = [] - result = await dp.trigger(context) + ctx["execution_order"] = [] + result = await dp.trigger(ctx) assert result is UNHANDLED - assert context["execution_order"] == [ + assert ctx["execution_order"] == [ "outer_pre", "filter", "outer_post", @@ -166,7 +140,7 @@ async def __call__(self, update: MessageCreated, ctx: Ctx) -> bool: @pytest.mark.asyncio -async def test_nested_router_middleware_execution(context: Ctx) -> None: +async def test_nested_router_middleware_execution(ctx: Ctx) -> None: dp = Dispatcher() root_router = Router("root") child_router = Router("child") @@ -181,11 +155,11 @@ async def test_nested_router_middleware_execution(context: Ctx) -> None: child_router.message_created.handler(handler) await dp.feed_signal(BeforeStartup()) - context["execution_order"] = [] - result = await dp.trigger(context) + ctx["execution_order"] = [] + result = await dp.trigger(ctx) assert result == "OK" - assert context["execution_order"] == [ + assert ctx["execution_order"] == [ "dp_pre", "root_pre", "child_pre", @@ -199,7 +173,7 @@ async def test_nested_router_middleware_execution(context: Ctx) -> None: @pytest.mark.asyncio -async def test_one_call_per_event_with_routers(context: Ctx) -> None: +async def test_one_call_per_event_with_routers(ctx: Ctx) -> None: async def outer_middleware( update: MessageCreated, ctx: Ctx, @@ -226,10 +200,30 @@ async def successful_handler(_: Any, ctx: Ctx) -> str: return "OK" await dp.feed_signal(BeforeStartup()) - context["calls"] = 0 - context["handler_calls"] = 0 - result = await dp.trigger(context) + ctx["calls"] = 0 + ctx["handler_calls"] = 0 + result = await dp.trigger(ctx) assert result == "OK" - assert context["calls"] == 1 - assert context["handler_calls"] == 1 + assert ctx["calls"] == 1 + assert ctx["handler_calls"] == 1 + + +@pytest.mark.asyncio +async def test_fsm_disabled() -> None: + dp = Dispatcher(disable_fsm=True) + + assert not any( + isinstance(middleware, FSMContextMiddleware) + for middleware in dp.update.middleware.outer.middlewares + ) + + +@pytest.mark.asyncio +async def test_fsm_enabled_by_default() -> None: + dp = Dispatcher() + + assert any( + isinstance(m, FSMContextMiddleware) + for m in dp.update.middleware.outer.middlewares + ) diff --git a/tests/maxo/routing/test_signals.py b/tests/maxo/routing/test_signals.py new file mode 100644 index 00000000..f113bea7 --- /dev/null +++ b/tests/maxo/routing/test_signals.py @@ -0,0 +1,174 @@ +from datetime import UTC, datetime + +import pytest + +from maxo import Router +from maxo.enums import ChatType +from maxo.routing.dispatcher import Dispatcher +from maxo.routing.middlewares.state import ( + EmptyMiddlewareManagerState, + StartedMiddlewareManagerState, +) +from maxo.routing.observers.state import EmptyObserverState, StartedObserverState +from maxo.routing.signals import ( + AfterShutdown, + AfterStartup, + BeforeShutdown, + BeforeStartup, + MaxoUpdate, +) +from maxo.routing.updates.message_created import MessageCreated +from maxo.types import Message, MessageBody, Recipient, User + + +@pytest.fixture +def update() -> MessageCreated: + return MessageCreated( + message=Message( + body=MessageBody(mid="test", seq=1), + recipient=Recipient(chat_type=ChatType.DIALOG, chat_id=1), + timestamp=datetime.now(UTC), + sender=User( + user_id=1, + first_name="Test", + is_bot=False, + last_activity_time=datetime.now(UTC), + ), + ), + timestamp=datetime.now(UTC), + ) + + +@pytest.mark.asyncio +async def test_dp_signals() -> None: + dp = Dispatcher() + order = [] + + @dp.before_startup() + async def before_startup() -> None: + order.append("before_startup") + + @dp.after_startup() + async def after_startup() -> None: + order.append("after_startup") + + @dp.before_shutdown() + async def before_shutdown() -> None: + order.append("before_shutdown") + + @dp.after_shutdown() + async def after_shutdown() -> None: + order.append("after_shutdown") + + await dp.feed_signal(BeforeStartup()) + await dp.feed_signal(AfterStartup()) + await dp.feed_signal(BeforeShutdown()) + await dp.feed_signal(AfterShutdown()) + + assert order == [ + "before_startup", + "after_startup", + "before_shutdown", + "after_shutdown", + ] + + +@pytest.mark.asyncio +async def test_included_router_signals() -> None: + dp = Dispatcher() + deep_router = Router() + deeper_router = Router() + deepest_router = Router() + + dp.include(deep_router) + deep_router.include(deeper_router) + deeper_router.include(deepest_router) + + order = [] + + @dp.before_startup() + @deep_router.before_startup() + @deeper_router.before_startup() + @deepest_router.before_startup() + async def before_startup() -> None: + order.append("before_startup") + + @dp.after_startup() + @deep_router.after_startup() + async def after_startup() -> None: + order.append("after_startup") + + @deep_router.before_shutdown() + @deeper_router.before_shutdown() + async def before_shutdown() -> None: + order.append("before_shutdown") + + @deeper_router.after_shutdown() + @deepest_router.after_shutdown() + async def after_shutdown() -> None: + order.append("after_shutdown") + + await dp.feed_signal(BeforeStartup()) + await dp.feed_signal(AfterStartup()) + await dp.feed_signal(BeforeShutdown()) + await dp.feed_signal(AfterShutdown()) + + assert order == [ + *(["before_startup"] * 4), + *(["after_startup"] * 2), + *(["before_shutdown"] * 2), + *(["after_shutdown"] * 2), + ] + + +@pytest.mark.asyncio +async def test_included_router_observers_state() -> None: + # ruff: noqa: E721 + dp = Dispatcher() + deep_router = Router() + deeper_router = Router() + + dp.include(deep_router) + deep_router.include(deeper_router) + + for router in (dp, deep_router, deeper_router): + for observer in router.observers.values(): + assert type(observer.state) == EmptyObserverState + assert type(observer.middleware.inner.state) == EmptyMiddlewareManagerState + + await dp.feed_signal(BeforeStartup()) + + for router in (dp, deep_router, deeper_router): + for observer in router.observers.values(): + assert type(observer.state) == StartedObserverState + assert ( + type(observer.middleware.inner.state) == StartedMiddlewareManagerState + ) + + await dp.feed_signal(AfterStartup()) + await dp.feed_signal(BeforeShutdown()) + + for router in (dp, deep_router, deeper_router): + for observer in router.observers.values(): + assert type(observer.state) == EmptyObserverState + assert type(observer.middleware.inner.state) == EmptyMiddlewareManagerState + + await dp.feed_signal(AfterShutdown()) + + +@pytest.mark.asyncio +async def test_dp_update_handler(update: MessageCreated, bot) -> None: + dp = Dispatcher() + + triggered = False + + @dp.update() + async def update_handler(_) -> None: + nonlocal triggered + triggered = True + + await dp.feed_signal(BeforeStartup()) + await dp.feed_signal(AfterStartup()) + + await dp.feed_max_update(MaxoUpdate(update=update), bot) + assert triggered diff --git a/tests/maxo/utils/test_formatting.py b/tests/maxo/utils/test_formatting.py new file mode 100644 index 00000000..42575fd8 --- /dev/null +++ b/tests/maxo/utils/test_formatting.py @@ -0,0 +1,269 @@ +import pytest + +from maxo.enums import MarkupElementType +from maxo.types.emphasized_markup import EmphasizedMarkup +from maxo.types.markup_element import MarkupElement +from maxo.types.strong_markup import StrongMarkup +from maxo.types.underline_markup import UnderlineMarkup +from maxo.types.user_mention_markup import UserMentionMarkup +from maxo.utils.formatting import ( + Bold, + Italic, + Link, + Mention, + Monospaced, + Strikethrough, + Text, + Underline, + _apply_entity, + as_key_value, + as_line, + as_list, + as_marked_list, + as_marked_section, + as_numbered_list, + as_numbered_section, + as_section, +) +from maxo.utils.text_decorations import html_decoration + + +class TestNode: + @pytest.mark.parametrize( + ("node", "result"), + [ + ( + Text("test"), + "test", + ), + ( + Bold("test"), + "test", + ), + ( + Italic("test"), + "test", + ), + ( + Underline("test"), + "test", + ), + ( + Strikethrough("test"), + "test", + ), + ( + Monospaced("test"), + "
test
", + ), + ( + Link("test", url="https://example.com"), + 'test', + ), + ( + Mention("test", user_id=42), + 'test', + ), + ], + ) + def test_render_plain_only(self, node: Text, result: str): + text, entities = node.render() + if node.type: + assert len(entities) == 1 + entity = entities[0] + assert entity.type == node.type + + content = html_decoration.unparse(text, entities) + assert content == result + + def test_render_text(self): + node = Text("Hello, ", "World", "!") + text, entities = node.render() + assert text == "Hello, World!" + assert not entities + + def test_render_nested(self): + node = Text( + Text("Hello, ", Bold("World"), "!"), + "\n", + Text(Bold("This ", Underline("is"), " test", Italic("!"))), + ) + text, entities = node.render() + assert text == "Hello, World!\nThis is test!" + assert entities == [ + StrongMarkup(type=MarkupElementType.STRONG, from_=7, length=5), + StrongMarkup(type=MarkupElementType.STRONG, from_=14, length=13), + UnderlineMarkup(type=MarkupElementType.UNDERLINE, from_=19, length=2), + EmphasizedMarkup(type=MarkupElementType.EMPHASIZED, from_=26, length=1), + ] + + def test_as_html(self): + node = Text("Hello, ", Bold("World"), "!") + assert node.as_html() == "Hello, World!" + + def test_as_markdown(self): + node = Text("Hello, ", Bold("World"), "!") + assert node.as_markdown() == r"Hello, **World**\!" + + def test_replace(self): + node0 = Text("test0", param0="test1") + node1 = node0.replace("test1", "test2", param1="test1") + assert node0._body != node1._body + assert node0._params != node1._params + assert "param1" not in node0._params + assert "param1" in node1._params + + def test_add(self): + node0 = Text("Hello") + node1 = Bold("World") + + node2 = node0 + Text(", ") + node1 + "!" + assert node0 != node2 + assert node1 != node2 + assert len(node0._body) == 1 + assert len(node1._body) == 1 + assert len(node2._body) == 3 + + text, _ = node2.render() + assert text == "Hello, World!" + + def test_getitem_position(self): + node = Text("Hello, ", Bold("World"), "!") + with pytest.raises(TypeError): + node[2] + + def test_getitem_empty_slice(self): + node = Text("Hello, ", Bold("World"), "!") + new_node = node[:] + assert new_node is not node + assert isinstance(new_node, Text) + assert new_node._body == node._body + + def test_getitem_slice_zero(self): + node = Text("Hello, ", Bold("World"), "!") + new_node = node[2:2] + assert node is not new_node + assert isinstance(new_node, Text) + assert not new_node._body + + def test_getitem_slice_simple(self): + node = Text("Hello, ", Bold("World"), "!") + new_node = node[2:10] + assert isinstance(new_node, Text) + text, entities = new_node.render() + assert text == "llo, Wor" + assert len(entities) == 1 + assert entities[0].type == MarkupElementType.STRONG + + def test_getitem_slice_inside_child(self): + node = Text("Hello, ", Bold("World"), "!") + new_node = node[8:10] + assert isinstance(new_node, Text) + text, entities = new_node.render() + assert text == "or" + assert len(entities) == 1 + assert entities[0].type == MarkupElementType.STRONG + + def test_getitem_slice_tail(self): + node = Text("Hello, ", Bold("World"), "!") + new_node = node[12:13] + assert isinstance(new_node, Text) + text, entities = new_node.render() + assert text == "!" + assert not entities + + def test_from_entities(self): + # Most of the cases covered by text_decorations module + + node = Strikethrough.from_entities( + text="test1 test2 test3 test4 test5 test6", + entities=[ + MarkupElement(type=MarkupElementType.STRONG, from_=6, length=23), + MarkupElement(type=MarkupElementType.UNDERLINE, from_=12, length=5), + MarkupElement(type=MarkupElementType.EMPHASIZED, from_=24, length=5), + ], + ) + assert len(node._body) == 3 + assert isinstance(node, Strikethrough) + rendered = node.as_html() + assert ( + rendered + == "test1 test2 test3 test4 test5 test6" + ) + + def test_pretty_string(self): + node = Strikethrough.from_entities( + text="X", + entities=[ + UserMentionMarkup( + type=MarkupElementType.USER_MENTION, + from_=0, + length=1, + user_id=42, + ), + ], + ) + assert ( + node.as_pretty_string(indent=True) + == r"""Strikethrough( + Mention( + 'X', + user_id=42, + user_link= + ) +)""" + ) + + +class TestUtils: + def test_apply_entity(self): + node = _apply_entity( + MarkupElement(type=MarkupElementType.STRONG, from_=0, length=4), + "test", + ) + assert isinstance(node, Bold) + assert node._body == ("test",) + + def test_as_line(self): + node = as_line("test", "test", "test") + assert isinstance(node, Text) + assert len(node._body) == 4 # 3 + '\\n' + + def test_line_with_sep(self): + node = as_line("test", "test", "test", sep=" ") + assert isinstance(node, Text) + assert len(node._body) == 6 # 3 + 2 * ' ' + '\\n' + + def test_as_line_single_element_with_sep(self): + node = as_line("test", sep=" ") + assert isinstance(node, Text) + assert len(node._body) == 2 # 1 + '\\n' + + def test_as_list(self): + node = as_list("test", "test", "test") + assert isinstance(node, Text) + assert len(node._body) == 5 # 3 + 2 * '\\n' between lines + + def test_as_marked_list(self): + node = as_marked_list("test 1", "test 2", "test 3") + assert node.as_html() == "- test 1\n- test 2\n- test 3" + + def test_as_numbered_list(self): + node = as_numbered_list("test 1", "test 2", "test 3", start=5) + assert node.as_html() == "5. test 1\n6. test 2\n7. test 3" + + def test_as_section(self): + node = as_section("title", "test 1", "test 2", "test 3") + assert node.as_html() == "title\ntest 1test 2test 3" + + def test_as_marked_section(self): + node = as_marked_section("Section", "test 1", "test 2", "test 3") + assert node.as_html() == "Section\n- test 1\n- test 2\n- test 3" + + def test_as_numbered_section(self): + node = as_numbered_section("Section", "test 1", "test 2", "test 3", start=5) + assert node.as_html() == "Section\n5. test 1\n6. test 2\n7. test 3" + + def test_as_key_value(self): + node = as_key_value("key", "test 1") + assert node.as_html() == "key: test 1" diff --git a/tests/maxo_webhook/__init__.py b/tests/maxo_webhook/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/maxo_webhook/conftest.py b/tests/maxo_webhook/conftest.py new file mode 100644 index 00000000..c2a465ee --- /dev/null +++ b/tests/maxo_webhook/conftest.py @@ -0,0 +1,15 @@ +from ipaddress import IPv4Address + +import pytest + +from maxo import Bot + + +@pytest.fixture +def bot(): + return Bot("42:TEST") + + +@pytest.fixture +def localhost_ip() -> IPv4Address: + return IPv4Address("127.0.0.1") diff --git a/tests/maxo_webhook/fixtures/__init__.py b/tests/maxo_webhook/fixtures/__init__.py new file mode 100644 index 00000000..e27863b8 --- /dev/null +++ b/tests/maxo_webhook/fixtures/__init__.py @@ -0,0 +1,11 @@ +from .fixtures_bound_request import DummyAdapter, DummyBoundRequest, DummyRequest +from .fixtures_checks import ConditionalCheck, FailingCheck, PassingCheck + +__all__ = ( + "ConditionalCheck", + "DummyAdapter", + "DummyBoundRequest", + "DummyRequest", + "FailingCheck", + "PassingCheck", +) diff --git a/tests/maxo_webhook/fixtures/fixtures_bound_request.py b/tests/maxo_webhook/fixtures/fixtures_bound_request.py new file mode 100644 index 00000000..4a6ae26c --- /dev/null +++ b/tests/maxo_webhook/fixtures/fixtures_bound_request.py @@ -0,0 +1,46 @@ +from typing import Any + +from maxo.webhook.adapters.base_adapter import BoundRequest, WebAdapter + + +class DummyAdapter(WebAdapter): + def bind(self, request): + raise NotImplementedError("DummyAdapter.bind is not implemented") + + def register(self, app, path, handler, on_startup=None, on_shutdown=None): + raise NotImplementedError("DummyAdapter.register is not implemented") + + def create_json_response(self, status, payload): + return status, payload + + +class DummyRequest: + def __init__(self, *, path_params=None, query_params=None, headers=None, ip=None): + self.path_params = path_params or {} + self.query_params = query_params or {} + self.headers = headers or {} + self.ip = ip + + +class DummyBoundRequest(BoundRequest[DummyRequest]): + def __init__(self, request: DummyRequest | None = None): + super().__init__(request or DummyRequest()) + + async def json(self) -> dict[str, Any]: + return {} + + @property + def client_ip(self) -> str | None: + return self.request.ip + + @property + def headers(self): + return self.request.headers + + @property + def query_params(self): + return self.request.query_params + + @property + def path_params(self): + return self.request.path_params diff --git a/tests/maxo_webhook/fixtures/fixtures_checks.py b/tests/maxo_webhook/fixtures/fixtures_checks.py new file mode 100644 index 00000000..ac989730 --- /dev/null +++ b/tests/maxo_webhook/fixtures/fixtures_checks.py @@ -0,0 +1,21 @@ +from maxo import Bot +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.security.base_check import SecurityCheck + + +class PassingCheck(SecurityCheck): + async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + return True + + +class FailingCheck(SecurityCheck): + async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + return False + + +class ConditionalCheck(SecurityCheck): + def __init__(self, condition: bool): + self.condition = condition + + async def verify(self, bot: Bot, bound_request: BoundRequest) -> bool: + return self.condition diff --git a/tests/maxo_webhook/test_aiohttp_adapter.py b/tests/maxo_webhook/test_aiohttp_adapter.py new file mode 100644 index 00000000..654a3a2d --- /dev/null +++ b/tests/maxo_webhook/test_aiohttp_adapter.py @@ -0,0 +1,35 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from aiohttp import web + +from maxo.webhook.adapters.aiohttp.adapter import AiohttpBoundRequest, AiohttpWebAdapter + + +@pytest.fixture +def aiohttp_app(): + return web.Application() + + +@pytest.fixture +def mocked_engine(): + engine = MagicMock() + engine.feed_request = AsyncMock() + return engine + + +@pytest.mark.skip("Разобраться с ошибкой") +@pytest.mark.asyncio +async def test_adapter(aiohttp_client, aiohttp_app): + engine = AsyncMock(return_value=web.Response(status=200)) + + adapter = AiohttpWebAdapter() + adapter.register(aiohttp_app, "/webhook", engine) + + client = await aiohttp_client(aiohttp_app) + await client.post("/webhook", json={"foo": "bar"}) + + engine.assert_awaited_once() + request = engine.call_args.args[0] + assert isinstance(request, AiohttpBoundRequest) + assert await request.json() == {"foo": "bar"} diff --git a/tests/maxo_webhook/test_engines.py b/tests/maxo_webhook/test_engines.py new file mode 100644 index 00000000..ea13911a --- /dev/null +++ b/tests/maxo_webhook/test_engines.py @@ -0,0 +1,222 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from maxo.bot.bot import Bot +from maxo.routing.dispatcher import Dispatcher +from maxo.routing.signals import ( + AfterShutdown, + AfterStartup, + BeforeShutdown, + BeforeStartup, +) +from maxo.webhook.engines.simple import SimpleEngine +from maxo.webhook.engines.token import TokenEngine + + +class TestSimpleEngine: + @pytest.fixture + def dispatcher(self) -> Dispatcher: + return Dispatcher() + + @pytest.fixture + def bot(self) -> MagicMock: + return MagicMock(spec=Bot) + + @pytest.fixture + def web_adapter(self) -> MagicMock: + return MagicMock() + + @pytest.fixture + def routing(self) -> MagicMock: + return MagicMock() + + @pytest.fixture + def security(self) -> MagicMock: + security = MagicMock() + security.get_secret_token = AsyncMock(return_value="secret") + return security + + @pytest.fixture + def engine( + self, + dispatcher: Dispatcher, + bot: MagicMock, + web_adapter: MagicMock, + routing: MagicMock, + security: MagicMock, + ) -> SimpleEngine: + return SimpleEngine( + dispatcher, + bot, + web_adapter=web_adapter, + routing=routing, + security=security, + ) + + def test_get_bot_from_request(self, engine: SimpleEngine, bot: MagicMock): + assert engine._get_bot_from_request(MagicMock()) is bot + + @pytest.mark.asyncio + async def test_set_webhook( + self, + engine: SimpleEngine, + bot: MagicMock, + routing: MagicMock, + ): + routing.webhook_point.return_value = "https://example.com/webhook" + bot.subscribe = AsyncMock() + + await engine.set_webhook( + update_types=["message"], + ) + + bot.subscribe.assert_called_once() + call_kwargs = bot.subscribe.call_args.kwargs + assert call_kwargs["url"] == "https://example.com/webhook" + assert call_kwargs["update_types"] == ["message"] + + @pytest.mark.asyncio + async def test_on_startup(self, engine: SimpleEngine, dispatcher: Dispatcher): + dispatcher.feed_signal = AsyncMock() + await engine.on_startup(app=MagicMock()) + assert dispatcher.feed_signal.await_count == 2 + assert isinstance( + dispatcher.feed_signal.await_args_list[0].args[0], + BeforeStartup, + ) + assert isinstance( + dispatcher.feed_signal.await_args_list[1].args[0], + AfterStartup, + ) + + @pytest.mark.asyncio + async def test_on_shutdown( + self, + engine: SimpleEngine, + dispatcher: Dispatcher, + bot: MagicMock, + ): + dispatcher.feed_signal = AsyncMock() + bot.close = AsyncMock() + await engine.on_shutdown(app=MagicMock()) + assert dispatcher.feed_signal.await_count == 2 + bot.close.assert_awaited_once() + assert isinstance( + dispatcher.feed_signal.await_args_list[0].args[0], + BeforeShutdown, + ) + assert isinstance( + dispatcher.feed_signal.await_args_list[1].args[0], + AfterShutdown, + ) + + +class TestTokenEngine: + @pytest.fixture + def dispatcher(self) -> Dispatcher: + return Dispatcher() + + @pytest.fixture + def web_adapter(self) -> MagicMock: + return MagicMock() + + @pytest.fixture + def routing(self) -> MagicMock: + routing = MagicMock() + routing.extract_token.return_value = "42:TEST" + return routing + + @pytest.fixture + def security(self) -> MagicMock: + security = MagicMock() + security.get_secret_token = AsyncMock(return_value="secret") + return security + + @pytest.fixture + def engine( + self, + dispatcher: Dispatcher, + web_adapter: MagicMock, + routing: MagicMock, + security: MagicMock, + ) -> TokenEngine: + return TokenEngine( + dispatcher, + web_adapter=web_adapter, + routing=routing, + security=security, + ) + + def test_get_bot(self, engine: TokenEngine): + with patch("maxo.webhook.engines.token.Bot") as bot_mock: + bot_mock.side_effect = [MagicMock(spec=Bot), MagicMock(spec=Bot)] + + bot1 = engine.get_bot("42:TEST") + bot2 = engine.get_bot("42:TEST") + bot3 = engine.get_bot("43:TEST") + + assert bot1 is bot2 + assert bot1 is not bot3 + bot_mock.assert_any_call(token="42:TEST", defaults=None) # noqa: S106 + bot_mock.assert_any_call(token="43:TEST", defaults=None) # noqa: S106 + assert bot_mock.call_count == 2 + + def test_get_bot_from_request(self, engine: TokenEngine): + with patch.object(engine, "get_bot") as get_bot_mock: + engine._get_bot_from_request(MagicMock()) + get_bot_mock.assert_called_once_with("42:TEST") + + @pytest.mark.asyncio + async def test_set_webhook(self, engine: TokenEngine, routing: MagicMock): + routing.webhook_point.return_value = "https://example.com/webhook/42:TEST" + + with patch.object(engine, "get_bot") as get_bot_mock: + bot_mock = get_bot_mock.return_value + bot_mock.subscribe = AsyncMock() + + await engine.set_webhook( + "42:TEST", + update_types=["message"], + ) + + get_bot_mock.assert_called_once_with("42:TEST") + bot_mock.subscribe.assert_called_once() + call_kwargs = bot_mock.subscribe.call_args.kwargs + assert call_kwargs["url"] == "https://example.com/webhook/42:TEST" + assert call_kwargs["update_types"] == ["message"] + + @pytest.mark.asyncio + async def test_on_startup(self, engine: TokenEngine, dispatcher: Dispatcher): + dispatcher.feed_signal = AsyncMock() + await engine.on_startup(app=MagicMock()) + assert dispatcher.feed_signal.await_count == 2 + assert isinstance( + dispatcher.feed_signal.await_args_list[0].args[0], + BeforeStartup, + ) + assert isinstance( + dispatcher.feed_signal.await_args_list[1].args[0], + AfterStartup, + ) + + @pytest.mark.asyncio + async def test_on_shutdown(self, engine: TokenEngine, dispatcher: Dispatcher): + with patch.object(engine, "get_bot") as get_bot_mock: + bot_mock = get_bot_mock.return_value + bot_mock.close = AsyncMock() + engine._bots["42:TEST"] = bot_mock + + dispatcher.feed_signal = AsyncMock() + await engine.on_shutdown(app=MagicMock()) + assert dispatcher.feed_signal.await_count == 2 + bot_mock.close.assert_awaited_once() + assert not engine._bots + assert isinstance( + dispatcher.feed_signal.await_args_list[0].args[0], + BeforeShutdown, + ) + assert isinstance( + dispatcher.feed_signal.await_args_list[1].args[0], + AfterShutdown, + ) diff --git a/tests/maxo_webhook/test_fastapi_adapter.py b/tests/maxo_webhook/test_fastapi_adapter.py new file mode 100644 index 00000000..559ba1be --- /dev/null +++ b/tests/maxo_webhook/test_fastapi_adapter.py @@ -0,0 +1,28 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from maxo.webhook.adapters.base_adapter import BoundRequest +from maxo.webhook.adapters.fastapi.adapter import FastApiWebAdapter + + +@pytest.mark.asyncio +async def test_adapter(): + engine = MagicMock() + engine.feed_request = AsyncMock() + + async def handler(request: BoundRequest) -> None: + await engine.feed_request(request) + + app = FastAPI() + adapter = FastApiWebAdapter() + adapter.register(app, "/webhook", handler) + + client = TestClient(app) + client.post("/webhook", json={"foo": "bar"}) + + engine.feed_request.assert_awaited_once() + request = engine.feed_request.call_args.args[0] + assert await request.json() == {"foo": "bar"} diff --git a/tests/maxo_webhook/test_routing.py b/tests/maxo_webhook/test_routing.py new file mode 100644 index 00000000..43213030 --- /dev/null +++ b/tests/maxo_webhook/test_routing.py @@ -0,0 +1,170 @@ +import pytest +from yarl import URL + +from maxo import Bot +from maxo.webhook.routing.path import PathRouting +from maxo.webhook.routing.query import QueryRouting +from maxo.webhook.routing.static import StaticRouting + +from .fixtures import DummyBoundRequest, DummyRequest + + +@pytest.mark.parametrize( + "url", + [ + "https://example.com/webhook", + "https://example.com/webhook/", + "https://example.com/webhook/any/path", + "https://example.com/webhook?foo=bar", + ], +) +def test_static_routing(url, bot): + routing = StaticRouting(url=url) + assert routing.webhook_point(bot) == url + + +@pytest.mark.parametrize( + ("url", "param", "token", "path_params", "expected_url", "expected_token"), + [ + ( + "https://example.com/webhook/{token}", + "token", + "42:TEST", + {"token": "42:TEST"}, + "https://example.com/webhook/42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook/{token}", + "token", + "42:TEST", + {}, + "https://example.com/webhook/42:TEST", + None, + ), + ( + "https://example.com/webhook/{mytoken}", + "mytoken", + "42:TEST", + {"mytoken": "42:TEST"}, + "https://example.com/webhook/42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook/{mytoken}", + "mytoken", + "42:TEST", + {}, + "https://example.com/webhook/42:TEST", + None, + ), + ], + ids=[ + "standard-param-present", + "standard-param-missing", + "custom-param-present", + "custom-param-missing", + ], +) +def test_path_routing(url, param, token, path_params, expected_url, expected_token): + routing = PathRouting(url=url, param=param) + assert routing.webhook_point(Bot(token)) == expected_url + req = DummyBoundRequest(DummyRequest(path_params=path_params)) + assert routing.extract_token(req) == expected_token + + +@pytest.mark.parametrize( + ("url", "param", "token", "query_params", "expected_url", "expected_token"), + [ + ( + "https://example.com/webhook", + "token", + "42:TEST", + {"token": "42:TEST"}, + "https://example.com/webhook?token=42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook", + "token", + "42:TEST", + {}, + "https://example.com/webhook?token=42:TEST", + None, + ), + ( + "https://example.com/webhook", + "mytoken", + "42:TEST", + {"mytoken": "42:TEST"}, + "https://example.com/webhook?mytoken=42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook", + "mytoken", + "42:TEST", + {}, + "https://example.com/webhook?mytoken=42:TEST", + None, + ), + ( + "https://example.com/webhook?other=value", + "token", + "42:TEST", + {"token": "42:TEST", "other": "value"}, + "https://example.com/webhook?other=value&token=42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook?foo=bar&baz=qux", + "token", + "42:TEST", + {"token": "42:TEST", "foo": "bar", "baz": "qux"}, + "https://example.com/webhook?foo=bar&baz=qux&token=42:TEST", + "42:TEST", + ), + ( + "https://example.com/webhook?token=old_value&other=value", + "token", + "42:TEST", + {"token": "42:TEST", "other": "value"}, + "https://example.com/webhook?token=42:TEST&other=value", + "42:TEST", + ), + ( + "https://example.com/webhook?api_key=secret&debug=true", + "bot_token", + "123:ABC", + {"bot_token": "123:ABC", "api_key": "secret", "debug": "true"}, + "https://example.com/webhook?api_key=secret&debug=true&bot_token=123:ABC", + "123:ABC", + ), + ], + ids=[ + "standard-param-present", + "standard-param-missing", + "custom-param-present", + "custom-param-missing", + "preserve-existing-params", + "preserve-multiple-params", + "override-token-param", + "complex-params", + ], +) +def test_query_routing(url, param, token, query_params, expected_url, expected_token): + routing = QueryRouting(url=url, param=param) + webhook_url = routing.webhook_point(Bot(token)) + + # Parse both URLs to compare query params (order may differ) + expected = URL(expected_url) + actual = URL(webhook_url) + + # Check that all expected query params are present + assert dict(actual.query) == dict( + expected.query, + ), f"Query parameters mismatch. Expected: {dict(expected.query)}, Got: {dict(actual.query)}" + + # Check token extraction + req = DummyBoundRequest(DummyRequest(query_params=query_params)) + assert routing.extract_token(req) == expected_token diff --git a/tests/maxo_webhook/test_secret_token.py b/tests/maxo_webhook/test_secret_token.py new file mode 100644 index 00000000..8119ab68 --- /dev/null +++ b/tests/maxo_webhook/test_secret_token.py @@ -0,0 +1,37 @@ +import pytest + +from maxo.webhook.security import Security, StaticSecretToken +from maxo.webhook.security.secret_token import SECRET_HEADER + +from .fixtures import DummyBoundRequest, DummyRequest + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("secret_token", "request_token", "expected"), + [ + ("my-secret", "my-secret", True), + ("my-secret", "wrong-secret", False), + ("my-secret", None, False), + ], + ids=["match", "mismatch", "none"], +) +async def test_security_secret_token(secret_token, request_token, expected, bot): + sec = Security(secret_token=StaticSecretToken(secret_token)) + headers = {SECRET_HEADER: request_token} if request_token is not None else {} + req = DummyBoundRequest(DummyRequest(headers=headers)) + assert await sec.verify(bot, req) is expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("secret_token", "expected"), + [ + (StaticSecretToken("test-secret"), "test-secret"), + (None, None), + ], + ids=["with-secret", "without-secret"], +) +async def test_security_get_secret_token(secret_token, expected, bot): + sec = Security(secret_token=secret_token) + assert await sec.get_secret_token(bot=bot) == expected diff --git a/tests/maxo_webhook/test_security.py b/tests/maxo_webhook/test_security.py new file mode 100644 index 00000000..9a695250 --- /dev/null +++ b/tests/maxo_webhook/test_security.py @@ -0,0 +1,88 @@ +import pytest + +from maxo.webhook.security.secret_token import ( + SECRET_HEADER, + StaticSecretToken, +) +from maxo.webhook.security.security import Security + +from .fixtures import DummyBoundRequest, DummyRequest, FailingCheck, PassingCheck + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("checks", "expected"), + [ + # No checks - should pass + ([], True), + # Single check + ([PassingCheck()], True), + ([FailingCheck()], False), + # Two checks + ([PassingCheck(), PassingCheck()], True), + ([PassingCheck(), FailingCheck()], False), + ([FailingCheck(), PassingCheck()], False), + ([FailingCheck(), FailingCheck()], False), + # Three+ checks + ([PassingCheck(), PassingCheck(), PassingCheck()], True), + ([FailingCheck(), PassingCheck(), PassingCheck()], False), + ([PassingCheck(), PassingCheck(), FailingCheck()], False), + ], + ids=[ + "no-checks", + "single-passing", + "single-failing", + "two-passing", + "passing-then-failing", + "failing-then-passing", + "two-failing", + "three-passing", + "failing-first-passing", + "failing-last-passing", + ], +) +async def test_security_checks(checks, expected, bot): + sec = Security(*checks) + req = DummyBoundRequest() + assert await sec.verify(bot, req) is expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("checks", "secret_token", "request_token", "expected"), + [ + # Both present and working + ([PassingCheck()], StaticSecretToken("secret"), "secret", True), + ([FailingCheck()], StaticSecretToken("secret"), "secret", False), + ([PassingCheck()], StaticSecretToken("secret"), "wrong", False), + # No checks + ([], StaticSecretToken("secret"), "secret", True), + ([], StaticSecretToken("secret"), "wrong", False), + # No secret token + ([PassingCheck()], None, None, True), + ([FailingCheck()], None, None, False), + # No checks and no secret token + ([], None, None, True), + ], + ids=[ + "both-pass", + "check-fails", + "secret-fails", + "no-checks-secret-pass", + "no-checks-secret-fail", + "no-secret-check-pass", + "no-secret-check-fail", + "no-checks-no-secret", + ], +) +async def test_security_checks_and_secret_token( + checks, + secret_token, + request_token, + expected, + bot, +): + sec = Security(*checks, secret_token=secret_token) + headers = {SECRET_HEADER: request_token} if request_token is not None else {} + req = DummyBoundRequest(DummyRequest(headers=headers)) + assert await sec.verify(bot, req) is expected