Skip to content

Commit 7d2e7f7

Browse files
Allow functions to be used for URL, headers and auth data in client connection (Fixes #588)
1 parent 2538df8 commit 7d2e7f7

File tree

4 files changed

+136
-10
lines changed

4 files changed

+136
-10
lines changed

socketio/asyncio_client.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,19 @@ async def connect(self, url, headers={}, auth=None, transports=None,
6868
"""Connect to a Socket.IO server.
6969
7070
:param url: The URL of the Socket.IO server. It can include custom
71-
query string parameters if required by the server.
71+
query string parameters if required by the server. If a
72+
function is provided, the client will invoke it to obtain
73+
the URL each time a connection or reconnection is
74+
attempted.
7275
:param headers: A dictionary with custom headers to send with the
73-
connection request.
76+
connection request. If a function is provided, the
77+
client will invoke it to obtain the headers dictionary
78+
each time a connection or reconnection is attempted.
7479
:param auth: Authentication data passed to the server with the
7580
connection request, normally a dictionary with one or
76-
more string key/value pairs.
81+
more string key/value pairs. If a function is provided,
82+
the client will invoke it to obtain the authentication
83+
data each time a connection or reconnection is attempted.
7784
:param transports: The list of allowed transports. Valid transports
7885
are ``'polling'`` and ``'websocket'``. If not
7986
given, the polling transport is connected first,
@@ -124,8 +131,10 @@ async def connect(self, url, headers={}, auth=None, transports=None,
124131
self._connect_event = self.eio.create_event()
125132
else:
126133
self._connect_event.clear()
134+
real_url = await self._get_real_value(self.connection_url)
135+
real_headers = await self._get_real_value(self.connection_headers)
127136
try:
128-
await self.eio.connect(url, headers=headers,
137+
await self.eio.connect(real_url, headers=real_headers,
129138
transports=transports,
130139
engineio_path=socketio_path)
131140
except engineio.exceptions.ConnectionError as exc:
@@ -320,6 +329,15 @@ async def sleep(self, seconds=0):
320329
"""
321330
return await self.eio.sleep(seconds)
322331

332+
async def _get_real_value(self, value):
333+
"""Return the actual value, for parameters that can also be given as
334+
callables."""
335+
if not callable(value):
336+
return value
337+
if asyncio.iscoroutinefunction(value):
338+
return await value()
339+
return value()
340+
323341
async def _send_packet(self, pkt):
324342
"""Send a Socket.IO packet to the server."""
325343
encoded_packet = pkt.encode()
@@ -462,9 +480,10 @@ async def _handle_eio_connect(self):
462480
"""Handle the Engine.IO connection event."""
463481
self.logger.info('Engine.IO connection established')
464482
self.sid = self.eio.sid
483+
real_auth = await self._get_real_value(self.connection_auth)
465484
for n in self.connection_namespaces:
466485
await self._send_packet(packet.Packet(
467-
packet.CONNECT, data=self.connection_auth, namespace=n))
486+
packet.CONNECT, data=real_auth, namespace=n))
468487

469488
async def _handle_eio_message(self, data):
470489
"""Dispatch Engine.IO messages."""

socketio/client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,19 @@ def connect(self, url, headers={}, auth=None, transports=None,
240240
"""Connect to a Socket.IO server.
241241
242242
:param url: The URL of the Socket.IO server. It can include custom
243-
query string parameters if required by the server.
243+
query string parameters if required by the server. If a
244+
function is provided, the client will invoke it to obtain
245+
the URL each time a connection or reconnection is
246+
attempted.
244247
:param headers: A dictionary with custom headers to send with the
245-
connection request.
248+
connection request. If a function is provided, the
249+
client will invoke it to obtain the headers dictionary
250+
each time a connection or reconnection is attempted.
246251
:param auth: Authentication data passed to the server with the
247252
connection request, normally a dictionary with one or
248-
more string key/value pairs.
253+
more string key/value pairs. If a function is provided,
254+
the client will invoke it to obtain the authentication
255+
data each time a connection or reconnection is attempted.
249256
:param transports: The list of allowed transports. Valid transports
250257
are ``'polling'`` and ``'websocket'``. If not
251258
given, the polling transport is connected first,
@@ -294,8 +301,11 @@ def connect(self, url, headers={}, auth=None, transports=None,
294301
self._connect_event = self.eio.create_event()
295302
else:
296303
self._connect_event.clear()
304+
real_url = self._get_real_value(self.connection_url)
305+
real_headers = self._get_real_value(self.connection_headers)
297306
try:
298-
self.eio.connect(url, headers=headers, transports=transports,
307+
self.eio.connect(real_url, headers=real_headers,
308+
transports=transports,
299309
engineio_path=socketio_path)
300310
except engineio.exceptions.ConnectionError as exc:
301311
self._trigger_event(
@@ -490,6 +500,13 @@ def sleep(self, seconds=0):
490500
"""
491501
return self.eio.sleep(seconds)
492502

503+
def _get_real_value(self, value):
504+
"""Return the actual value, for parameters that can also be given as
505+
callables."""
506+
if not callable(value):
507+
return value
508+
return value()
509+
493510
def _send_packet(self, pkt):
494511
"""Send a Socket.IO packet to the server."""
495512
encoded_packet = pkt.encode()
@@ -628,9 +645,10 @@ def _handle_eio_connect(self):
628645
"""Handle the Engine.IO connection event."""
629646
self.logger.info('Engine.IO connection established')
630647
self.sid = self.eio.sid
648+
real_auth = self._get_real_value(self.connection_auth)
631649
for n in self.connection_namespaces:
632650
self._send_packet(packet.Packet(
633-
packet.CONNECT, data=self.connection_auth, namespace=n))
651+
packet.CONNECT, data=real_auth, namespace=n))
634652

635653
def _handle_eio_message(self, data):
636654
"""Dispatch Engine.IO messages."""

tests/asyncio/test_asyncio_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,30 @@ def test_connect(self):
7575
engineio_path='path',
7676
)
7777

78+
def test_connect_functions(self):
79+
async def headers():
80+
return 'headers'
81+
82+
c = asyncio_client.AsyncClient()
83+
c.eio.connect = AsyncMock()
84+
_run(
85+
c.connect(
86+
lambda: 'url',
87+
headers=headers,
88+
auth='auth',
89+
transports='transports',
90+
namespaces=['/foo', '/', '/bar'],
91+
socketio_path='path',
92+
wait=False,
93+
)
94+
)
95+
c.eio.connect.mock.assert_called_once_with(
96+
'url',
97+
headers='headers',
98+
transports='transports',
99+
engineio_path='path',
100+
)
101+
78102
def test_connect_one_namespace(self):
79103
c = asyncio_client.AsyncClient()
80104
c.eio.connect = AsyncMock()
@@ -960,6 +984,29 @@ def test_handle_eio_connect(self):
960984
== expected_packet.encode()
961985
)
962986

987+
def test_handle_eio_connect_function(self):
988+
c = asyncio_client.AsyncClient()
989+
c.connection_namespaces = ['/', '/foo']
990+
c.connection_auth = lambda: 'auth'
991+
c._send_packet = AsyncMock()
992+
c.eio.sid = 'foo'
993+
assert c.sid is None
994+
_run(c._handle_eio_connect())
995+
assert c.sid == 'foo'
996+
assert c._send_packet.mock.call_count == 2
997+
expected_packet = packet.Packet(
998+
packet.CONNECT, data='auth', namespace='/')
999+
assert (
1000+
c._send_packet.mock.call_args_list[0][0][0].encode()
1001+
== expected_packet.encode()
1002+
)
1003+
expected_packet = packet.Packet(
1004+
packet.CONNECT, data='auth', namespace='/foo')
1005+
assert (
1006+
c._send_packet.mock.call_args_list[1][0][0].encode()
1007+
== expected_packet.encode()
1008+
)
1009+
9631010
def test_handle_eio_message(self):
9641011
c = asyncio_client.AsyncClient()
9651012
c._handle_connect = AsyncMock()

tests/common/test_client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,25 @@ def test_connect(self):
173173
engineio_path='path',
174174
)
175175

176+
def test_connect_functions(self):
177+
c = client.Client()
178+
c.eio.connect = mock.MagicMock()
179+
c.connect(
180+
lambda: 'url',
181+
headers=lambda: 'headers',
182+
auth='auth',
183+
transports='transports',
184+
namespaces=['/foo', '/', '/bar'],
185+
socketio_path='path',
186+
wait=False,
187+
)
188+
c.eio.connect.assert_called_once_with(
189+
'url',
190+
headers='headers',
191+
transports='transports',
192+
engineio_path='path',
193+
)
194+
176195
def test_connect_one_namespace(self):
177196
c = client.Client()
178197
c.eio.connect = mock.MagicMock()
@@ -1030,6 +1049,29 @@ def test_handle_eio_connect(self):
10301049
== expected_packet.encode()
10311050
)
10321051

1052+
def test_handle_eio_connect_function(self):
1053+
c = client.Client()
1054+
c.connection_namespaces = ['/', '/foo']
1055+
c.connection_auth = lambda: 'auth'
1056+
c._send_packet = mock.MagicMock()
1057+
c.eio.sid = 'foo'
1058+
assert c.sid is None
1059+
c._handle_eio_connect()
1060+
assert c.sid == 'foo'
1061+
assert c._send_packet.call_count == 2
1062+
expected_packet = packet.Packet(
1063+
packet.CONNECT, data='auth', namespace='/')
1064+
assert (
1065+
c._send_packet.call_args_list[0][0][0].encode()
1066+
== expected_packet.encode()
1067+
)
1068+
expected_packet = packet.Packet(
1069+
packet.CONNECT, data='auth', namespace='/foo')
1070+
assert (
1071+
c._send_packet.call_args_list[1][0][0].encode()
1072+
== expected_packet.encode()
1073+
)
1074+
10331075
def test_handle_eio_message(self):
10341076
c = client.Client()
10351077
c._handle_connect = mock.MagicMock()

0 commit comments

Comments
 (0)