Skip to content

Commit 429625a

Browse files
authored
Merge pull request #180 from xhochy/proxy-ws-subprotocols
Proxy Websocket subprotocols
2 parents 80263c5 + 3764679 commit 429625a

File tree

5 files changed

+129
-4
lines changed

5 files changed

+129
-4
lines changed

jupyter_server_proxy/handlers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs):
4444
self.proxy_base = ''
4545
self.absolute_url = kwargs.pop('absolute_url', False)
4646
self.host_whitelist = kwargs.pop('host_whitelist', ['localhost', '127.0.0.1'])
47+
self.subprotocols = None
4748
super().__init__(*args, **kwargs)
4849

4950
# Support all the methods that tornado does by default except for GET which
@@ -284,7 +285,8 @@ async def start_websocket_connection():
284285
self._record_activity()
285286
request = httpclient.HTTPRequest(url=client_uri, headers=headers)
286287
self.ws = await pingable_ws_connect(request=request,
287-
on_message_callback=message_cb, on_ping_callback=ping_cb)
288+
on_message_callback=message_cb, on_ping_callback=ping_cb,
289+
subprotocols=self.subprotocols)
288290
ws_connected.set_result(True)
289291
self._record_activity()
290292
self.log.info('Websocket connection established to {}'.format(client_uri))
@@ -316,6 +318,7 @@ def check_xsrf_cookie(self):
316318

317319
def select_subprotocol(self, subprotocols):
318320
'''Select a single Sec-WebSocket-Protocol during handshake.'''
321+
self.subprotocols = subprotocols
319322
if isinstance(subprotocols, list) and subprotocols:
320323
self.log.info('Client sent subprotocols: {}'.format(subprotocols))
321324
return subprotocols[0]

jupyter_server_proxy/websocket.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def on_ping(self, data):
3636
self._on_ping_callback(data)
3737

3838

39-
def pingable_ws_connect(request=None, on_message_callback=None,
40-
on_ping_callback=None):
39+
def pingable_ws_connect(request=None,on_message_callback=None,
40+
on_ping_callback=None, subprotocols=None):
4141
"""
4242
A variation on websocket_connect that returns a PingableWSClientConnection
4343
with on_ping_callback.
@@ -60,7 +60,8 @@ def pingable_ws_connect(request=None, on_message_callback=None,
6060
compression_options={},
6161
on_message_callback=on_message_callback,
6262
on_ping_callback=on_ping_callback,
63-
max_message_size=getattr(websocket, '_default_max_message_size', 10 * 1024 * 1024))
63+
max_message_size=getattr(websocket, '_default_max_message_size', 10 * 1024 * 1024),
64+
subprotocols=subprotocols)
6465

6566
return conn.connect_future
6667

tests/resources/jupyter_server_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def mappathf(path):
2424
'command': ['python3', './tests/resources/httpinfo.py', '{port}'],
2525
'mappath': mappathf,
2626
},
27+
'python-websocket' : {
28+
'command': ['python3', './tests/resources/websocket.py', '--port={port}'],
29+
}
2730
}
2831

2932
import sys

tests/resources/websocket.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python
2+
#
3+
# Based on the chat demo from https://github.com/tornadoweb/tornado/blob/d6819307ee050bbd8ec5deb623e9150ce2220ef9/demos/websocket/chatdemo.py#L1
4+
# Original License:
5+
#
6+
# Copyright 2009 Facebook
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
9+
# not use this file except in compliance with the License. You may obtain
10+
# a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
17+
# License for the specific language governing permissions and limitations
18+
# under the License.
19+
20+
21+
import logging
22+
import json
23+
import tornado.escape
24+
import tornado.ioloop
25+
import tornado.options
26+
import tornado.web
27+
import tornado.websocket
28+
import os.path
29+
import uuid
30+
31+
from tornado.options import define, options
32+
33+
define("port", default=8888, help="run on the given port", type=int)
34+
35+
36+
class Application(tornado.web.Application):
37+
def __init__(self):
38+
handlers = [
39+
(r"/", MainHandler),
40+
(r"/echosocket", EchoWebSocket),
41+
(r"/subprotocolsocket", SubprotocolWebSocket),
42+
]
43+
settings = dict(
44+
cookie_secret="__RANDOM_VALUE__",
45+
template_path=os.path.join(os.path.dirname(__file__), "templates"),
46+
static_path=os.path.join(os.path.dirname(__file__), "static"),
47+
xsrf_cookies=True,
48+
)
49+
super(Application, self).__init__(handlers, **settings)
50+
51+
52+
class MainHandler(tornado.web.RequestHandler):
53+
def get(self):
54+
self.write("Hello, world!")
55+
56+
57+
class EchoWebSocket(tornado.websocket.WebSocketHandler):
58+
def on_message(self, message):
59+
self.write_message(message)
60+
61+
62+
class SubprotocolWebSocket(tornado.websocket.WebSocketHandler):
63+
def __init__(self, *args, **kwargs):
64+
self._subprotocols = None
65+
super().__init__(*args, **kwargs)
66+
67+
def select_subprotocol(self, subprotocols):
68+
self._subprotocols = subprotocols
69+
return None
70+
71+
def on_message(self, message):
72+
self.write_message(json.dumps(self._subprotocols))
73+
74+
75+
def main():
76+
tornado.options.parse_command_line()
77+
app = Application()
78+
app.listen(options.port)
79+
tornado.ioloop.IOLoop.current().start()
80+
81+
82+
if __name__ == "__main__":
83+
main()

tests/test_proxies.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import asyncio
2+
import json
13
import os
24
from http.client import HTTPConnection
35
import pytest
6+
from tornado.websocket import websocket_connect
47

58
PORT = os.getenv('TEST_PORT', 8888)
69
TOKEN = os.getenv('JUPYTER_TOKEN', 'secret')
@@ -99,3 +102,35 @@ def test_server_proxy_mappath_callable(requestpath, expected):
99102
def test_server_proxy_remote():
100103
r = request_get(PORT, '/newproxy', TOKEN, host='127.0.0.1')
101104
assert r.code == 200
105+
106+
107+
@pytest.fixture(scope="module")
108+
def event_loop():
109+
loop = asyncio.get_event_loop()
110+
yield loop
111+
loop.close()
112+
113+
114+
async def _websocket_echo():
115+
url = "ws://localhost:{}/python-websocket/echosocket".format(PORT)
116+
conn = await websocket_connect(url)
117+
expected_msg = "Hello, world!"
118+
await conn.write_message(expected_msg)
119+
msg = await conn.read_message()
120+
assert msg == expected_msg
121+
122+
def test_server_proxy_websocket(event_loop):
123+
event_loop.run_until_complete(_websocket_echo())
124+
125+
126+
async def _websocket_subprotocols():
127+
url = "ws://localhost:{}/python-websocket/subprotocolsocket".format(PORT)
128+
conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"])
129+
await conn.write_message("Hello, world!")
130+
msg = await conn.read_message()
131+
assert json.loads(msg) == ["protocol_1", "protocol_2"]
132+
133+
134+
def test_server_proxy_websocket_subprotocols(event_loop):
135+
event_loop.run_until_complete(_websocket_subprotocols())
136+

0 commit comments

Comments
 (0)