11from __future__ import annotations
22
33import json
4- from typing import TYPE_CHECKING
4+ from functools import partial
5+ from typing import TYPE_CHECKING , Any
56from unittest .mock import Mock
67
78import pytest
9+ from streamdeck .event_listener import StopStreaming
810from streamdeck .websocket import WebSocketClient
9- from websockets import ConnectionClosedOK , WebSocketException
11+ from websockets import (
12+ ConnectionClosed ,
13+ ConnectionClosedError ,
14+ ConnectionClosedOK ,
15+ InvalidHeader ,
16+ WebSocketException ,
17+ )
1018
1119
1220if TYPE_CHECKING :
@@ -18,13 +26,16 @@ def mock_connection() -> Mock:
1826 """Fixture to mock the ClientConnection object returned by websockets.sync.client.connect."""
1927 return Mock ()
2028
29+
2130@pytest .fixture
2231def patched_connect (mocker : MockerFixture , mock_connection : Mock ) -> Mock :
2332 """Fixture to mock the ClientConnection object returned by websockets.sync.client.connect."""
2433 return mocker .patch ("streamdeck.websocket.connect" , return_value = mock_connection )
2534
2635
27- def test_initialization_calls_connect_correctly (patched_connect : Mock , mock_connection : Mock , port_number : int ) -> None :
36+ def test_initialization_calls_connect_correctly (
37+ patched_connect : Mock , mock_connection : Mock , port_number : int
38+ ) -> None :
2839 """Test that WebSocketClient initializes correctly by calling the connect function with the appropriate URI."""
2940 with WebSocketClient (port = port_number ) as client :
3041 # Assert that 'connect' was called once with the correct URI.
@@ -50,9 +61,38 @@ def test_send_event_serializes_and_sends(mock_connection: Mock, port_number: int
5061def test_listen_yields_messages (mock_connection : Mock , port_number : int ) -> None :
5162 """Test that listen yields messages from the WebSocket connection."""
5263 # Set up the mocked connection to return messages until closing
53- mock_connection .recv .side_effect = ["message1" , b"message2" , WebSocketException ()]
64+ expected_results = ["message1" , b"message2" , "message3" ]
65+ mock_connection .recv .side_effect = expected_results
5466
5567 with WebSocketClient (port = port_number ) as client :
56- messages = list (client .listen ())
68+ actual_messages : list [Any ] = []
69+ for i , msg in enumerate (client .listen ()):
70+ actual_messages .append (msg )
71+ if i == 2 :
72+ break
73+
74+ assert actual_messages == expected_results
75+
76+
77+ @pytest .mark .parametrize (
78+ "exception_class" ,
79+ [
80+ partial (ConnectionClosedOK , None , None ),
81+ partial (ConnectionClosedError , None , None ),
82+ partial (InvalidHeader , "header-name" , None ),
83+ partial (ConnectionClosed , None , None ),
84+ WebSocketException ,
85+ ],
86+ )
87+ @pytest .mark .usefixtures ("patched_connect" )
88+ def test_listen_raises_StopStreaming_from_WebSocketException (
89+ mock_connection : Mock , port_number : int , exception_class : type [WebSocketException ]
90+ ) -> None :
91+ """Test that listen raises a StopStreaming exception when a WebSocketException is raised."""
92+ # Set up the mocked connection to return messages until closing
93+ mock_connection .recv .side_effect = ["message1" , b"message2" , exception_class ()]
5794
58- assert messages == ["message1" , b"message2" ]
95+ # This should raise a StopStreaming exception when any WebSocketException is raised
96+ with WebSocketClient (port = port_number ) as client , pytest .raises (StopStreaming ):
97+ for _ in client .listen ():
98+ pass
0 commit comments