1313import urllib .parse
1414from typing import Iterable , List , Optional , Union
1515
16+ import outcome
1617import trio
1718import trio .abc
1819from wsproto import ConnectionType , WSConnection
4445logger = logging .getLogger ('trio-websocket' )
4546
4647
48+ class TrioWebsocketInternalError (Exception ):
49+ ...
50+
51+
4752def _ignore_cancel (exc ):
4853 return None if isinstance (exc , trio .Cancelled ) else exc
4954
@@ -125,10 +130,10 @@ async def open_websocket(
125130 client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126131 or server rejection (:exc:`ConnectionRejected`) during handshakes.
127132 '''
128- async with trio .open_nursery () as new_nursery :
133+ async def open_connection ( nursery : trio .Nursery ) -> WebSocketConnection :
129134 try :
130135 with trio .fail_after (connect_timeout ):
131- connection = await connect_websocket (new_nursery , host , port ,
136+ return await connect_websocket (nursery , host , port ,
132137 resource , use_ssl = use_ssl , subprotocols = subprotocols ,
133138 extra_headers = extra_headers ,
134139 message_queue_size = message_queue_size ,
@@ -137,14 +142,59 @@ async def open_websocket(
137142 raise ConnectionTimeout from None
138143 except OSError as e :
139144 raise HandshakeError from e
145+
146+ async def close_connection (connection : WebSocketConnection ) -> None :
140147 try :
141- yield connection
142- finally :
143- try :
144- with trio .fail_after (disconnect_timeout ):
145- await connection .aclose ()
146- except trio .TooSlowError :
147- raise DisconnectionTimeout from None
148+ with trio .fail_after (disconnect_timeout ):
149+ await connection .aclose ()
150+ except trio .TooSlowError :
151+ raise DisconnectionTimeout from None
152+
153+ connection : WebSocketConnection | None = None
154+ result2 : outcome .Maybe [None ] | None = None
155+ user_error = None
156+
157+ try :
158+ async with trio .open_nursery () as new_nursery :
159+ result = await outcome .acapture (open_connection , new_nursery )
160+
161+ if isinstance (result , outcome .Value ):
162+ connection = result .unwrap ()
163+ try :
164+ yield connection
165+ except BaseException as e :
166+ user_error = e
167+ raise
168+ finally :
169+ result2 = await outcome .acapture (close_connection , connection )
170+ # This exception handler should only be entered if:
171+ # 1. The _reader_task started in connect_websocket raises
172+ # 2. User code raises an exception
173+ except BaseExceptionGroup as e :
174+ # user_error, or exception bubbling up from _reader_task
175+ if len (e .exceptions ) == 1 :
176+ raise e .exceptions [0 ]
177+ # if the group contains two exceptions, one being Cancelled, and the other
178+ # is user_error => drop Cancelled and raise user_error
179+ # This Cancelled should only have been able to come from _reader_task
180+ if (
181+ len (e .exceptions ) == 2
182+ and user_error is not None
183+ and user_error in e .exceptions
184+ and any (isinstance (exc , trio .Cancelled ) for exc in e .exceptions )
185+ ):
186+ raise user_error # pylint: disable=raise-missing-from,raising-bad-type
187+ raise TrioWebsocketInternalError from e # pragma: no cover
188+ ## TODO: handle keyboardinterrupt?
189+
190+ finally :
191+ if result2 is not None :
192+ result2 .unwrap ()
193+
194+
195+ # error setting up, unwrap that exception
196+ if connection is None :
197+ result .unwrap ()
148198
149199
150200async def connect_websocket (nursery , host , port , resource , * , use_ssl ,
0 commit comments