Skip to content

Commit 4da6d74

Browse files
Added wait argument to client's connect method (Fixes #634)
1 parent f341abe commit 4da6d74

File tree

4 files changed

+256
-10
lines changed

4 files changed

+256
-10
lines changed

socketio/asyncio_client.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def is_asyncio_based(self):
6363
return True
6464

6565
async def connect(self, url, headers={}, transports=None,
66-
namespaces=None, socketio_path='socket.io'):
66+
namespaces=None, socketio_path='socket.io', wait=True,
67+
wait_timeout=1):
6768
"""Connect to a Socket.IO server.
6869
6970
:param url: The URL of the Socket.IO server. It can include custom
@@ -80,18 +81,26 @@ async def connect(self, url, headers={}, transports=None,
8081
:param socketio_path: The endpoint where the Socket.IO server is
8182
installed. The default value is appropriate for
8283
most cases.
84+
:param wait: if set to ``True`` (the default) the call only returns
85+
when all the namespaces are connected. If set to
86+
``False``, the call returns as soon as the Engine.IO
87+
transport is connected, and the namespaces will connect
88+
in the background.
89+
:param wait_timeout: How long the client should wait for the
90+
connection. The default is 1 second. This
91+
argument is only considered when ``wait`` is set
92+
to ``True``.
8393
8494
Note: this method is a coroutine.
8595
86-
Note: The connection mechannism occurs in the background and will
87-
complete at some point after this function returns. The connection
88-
will be established when the ``connect`` event is invoked.
89-
9096
Example usage::
9197
9298
sio = socketio.AsyncClient()
9399
sio.connect('http://localhost:5000')
94100
"""
101+
if self.connected:
102+
raise exceptions.ConnectionError('Already connected')
103+
95104
self.connection_url = url
96105
self.connection_headers = headers
97106
self.connection_transports = transports
@@ -106,6 +115,11 @@ async def connect(self, url, headers={}, transports=None,
106115
elif isinstance(namespaces, str):
107116
namespaces = [namespaces]
108117
self.connection_namespaces = namespaces
118+
self.namespaces = {}
119+
if self._connect_event is None:
120+
self._connect_event = self.eio.create_event()
121+
else:
122+
self._connect_event.clear()
109123
try:
110124
await self.eio.connect(url, headers=headers,
111125
transports=transports,
@@ -115,6 +129,22 @@ async def connect(self, url, headers={}, transports=None,
115129
'connect_error', '/',
116130
exc.args[1] if len(exc.args) > 1 else exc.args[0])
117131
raise exceptions.ConnectionError(exc.args[0]) from None
132+
133+
if wait:
134+
try:
135+
while True:
136+
await asyncio.wait_for(self._connect_event.wait(),
137+
wait_timeout)
138+
self._connect_event.clear()
139+
if set(self.namespaces) == set(self.connection_namespaces):
140+
break
141+
except asyncio.TimeoutError:
142+
pass
143+
if set(self.namespaces) != set(self.connection_namespaces):
144+
await self.disconnect()
145+
raise exceptions.ConnectionError(
146+
'One or more namespaces failed to connect')
147+
118148
self.connected = True
119149

120150
async def wait(self):
@@ -301,6 +331,7 @@ async def _handle_connect(self, namespace, data):
301331
self.logger.info('Namespace {} is connected'.format(namespace))
302332
self.namespaces[namespace] = (data or {}).get('sid', self.sid)
303333
await self._trigger_event('connect', namespace=namespace)
334+
self._connect_event.set()
304335

305336
async def _handle_disconnect(self, namespace):
306337
if not self.connected:
@@ -355,6 +386,7 @@ async def _handle_error(self, namespace, data):
355386
elif not isinstance(data, (tuple, list)):
356387
data = (data,)
357388
await self._trigger_event('connect_error', namespace, *data)
389+
self._connect_event.set()
358390
if namespace in self.namespaces:
359391
del self.namespaces[namespace]
360392
if namespace == '/':

socketio/client.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0,
131131
self.namespace_handlers = {}
132132
self.callbacks = {}
133133
self._binary_packet = None
134+
self._connect_event = None
134135
self._reconnect_task = None
135136
self._reconnect_abort = None
136137

@@ -233,7 +234,8 @@ def register_namespace(self, namespace_handler):
233234
namespace_handler
234235

235236
def connect(self, url, headers={}, transports=None,
236-
namespaces=None, socketio_path='socket.io'):
237+
namespaces=None, socketio_path='socket.io', wait=True,
238+
wait_timeout=1):
237239
"""Connect to a Socket.IO server.
238240
239241
:param url: The URL of the Socket.IO server. It can include custom
@@ -250,16 +252,24 @@ def connect(self, url, headers={}, transports=None,
250252
:param socketio_path: The endpoint where the Socket.IO server is
251253
installed. The default value is appropriate for
252254
most cases.
253-
254-
Note: The connection mechannism occurs in the background and will
255-
complete at some point after this function returns. The connection
256-
will be established when the ``connect`` event is invoked.
255+
:param wait: if set to ``True`` (the default) the call only returns
256+
when all the namespaces are connected. If set to
257+
``False``, the call returns as soon as the Engine.IO
258+
transport is connected, and the namespaces will connect
259+
in the background.
260+
:param wait_timeout: How long the client should wait for the
261+
connection. The default is 1 second. This
262+
argument is only considered when ``wait`` is set
263+
to ``True``.
257264
258265
Example usage::
259266
260267
sio = socketio.Client()
261268
sio.connect('http://localhost:5000')
262269
"""
270+
if self.connected:
271+
raise exceptions.ConnectionError('Already connected')
272+
263273
self.connection_url = url
264274
self.connection_headers = headers
265275
self.connection_transports = transports
@@ -274,6 +284,11 @@ def connect(self, url, headers={}, transports=None,
274284
elif isinstance(namespaces, str):
275285
namespaces = [namespaces]
276286
self.connection_namespaces = namespaces
287+
self.namespaces = {}
288+
if self._connect_event is None:
289+
self._connect_event = self.eio.create_event()
290+
else:
291+
self._connect_event.clear()
277292
try:
278293
self.eio.connect(url, headers=headers, transports=transports,
279294
engineio_path=socketio_path)
@@ -282,6 +297,17 @@ def connect(self, url, headers={}, transports=None,
282297
'connect_error', '/',
283298
exc.args[1] if len(exc.args) > 1 else exc.args[0])
284299
raise exceptions.ConnectionError(exc.args[0]) from None
300+
301+
if wait:
302+
while self._connect_event.wait(timeout=wait_timeout):
303+
self._connect_event.clear()
304+
if set(self.namespaces) == set(self.connection_namespaces):
305+
break
306+
if set(self.namespaces) != set(self.connection_namespaces):
307+
self.disconnect()
308+
raise exceptions.ConnectionError(
309+
'One or more namespaces failed to connect')
310+
285311
self.connected = True
286312

287313
def wait(self):
@@ -483,6 +509,7 @@ def _handle_connect(self, namespace, data):
483509
self.logger.info('Namespace {} is connected'.format(namespace))
484510
self.namespaces[namespace] = (data or {}).get('sid', self.sid)
485511
self._trigger_event('connect', namespace=namespace)
512+
self._connect_event.set()
486513

487514
def _handle_disconnect(self, namespace):
488515
if not self.connected:
@@ -534,6 +561,7 @@ def _handle_error(self, namespace, data):
534561
elif not isinstance(data, (tuple, list)):
535562
data = (data,)
536563
self._trigger_event('connect_error', namespace, *data)
564+
self._connect_event.set()
537565
if namespace in self.namespaces:
538566
del self.namespaces[namespace]
539567
if namespace == '/':

0 commit comments

Comments
 (0)