Skip to content

Commit 1b9f84b

Browse files
authored
Merge pull request #458 from consideRatio/test-subprotocols
Ensure no blank `Sec-Websocket-Protocol` headers and warn if websocket subprotocol edge case occur
2 parents 288c74b + eda6136 commit 1b9f84b

File tree

5 files changed

+139
-46
lines changed

5 files changed

+139
-46
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jobs:
9595
pip-install-constraints: >-
9696
jupyter-server==1.0
9797
simpervisor==1.0
98-
tornado==5.0
98+
tornado==5.1
9999
traitlets==4.2.1
100100
101101
steps:

jupyter_server_proxy/handlers.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs):
116116
"rewrite_response",
117117
tuple(),
118118
)
119-
self.subprotocols = None
120119
super().__init__(*args, **kwargs)
121120

122121
# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
@@ -489,15 +488,28 @@ async def start_websocket_connection():
489488
self.log.info(f"Trying to establish websocket connection to {client_uri}")
490489
self._record_activity()
491490
request = httpclient.HTTPRequest(url=client_uri, headers=headers)
491+
subprotocols = (
492+
[self.selected_subprotocol] if self.selected_subprotocol else None
493+
)
492494
self.ws = await pingable_ws_connect(
493495
request=request,
494496
on_message_callback=message_cb,
495497
on_ping_callback=ping_cb,
496-
subprotocols=self.subprotocols,
498+
subprotocols=subprotocols,
497499
resolver=resolver,
498500
)
499501
self._record_activity()
500502
self.log.info(f"Websocket connection established to {client_uri}")
503+
if (
504+
subprotocols
505+
and self.ws.selected_subprotocol != self.selected_subprotocol
506+
):
507+
self.log.warn(
508+
f"Websocket subprotocol between proxy/server ({self.ws.selected_subprotocol}) "
509+
f"became different than for client/proxy ({self.selected_subprotocol}) "
510+
"due to https://github.com/jupyterhub/jupyter-server-proxy/issues/459. "
511+
f"Requested subprotocols were {subprotocols}."
512+
)
501513

502514
# Wait for the WebSocket to be connected before resolving.
503515
# Otherwise, messages sent by the client before the
@@ -531,12 +543,25 @@ def check_xsrf_cookie(self):
531543
"""
532544

533545
def select_subprotocol(self, subprotocols):
534-
"""Select a single Sec-WebSocket-Protocol during handshake."""
535-
self.subprotocols = subprotocols
536-
if isinstance(subprotocols, list) and subprotocols:
537-
self.log.debug(f"Client sent subprotocols: {subprotocols}")
546+
"""
547+
Select a single Sec-WebSocket-Protocol during handshake.
548+
549+
Note that this subprotocol selection should really be delegated to the
550+
server we proxy to, but we don't! For this to happen, we would need to
551+
delay accepting the handshake with the client until we have successfully
552+
handshaked with the server. This issue is tracked via
553+
https://github.com/jupyterhub/jupyter-server-proxy/issues/459.
554+
555+
Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that
556+
includes an informative docstring:
557+
https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360.
558+
"""
559+
if subprotocols:
560+
self.log.debug(
561+
f"Client sent subprotocols: {subprotocols}, selecting the first"
562+
)
538563
return subprotocols[0]
539-
return super().select_subprotocol(subprotocols)
564+
return None
540565

541566

542567
class LocalProxyHandler(ProxyHandler):

pyproject.toml

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@ dependencies = [
5050
"importlib_metadata >=4.8.3 ; python_version<\"3.10\"",
5151
"jupyter-server >=1.0",
5252
"simpervisor >=1.0",
53-
"tornado >=5.0",
53+
"tornado >=5.1",
5454
"traitlets >= 4.2.1",
5555
]
5656

5757
[project.optional-dependencies]
5858
test = [
5959
"pytest",
60+
"pytest-asyncio",
6061
"pytest-cov",
6162
"pytest-html",
6263
]
@@ -195,21 +196,33 @@ src = "pyproject.toml"
195196
[[tool.tbump.file]]
196197
src = "labextension/package.json"
197198

199+
200+
# pytest is used for running Python based tests
201+
#
202+
# ref: https://docs.pytest.org/en/stable/
203+
#
198204
[tool.pytest.ini_options]
199-
cache_dir = "build/.cache/pytest"
200-
testpaths = ["tests"]
201205
addopts = [
202-
"-vv",
206+
"--verbose",
207+
"--durations=10",
208+
"--color=yes",
203209
"--cov=jupyter_server_proxy",
204210
"--cov-branch",
205211
"--cov-context=test",
206212
"--cov-report=term-missing:skip-covered",
207213
"--cov-report=html:build/coverage",
208214
"--no-cov-on-fail",
209215
"--html=build/pytest/index.html",
210-
"--color=yes",
211216
]
217+
asyncio_mode = "auto"
218+
testpaths = ["tests"]
219+
cache_dir = "build/.cache/pytest"
212220

221+
222+
# pytest-cov / coverage is used to measure code coverage of tests
223+
#
224+
# ref: https://coverage.readthedocs.io/en/stable/config.html
225+
#
213226
[tool.coverage.run]
214227
data_file = "build/.coverage"
215228
concurrency = [

tests/resources/websocket.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,48 @@ def get(self):
5454

5555

5656
class EchoWebSocket(tornado.websocket.WebSocketHandler):
57+
"""Echoes back received messages."""
58+
5759
def on_message(self, message):
5860
self.write_message(message)
5961

6062

6163
class HeadersWebSocket(tornado.websocket.WebSocketHandler):
64+
"""Echoes back incoming request headers."""
65+
6266
def on_message(self, message):
6367
self.write_message(json.dumps(dict(self.request.headers)))
6468

6569

6670
class SubprotocolWebSocket(tornado.websocket.WebSocketHandler):
71+
"""
72+
Echoes back requested subprotocols and selected subprotocol as a JSON
73+
encoded message, and selects subprotocols in a very particular way to help
74+
us test things.
75+
"""
76+
6777
def __init__(self, *args, **kwargs):
68-
self._subprotocols = None
78+
self._requested_subprotocols = None
6979
super().__init__(*args, **kwargs)
7080

7181
def select_subprotocol(self, subprotocols):
72-
self._subprotocols = subprotocols
73-
return None
82+
self._requested_subprotocols = subprotocols if subprotocols else None
83+
84+
if not subprotocols:
85+
return None
86+
if "please_select_no_protocol" in subprotocols:
87+
return None
88+
if "favored" in subprotocols:
89+
return "favored"
90+
else:
91+
return subprotocols[0]
7492

7593
def on_message(self, message):
76-
self.write_message(json.dumps(self._subprotocols))
94+
response = {
95+
"requested_subprotocols": self._requested_subprotocols,
96+
"selected_subprotocol": self.selected_subprotocol,
97+
}
98+
self.write_message(json.dumps(response))
7799

78100

79101
def main():

tests/test_proxies.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import gzip
32
import json
43
import sys
@@ -332,14 +331,9 @@ def test_server_content_encoding_header(
332331
assert f.read() == b"this is a test"
333332

334333

335-
@pytest.fixture(scope="module")
336-
def event_loop():
337-
loop = asyncio.get_event_loop()
338-
yield loop
339-
loop.close()
340-
341-
342-
async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
334+
async def test_server_proxy_websocket_messages(
335+
a_server_port_and_token: Tuple[int, str]
336+
) -> None:
343337
PORT = a_server_port_and_token[0]
344338
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/echosocket"
345339
conn = await websocket_connect(url)
@@ -349,13 +343,7 @@ async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
349343
assert msg == expected_msg
350344

351345

352-
def test_server_proxy_websocket(
353-
event_loop, a_server_port_and_token: Tuple[int, str]
354-
) -> None:
355-
event_loop.run_until_complete(_websocket_echo(a_server_port_and_token))
356-
357-
358-
async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
346+
async def test_server_proxy_websocket_headers(a_server_port_and_token: Tuple[int, str]):
359347
PORT = a_server_port_and_token[0]
360348
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket"
361349
conn = await websocket_connect(url)
@@ -366,25 +354,68 @@ async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
366354
assert headers["X-Custom-Header"] == "pytest-23456"
367355

368356

369-
def test_server_proxy_websocket_headers(
370-
event_loop, a_server_port_and_token: Tuple[int, str]
357+
@pytest.mark.parametrize(
358+
"client_requested,server_received,server_responded,proxy_responded",
359+
[
360+
(None, None, None, None),
361+
(["first"], ["first"], "first", "first"),
362+
# IMPORTANT: The tests below verify current bugged behavior, and the
363+
# commented out tests is what we want to succeed!
364+
#
365+
# The proxy websocket should actually respond the handshake
366+
# with a subprotocol based on a the server handshake
367+
# response, but we are finalizing the client/proxy handshake
368+
# before the proxy/server handshake, and that makes it
369+
# impossible. We currently instead just pick the first
370+
# requested protocol no matter what what subprotocol the
371+
# server picks.
372+
#
373+
# Bug 1 - server wasn't passed all subprotocols:
374+
(["first", "second"], ["first"], "first", "first"),
375+
# (["first", "second"], ["first", "second"], "first", "first"),
376+
#
377+
# Bug 2 - server_responded doesn't match proxy_responded:
378+
(["first", "favored"], ["first"], "first", "first"),
379+
# (["first", "favored"], ["first", "favored"], "favored", "favored"),
380+
(
381+
["please_select_no_protocol"],
382+
["please_select_no_protocol"],
383+
None,
384+
"please_select_no_protocol",
385+
),
386+
# (["please_select_no_protocol"], ["please_select_no_protocol"], None, None),
387+
],
388+
)
389+
async def test_server_proxy_websocket_subprotocols(
390+
a_server_port_and_token: Tuple[int, str],
391+
client_requested,
392+
server_received,
393+
server_responded,
394+
proxy_responded,
371395
):
372-
event_loop.run_until_complete(_websocket_headers(a_server_port_and_token))
373-
374-
375-
async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None:
376396
PORT, TOKEN = a_server_port_and_token
377397
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/subprotocolsocket"
378-
conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"])
398+
conn = await websocket_connect(url, subprotocols=client_requested)
379399
await conn.write_message("Hello, world!")
400+
401+
# verify understanding of websocket_connect that this test relies on
402+
if client_requested:
403+
assert "Sec-Websocket-Protocol" in conn.request.headers
404+
else:
405+
assert "Sec-Websocket-Protocol" not in conn.request.headers
406+
380407
msg = await conn.read_message()
381-
assert json.loads(msg) == ["protocol_1", "protocol_2"]
408+
info = json.loads(msg)
382409

410+
assert info["requested_subprotocols"] == server_received
411+
assert info["selected_subprotocol"] == server_responded
412+
assert conn.selected_subprotocol == proxy_responded
383413

384-
def test_server_proxy_websocket_subprotocols(
385-
event_loop, a_server_port_and_token: Tuple[int, str]
386-
):
387-
event_loop.run_until_complete(_websocket_subprotocols(a_server_port_and_token))
414+
# verify proxy response headers directly
415+
if proxy_responded is None:
416+
assert "Sec-Websocket-Protocol" not in conn.headers
417+
else:
418+
assert "Sec-Websocket-Protocol" in conn.headers
388419

389420

390421
@pytest.mark.parametrize(
@@ -410,7 +441,9 @@ def test_bad_server_proxy_url(
410441
assert "X-ProxyContextPath" not in r.headers
411442

412443

413-
def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None:
444+
def test_callable_environment_formatting(
445+
a_server_port_and_token: Tuple[int, str]
446+
) -> None:
414447
PORT, TOKEN = a_server_port_and_token
415448
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
416449
assert r.code == 200

0 commit comments

Comments
 (0)