Skip to content

Commit d09ca93

Browse files
authored
feat: Support dynamic query parameters on reconnect (#51)
The `ConnectStrategy` can be created with a `query_params` callable. This callable will return a set of parameters that should be used to update any static query parameters initially configured. This functionality enables FDv2 selector behavior where we want to resume from our last known checkpoint.
1 parent 638c403 commit d09ca93

File tree

4 files changed

+76
-5
lines changed

4 files changed

+76
-5
lines changed

ld_eventsource/config/connect_strategy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from logging import Logger
45
from typing import Callable, Iterator, Optional, Union
56

67
from urllib3 import PoolManager
78

8-
from ld_eventsource.http import _HttpClientImpl, _HttpConnectParams
9+
from ld_eventsource.http import (DynamicQueryParams, _HttpClientImpl,
10+
_HttpConnectParams)
911

1012

1113
class ConnectStrategy:
@@ -38,6 +40,7 @@ def http(
3840
headers: Optional[dict] = None,
3941
pool: Optional[PoolManager] = None,
4042
urllib3_request_options: Optional[dict] = None,
43+
query_params: Optional[DynamicQueryParams] = None
4144
) -> ConnectStrategy:
4245
"""
4346
Creates the default HTTP implementation, specifying request parameters.
@@ -47,9 +50,11 @@ def http(
4750
:param pool: optional urllib3 ``PoolManager`` to provide an HTTP client
4851
:param urllib3_request_options: optional ``kwargs`` to add to the ``request`` call; these
4952
can include any parameters supported by ``urllib3``, such as ``timeout``
53+
:param query_params: optional callable that can be used to affect query parameters
54+
dynamically for each connection attempt
5055
"""
5156
return _HttpConnectStrategy(
52-
_HttpConnectParams(url, headers, pool, urllib3_request_options)
57+
_HttpConnectParams(url, headers, pool, urllib3_request_options, query_params)
5358
)
5459

5560

ld_eventsource/http.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from logging import Logger
22
from typing import Callable, Iterator, Optional, Tuple
3+
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
34

45
from urllib3 import PoolManager
56
from urllib3.exceptions import MaxRetryError
@@ -9,6 +10,12 @@
910

1011
_CHUNK_SIZE = 10000
1112

13+
DynamicQueryParams = Callable[[], dict[str, str]]
14+
"""
15+
A callable that returns a dictionary of query parameters to add to the URL.
16+
This can be used to modify query parameters dynamically for each connection attempt.
17+
"""
18+
1219

1320
class _HttpConnectParams:
1421
def __init__(
@@ -17,16 +24,22 @@ def __init__(
1724
headers: Optional[dict] = None,
1825
pool: Optional[PoolManager] = None,
1926
urllib3_request_options: Optional[dict] = None,
27+
query_params: Optional[DynamicQueryParams] = None
2028
):
2129
self.__url = url
2230
self.__headers = headers
2331
self.__pool = pool
2432
self.__urllib3_request_options = urllib3_request_options
33+
self.__query_params = query_params
2534

2635
@property
2736
def url(self) -> str:
2837
return self.__url
2938

39+
@property
40+
def query_params(self) -> Optional[DynamicQueryParams]:
41+
return self.__query_params
42+
3043
@property
3144
def headers(self) -> Optional[dict]:
3245
return self.__headers
@@ -48,7 +61,16 @@ def __init__(self, params: _HttpConnectParams, logger: Logger):
4861
self.__logger = logger
4962

5063
def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]:
51-
self.__logger.info("Connecting to stream at %s" % self.__params.url)
64+
url = self.__params.url
65+
if self.__params.query_params is not None:
66+
qp = self.__params.query_params()
67+
if qp:
68+
url_parts = list(urlsplit(url))
69+
query = dict(parse_qsl(url_parts[3]))
70+
query.update(qp)
71+
url_parts[3] = urlencode(query)
72+
url = urlunsplit(url_parts)
73+
self.__logger.info("Connecting to stream at %s" % url)
5274

5375
headers = self.__params.headers.copy() if self.__params.headers else {}
5476
headers['Cache-Control'] = 'no-cache'
@@ -67,7 +89,7 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab
6789
try:
6890
resp = self.__pool.request(
6991
'GET',
70-
self.__params.url,
92+
url,
7193
preload_content=False,
7294
retries=Retry(
7395
total=None, read=0, connect=0, status=0, other=0, redirect=3

ld_eventsource/testing/http_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def do_POST(self):
113113
def _do_request(self):
114114
server_wrapper = self.server.server_wrapper
115115
server_wrapper.requests.put(MockServerRequest(self))
116-
handler = server_wrapper.matchers.get(self.path)
116+
handler = server_wrapper.matchers.get(self.path.split("?")[0], None)
117117
if handler:
118118
handler.write(self)
119119
else:

ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from urllib.parse import parse_qsl
2+
13
from ld_eventsource import *
24
from ld_eventsource.config import *
35
from ld_eventsource.testing.helpers import *
@@ -56,6 +58,48 @@ def test_sse_client_reconnects_after_socket_closed():
5658
assert event2.data == 'data2'
5759

5860

61+
def test_sse_client_allows_modifying_query_params_dynamically():
62+
count = 0
63+
64+
def dynamic_query_params() -> dict[str, str]:
65+
nonlocal count
66+
count += 1
67+
params = {'count': str(count)}
68+
if count > 1:
69+
params['option'] = 'updated'
70+
71+
return params
72+
73+
with start_server() as server:
74+
with make_stream() as stream1:
75+
with make_stream() as stream2:
76+
server.for_path('/', SequentialHandler(stream1, stream2))
77+
stream1.push("event: a\ndata: data1\nid: id123\n\n")
78+
stream2.push("event: b\ndata: data2\n\n")
79+
with SSEClient(
80+
connect=ConnectStrategy.http(f"{server.uri}?basis=unchanging&option=initial", query_params=dynamic_query_params),
81+
error_strategy=ErrorStrategy.always_continue(),
82+
initial_retry_delay=0,
83+
) as client:
84+
client.start()
85+
next(client.events)
86+
stream1.close()
87+
next(client.events)
88+
r1 = server.await_request()
89+
r1_query_params = dict(parse_qsl(r1.path.split('?', 1)[1]))
90+
91+
# Ensure we can add, retain, and modify query parameters
92+
assert r1_query_params.get('count') == '1'
93+
assert r1_query_params.get('basis') == 'unchanging'
94+
assert r1_query_params.get('option') == 'initial'
95+
96+
r2 = server.await_request()
97+
r2_query_params = dict(parse_qsl(r2.path.split('?', 1)[1]))
98+
assert r2_query_params.get('count') == '2'
99+
assert r2_query_params.get('basis') == 'unchanging'
100+
assert r2_query_params.get('option') == 'updated'
101+
102+
59103
def test_sse_client_sends_last_event_id_on_reconnect():
60104
with start_server() as server:
61105
with make_stream() as stream1:

0 commit comments

Comments
 (0)