diff --git a/slack_bolt/context/say_stream/async_say_stream.py b/slack_bolt/context/say_stream/async_say_stream.py index af776891b..df9b362e2 100644 --- a/slack_bolt/context/say_stream/async_say_stream.py +++ b/slack_bolt/context/say_stream/async_say_stream.py @@ -34,6 +34,9 @@ async def __call__( recipient_team_id: Optional[str] = None, recipient_user_id: Optional[str] = None, thread_ts: Optional[str] = None, + icon_emoji: Optional[str] = None, + icon_url: Optional[str] = None, + username: Optional[str] = None, **kwargs, ) -> AsyncChatStream: """Starts a new chat stream with context.""" @@ -51,6 +54,9 @@ async def __call__( recipient_team_id=recipient_team_id or self.recipient_team_id, recipient_user_id=recipient_user_id or self.recipient_user_id, thread_ts=thread_ts, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) return await self.client.chat_stream( @@ -58,5 +64,8 @@ async def __call__( recipient_team_id=recipient_team_id or self.recipient_team_id, recipient_user_id=recipient_user_id or self.recipient_user_id, thread_ts=thread_ts, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) diff --git a/slack_bolt/context/say_stream/say_stream.py b/slack_bolt/context/say_stream/say_stream.py index b6a5ca797..15bdcc110 100644 --- a/slack_bolt/context/say_stream/say_stream.py +++ b/slack_bolt/context/say_stream/say_stream.py @@ -34,6 +34,9 @@ def __call__( recipient_team_id: Optional[str] = None, recipient_user_id: Optional[str] = None, thread_ts: Optional[str] = None, + icon_emoji: Optional[str] = None, + icon_url: Optional[str] = None, + username: Optional[str] = None, **kwargs, ) -> ChatStream: """Starts a new chat stream with context.""" @@ -51,6 +54,9 @@ def __call__( recipient_team_id=recipient_team_id or self.recipient_team_id, recipient_user_id=recipient_user_id or self.recipient_user_id, thread_ts=thread_ts, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) return self.client.chat_stream( @@ -58,5 +64,8 @@ def __call__( recipient_team_id=recipient_team_id or self.recipient_team_id, recipient_user_id=recipient_user_id or self.recipient_user_id, thread_ts=thread_ts, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) diff --git a/slack_bolt/context/set_status/async_set_status.py b/slack_bolt/context/set_status/async_set_status.py index e2c451f46..f10cc195c 100644 --- a/slack_bolt/context/set_status/async_set_status.py +++ b/slack_bolt/context/set_status/async_set_status.py @@ -23,6 +23,9 @@ async def __call__( self, status: str, loading_messages: Optional[List[str]] = None, + icon_emoji: Optional[str] = None, + icon_url: Optional[str] = None, + username: Optional[str] = None, **kwargs, ) -> AsyncSlackResponse: return await self.client.assistant_threads_setStatus( @@ -30,5 +33,8 @@ async def __call__( thread_ts=self.thread_ts, status=status, loading_messages=loading_messages, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) diff --git a/slack_bolt/context/set_status/set_status.py b/slack_bolt/context/set_status/set_status.py index 0ed612e16..055a5cab7 100644 --- a/slack_bolt/context/set_status/set_status.py +++ b/slack_bolt/context/set_status/set_status.py @@ -23,6 +23,9 @@ def __call__( self, status: str, loading_messages: Optional[List[str]] = None, + icon_emoji: Optional[str] = None, + icon_url: Optional[str] = None, + username: Optional[str] = None, **kwargs, ) -> SlackResponse: return self.client.assistant_threads_setStatus( @@ -30,5 +33,8 @@ def __call__( thread_ts=self.thread_ts, status=status, loading_messages=loading_messages, + icon_emoji=icon_emoji, + icon_url=icon_url, + username=username, **kwargs, ) diff --git a/tests/slack_bolt/context/test_say_stream.py b/tests/slack_bolt/context/test_say_stream.py index 29d244a65..04a52e419 100644 --- a/tests/slack_bolt/context/test_say_stream.py +++ b/tests/slack_bolt/context/test_say_stream.py @@ -1,21 +1,14 @@ import pytest +from unittest.mock import patch, MagicMock + from slack_sdk import WebClient from slack_bolt.context.say_stream.say_stream import SayStream -from tests.mock_web_api_server import cleanup_mock_web_api_server, setup_mock_web_api_server class TestSayStream: - default_chat_stream_buffer_size = WebClient.chat_stream.__kwdefaults__["buffer_size"] - def setup_method(self): - setup_mock_web_api_server(self) - valid_token = "xoxb-valid" - mock_api_server_base_url = "http://localhost:8888" - self.web_client = WebClient(token=valid_token, base_url=mock_api_server_base_url) - - def teardown_method(self): - cleanup_mock_web_api_server(self) + self.web_client = WebClient(token="xoxb-valid") def test_missing_channel_raises(self): say_stream = SayStream(client=self.web_client, channel=None, thread_ts="111.222") @@ -35,16 +28,17 @@ def test_default_params(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = say_stream() - - assert stream._buffer_size == self.default_chat_stream_buffer_size - assert stream._stream_args == { - "channel": "C111", - "thread_ts": "111.222", - "recipient_team_id": "T111", - "recipient_user_id": "U111", - "task_display_mode": None, - } + with patch.object(self.web_client, "chat_stream", return_value=MagicMock()) as mock_chat_stream: + say_stream() + mock_chat_stream.assert_called_once_with( + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + icon_emoji=None, + icon_url=None, + username=None, + ) def test_parameter_overrides(self): say_stream = SayStream( @@ -54,16 +48,17 @@ def test_parameter_overrides(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") - - assert stream._buffer_size == self.default_chat_stream_buffer_size - assert stream._stream_args == { - "channel": "C222", - "thread_ts": "333.444", - "recipient_team_id": "T222", - "recipient_user_id": "U222", - "task_display_mode": None, - } + with patch.object(self.web_client, "chat_stream", return_value=MagicMock()) as mock_chat_stream: + say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") + mock_chat_stream.assert_called_once_with( + channel="C222", + recipient_team_id="T222", + recipient_user_id="U222", + thread_ts="333.444", + icon_emoji=None, + icon_url=None, + username=None, + ) def test_buffer_size_overrides(self): say_stream = SayStream( @@ -73,19 +68,41 @@ def test_buffer_size_overrides(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = say_stream( - buffer_size=100, - channel="C222", - thread_ts="333.444", - recipient_team_id="T222", - recipient_user_id="U222", - ) + with patch.object(self.web_client, "chat_stream", return_value=MagicMock()) as mock_chat_stream: + say_stream( + buffer_size=100, + channel="C222", + thread_ts="333.444", + recipient_team_id="T222", + recipient_user_id="U222", + ) + mock_chat_stream.assert_called_once_with( + buffer_size=100, + channel="C222", + recipient_team_id="T222", + recipient_user_id="U222", + thread_ts="333.444", + icon_emoji=None, + icon_url=None, + username=None, + ) - assert stream._buffer_size == 100 - assert stream._stream_args == { - "channel": "C222", - "thread_ts": "333.444", - "recipient_team_id": "T222", - "recipient_user_id": "U222", - "task_display_mode": None, - } + def test_authorship_overrides(self): + say_stream = SayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + ) + with patch.object(self.web_client, "chat_stream", return_value=MagicMock()) as mock_chat_stream: + say_stream(icon_emoji=":maple_leaf:", username="Charlie Brown") + mock_chat_stream.assert_called_once_with( + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + icon_emoji=":maple_leaf:", + icon_url=None, + username="Charlie Brown", + ) diff --git a/tests/slack_bolt/context/test_set_status.py b/tests/slack_bolt/context/test_set_status.py index fe998df5e..bb5807e96 100644 --- a/tests/slack_bolt/context/test_set_status.py +++ b/tests/slack_bolt/context/test_set_status.py @@ -32,6 +32,15 @@ def test_set_status_loading_messages(self): ) assert response.status_code == 200 + def test_set_status_authorship(self): + set_status = SetStatus(client=self.web_client, channel_id="C111", thread_ts="123.123") + response: SlackResponse = set_status( + status="Thinking...", + icon_emoji=":maple_leaf:", + username="Charlie Brown", + ) + assert response.status_code == 200 + def test_set_status_invalid(self): set_status = SetStatus(client=self.web_client, channel_id="C111", thread_ts="123.123") with pytest.raises(TypeError): diff --git a/tests/slack_bolt_async/context/test_async_say_stream.py b/tests/slack_bolt_async/context/test_async_say_stream.py index 016549bd6..7ac084044 100644 --- a/tests/slack_bolt_async/context/test_async_say_stream.py +++ b/tests/slack_bolt_async/context/test_async_say_stream.py @@ -1,28 +1,20 @@ import pytest +from unittest.mock import patch, MagicMock + from slack_sdk.web.async_client import AsyncWebClient from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream -from tests.mock_web_api_server import ( - cleanup_mock_web_api_server, - setup_mock_web_api_server, -) from tests.utils import remove_os_env_temporarily, restore_os_env class TestAsyncSayStream: - default_chat_stream_buffer_size = AsyncWebClient.chat_stream.__kwdefaults__["buffer_size"] - @pytest.fixture(scope="function", autouse=True) def setup_teardown(self): old_os_env = remove_os_env_temporarily() - setup_mock_web_api_server(self) - valid_token = "xoxb-valid" - mock_api_server_base_url = "http://localhost:8888" try: - self.web_client = AsyncWebClient(token=valid_token, base_url=mock_api_server_base_url) - yield # run the test here + self.web_client = AsyncWebClient(token="xoxb-valid") + yield finally: - cleanup_mock_web_api_server(self) restore_os_env(old_os_env) @pytest.mark.asyncio @@ -46,16 +38,22 @@ async def test_default_params(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = await say_stream() + mock_chat_stream = MagicMock() - assert stream._buffer_size == self.default_chat_stream_buffer_size - assert stream._stream_args == { - "channel": "C111", - "thread_ts": "111.222", - "recipient_team_id": "T111", - "recipient_user_id": "U111", - "task_display_mode": None, - } + async def fake_chat_stream(**kwargs): + return mock_chat_stream(**kwargs) + + with patch.object(self.web_client, "chat_stream", side_effect=fake_chat_stream): + await say_stream() + mock_chat_stream.assert_called_once_with( + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + icon_emoji=None, + icon_url=None, + username=None, + ) @pytest.mark.asyncio async def test_parameter_overrides(self): @@ -66,16 +64,22 @@ async def test_parameter_overrides(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = await say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") + mock_chat_stream = MagicMock() - assert stream._buffer_size == self.default_chat_stream_buffer_size - assert stream._stream_args == { - "channel": "C222", - "thread_ts": "333.444", - "recipient_team_id": "T222", - "recipient_user_id": "U222", - "task_display_mode": None, - } + async def fake_chat_stream(**kwargs): + return mock_chat_stream(**kwargs) + + with patch.object(self.web_client, "chat_stream", side_effect=fake_chat_stream): + await say_stream(channel="C222", thread_ts="333.444", recipient_team_id="T222", recipient_user_id="U222") + mock_chat_stream.assert_called_once_with( + channel="C222", + recipient_team_id="T222", + recipient_user_id="U222", + thread_ts="333.444", + icon_emoji=None, + icon_url=None, + username=None, + ) @pytest.mark.asyncio async def test_buffer_size_overrides(self): @@ -86,19 +90,52 @@ async def test_buffer_size_overrides(self): recipient_user_id="U111", thread_ts="111.222", ) - stream = await say_stream( - buffer_size=100, - channel="C222", - thread_ts="333.444", - recipient_team_id="T222", - recipient_user_id="U222", + mock_chat_stream = MagicMock() + + async def fake_chat_stream(**kwargs): + return mock_chat_stream(**kwargs) + + with patch.object(self.web_client, "chat_stream", side_effect=fake_chat_stream): + await say_stream( + buffer_size=100, + channel="C222", + thread_ts="333.444", + recipient_team_id="T222", + recipient_user_id="U222", + ) + mock_chat_stream.assert_called_once_with( + buffer_size=100, + channel="C222", + recipient_team_id="T222", + recipient_user_id="U222", + thread_ts="333.444", + icon_emoji=None, + icon_url=None, + username=None, + ) + + @pytest.mark.asyncio + async def test_authorship_overrides(self): + say_stream = AsyncSayStream( + client=self.web_client, + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", ) + mock_chat_stream = MagicMock() + + async def fake_chat_stream(**kwargs): + return mock_chat_stream(**kwargs) - assert stream._buffer_size == 100 - assert stream._stream_args == { - "channel": "C222", - "thread_ts": "333.444", - "recipient_team_id": "T222", - "recipient_user_id": "U222", - "task_display_mode": None, - } + with patch.object(self.web_client, "chat_stream", side_effect=fake_chat_stream): + await say_stream(icon_emoji=":maple_leaf:", username="Charlie Brown") + mock_chat_stream.assert_called_once_with( + channel="C111", + recipient_team_id="T111", + recipient_user_id="U111", + thread_ts="111.222", + icon_emoji=":maple_leaf:", + icon_url=None, + username="Charlie Brown", + ) diff --git a/tests/slack_bolt_async/context/test_async_set_status.py b/tests/slack_bolt_async/context/test_async_set_status.py index e785ff89e..bcf1fcf19 100644 --- a/tests/slack_bolt_async/context/test_async_set_status.py +++ b/tests/slack_bolt_async/context/test_async_set_status.py @@ -40,6 +40,16 @@ async def test_set_status_loading_messages(self): ) assert response.status_code == 200 + @pytest.mark.asyncio + async def test_set_status_authorship(self): + set_status = AsyncSetStatus(client=self.web_client, channel_id="C111", thread_ts="123.123") + response: AsyncSlackResponse = await set_status( + status="Thinking...", + icon_emoji=":maple_leaf:", + username="Charlie Brown", + ) + assert response.status_code == 200 + @pytest.mark.asyncio async def test_set_status_invalid(self): set_status = AsyncSetStatus(client=self.web_client, channel_id="C111", thread_ts="123.123")