Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions packages/apps/src/microsoft_teams/apps/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import deque
from typing import Awaitable, Callable, List, Optional, Union

from httpx import HTTPStatusError
from microsoft_teams.api import (
ApiClient,
Attachment,
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(self, client: ApiClient, ref: ConversationReference):
self._total_wait_timeout: float = 30.0
self._state_changed = asyncio.Event()

self._canceled = False
self._reset_state()

def _reset_state(self) -> None:
Expand All @@ -74,6 +76,14 @@ def _reset_state(self) -> None:
self._entities: List[Entity] = []
self._queue: deque[Union[MessageActivityInput, TypingActivityInput, str]] = deque()

@property
def canceled(self) -> bool:
"""
Whether the stream has been canceled.
For example when the user pressed the Stop button or the 2-minute timeout has exceeded.
"""
return self._canceled

@property
def closed(self) -> bool:
"""Whether the final stream message has been sent."""
Expand Down Expand Up @@ -103,6 +113,9 @@ def emit(self, activity: Union[MessageActivityInput, TypingActivityInput, str])
activity: The activity to emit.
"""

if self._canceled:
return

if isinstance(activity, str):
activity = MessageActivityInput(text=activity, type="message")
self._queue.append(activity)
Expand Down Expand Up @@ -265,12 +278,25 @@ async def _send(self, to_send: Union[TypingActivityInput, MessageActivityInput])
Args:
activity: The activity to send.
"""
if self._canceled:
logger.warning("Teams channel stopped the stream.")
raise asyncio.CancelledError("Teams channel stopped the stream.")

to_send.from_ = self._ref.bot
to_send.conversation = self._ref.conversation

if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])):
res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send)
else:
res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send)

return SentActivity.merge(to_send, res)
try:
if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])):
res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send)
else:
res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send)

return SentActivity.merge(to_send, res)
except HTTPStatusError as e:
if e.response.status_code == 403:
self._canceled = True
logger.warning("Teams channel stopped the stream.")
raise asyncio.CancelledError("Teams channel stopped the stream.") from e
raise e
except Exception as e:
raise e
8 changes: 8 additions & 0 deletions packages/apps/src/microsoft_teams/apps/plugins/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
class StreamerProtocol(Protocol):
"""Component that can send streamed chunks of an activity."""

@property
def canceled(self) -> bool:
"""
Whether the stream has been canceled.
For example when the user pressed the Stop button or the 2-minute timeout has exceeded.
"""
...

@property
def closed(self) -> bool:
"""Whether the final stream message has been sent."""
Expand Down
100 changes: 100 additions & 0 deletions packages/apps/tests/test_http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import MagicMock, patch

import pytest
from httpx import HTTPStatusError, Request, Response
from microsoft_teams.api import (
Account,
ApiClient,
Expand Down Expand Up @@ -219,3 +220,102 @@ async def emit_task():
await self._run_scheduled_flushes(scheduled)

assert max_concurrent_entries == 1

@pytest.mark.asyncio
async def test_stream_canceled_on_403(self, mock_api_client, conversation_reference, patch_loop_call_later):
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_403(activity):
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_403
stream = HttpStream(mock_api_client, conversation_reference)

stream.emit("Test message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True

@pytest.mark.asyncio
async def test_emit_blocked_after_cancel(self, mock_api_client, conversation_reference, patch_loop_call_later):
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_403(activity):
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_403
stream = HttpStream(mock_api_client, conversation_reference)

stream.emit("First message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True

# Emit after cancel should be a no-op
stream.emit("Should be ignored")
assert stream.count == 0

@pytest.mark.asyncio
async def test_send_blocked_after_cancel(self, mock_api_client, conversation_reference):
stream = HttpStream(mock_api_client, conversation_reference)
stream._canceled = True

with pytest.raises(asyncio.CancelledError, match="Teams channel stopped the stream."):
await stream._send(TypingActivityInput(text="test"))

@pytest.mark.asyncio
async def test_stream_canceled_after_successful_message(
self, mock_api_client, conversation_reference, patch_loop_call_later
):
call_count = 0
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_then_403(activity):
nonlocal call_count
call_count += 1
if call_count == 1:
return SentActivity(id="activity-1", activity_params=activity)
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_then_403
stream = HttpStream(mock_api_client, conversation_reference)

# First emit succeeds
stream.emit("First message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is False
assert call_count == 1

# Second emit triggers 403
stream.emit("Second message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True
assert call_count == 2

# Further emits are blocked
stream.emit("Should be ignored")
assert stream.count == 0
Loading