diff --git a/packages/apps/src/microsoft_teams/apps/http_stream.py b/packages/apps/src/microsoft_teams/apps/http_stream.py index 645025c8..353aff3f 100644 --- a/packages/apps/src/microsoft_teams/apps/http_stream.py +++ b/packages/apps/src/microsoft_teams/apps/http_stream.py @@ -8,6 +8,7 @@ from collections import deque from typing import Awaitable, Callable, List, Optional, Union +from httpx import HTTPStatusError from microsoft_teams.api import ( ApiClient, Attachment, @@ -62,6 +63,7 @@ def __init__(self, client: ApiClient, ref: ConversationReference): self._total_wait_timeout: float = 30.0 self._state_changed = asyncio.Event() + self._canceled = False self._reset_state() def _reset_state(self) -> None: @@ -74,6 +76,14 @@ def _reset_state(self) -> None: self._entities: List[Entity] = [] self._queue: deque[Union[MessageActivityInput, TypingActivityInput, str]] = deque() + @property + def canceled(self) -> bool: + """ + Whether the stream has been canceled. + For example when the user pressed the Stop button or the 2-minute timeout has exceeded. + """ + return self._canceled + @property def closed(self) -> bool: """Whether the final stream message has been sent.""" @@ -103,6 +113,9 @@ def emit(self, activity: Union[MessageActivityInput, TypingActivityInput, str]) activity: The activity to emit. """ + if self._canceled: + return + if isinstance(activity, str): activity = MessageActivityInput(text=activity, type="message") self._queue.append(activity) @@ -265,12 +278,25 @@ async def _send(self, to_send: Union[TypingActivityInput, MessageActivityInput]) Args: activity: The activity to send. """ + if self._canceled: + logger.warning("Teams channel stopped the stream.") + raise asyncio.CancelledError("Teams channel stopped the stream.") + to_send.from_ = self._ref.bot to_send.conversation = self._ref.conversation - if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])): - res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send) - else: - res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send) - - return SentActivity.merge(to_send, res) + try: + if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])): + res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send) + else: + res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send) + + return SentActivity.merge(to_send, res) + except HTTPStatusError as e: + if e.response.status_code == 403: + self._canceled = True + logger.warning("Teams channel stopped the stream.") + raise asyncio.CancelledError("Teams channel stopped the stream.") from e + raise e + except Exception as e: + raise e diff --git a/packages/apps/src/microsoft_teams/apps/plugins/streamer.py b/packages/apps/src/microsoft_teams/apps/plugins/streamer.py index c73b6fa9..e8e6cf7b 100644 --- a/packages/apps/src/microsoft_teams/apps/plugins/streamer.py +++ b/packages/apps/src/microsoft_teams/apps/plugins/streamer.py @@ -13,6 +13,14 @@ class StreamerProtocol(Protocol): """Component that can send streamed chunks of an activity.""" + @property + def canceled(self) -> bool: + """ + Whether the stream has been canceled. + For example when the user pressed the Stop button or the 2-minute timeout has exceeded. + """ + ... + @property def closed(self) -> bool: """Whether the final stream message has been sent.""" diff --git a/packages/apps/tests/test_http_stream.py b/packages/apps/tests/test_http_stream.py index f6c9905a..9ad8e3c9 100644 --- a/packages/apps/tests/test_http_stream.py +++ b/packages/apps/tests/test_http_stream.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest +from httpx import HTTPStatusError, Request, Response from microsoft_teams.api import ( Account, ApiClient, @@ -219,3 +220,102 @@ async def emit_task(): await self._run_scheduled_flushes(scheduled) assert max_concurrent_entries == 1 + + @pytest.mark.asyncio + async def test_stream_canceled_on_403(self, mock_api_client, conversation_reference, patch_loop_call_later): + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_403(activity): + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_403 + stream = HttpStream(mock_api_client, conversation_reference) + + stream.emit("Test message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + + @pytest.mark.asyncio + async def test_emit_blocked_after_cancel(self, mock_api_client, conversation_reference, patch_loop_call_later): + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_403(activity): + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_403 + stream = HttpStream(mock_api_client, conversation_reference) + + stream.emit("First message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + + # Emit after cancel should be a no-op + stream.emit("Should be ignored") + assert stream.count == 0 + + @pytest.mark.asyncio + async def test_send_blocked_after_cancel(self, mock_api_client, conversation_reference): + stream = HttpStream(mock_api_client, conversation_reference) + stream._canceled = True + + with pytest.raises(asyncio.CancelledError, match="Teams channel stopped the stream."): + await stream._send(TypingActivityInput(text="test")) + + @pytest.mark.asyncio + async def test_stream_canceled_after_successful_message( + self, mock_api_client, conversation_reference, patch_loop_call_later + ): + call_count = 0 + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_then_403(activity): + nonlocal call_count + call_count += 1 + if call_count == 1: + return SentActivity(id="activity-1", activity_params=activity) + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_then_403 + stream = HttpStream(mock_api_client, conversation_reference) + + # First emit succeeds + stream.emit("First message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is False + assert call_count == 1 + + # Second emit triggers 403 + stream.emit("Second message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + assert call_count == 2 + + # Further emits are blocked + stream.emit("Should be ignored") + assert stream.count == 0