Skip to content

Commit ba33538

Browse files
ganisbackyuvipanda
authored andcommitted
support stream
1 parent 44b5405 commit ba33538

File tree

1 file changed

+79
-2
lines changed

1 file changed

+79
-2
lines changed

jupyter_server_proxy/handlers.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Some original inspiration from https://github.com/senko/tornado-proxy
55
"""
66

7-
import os
7+
import os, json, re
88
import socket
99
from asyncio import Lock
1010
from copy import copy
@@ -287,7 +287,7 @@ def get_client_uri(self, protocol, host, port, proxied_path):
287287

288288
return client_uri
289289

290-
def _build_proxy_request(self, host, port, proxied_path, body):
290+
def _build_proxy_request(self, host, port, proxied_path, body, **extra_opts):
291291
headers = self.proxy_request_headers()
292292

293293
client_uri = self.get_client_uri("http", host, port, proxied_path)
@@ -307,6 +307,7 @@ def _build_proxy_request(self, host, port, proxied_path, body):
307307
decompress_response=False,
308308
headers=headers,
309309
**self.proxy_request_options(),
310+
**extra_opts,
310311
)
311312
return req
312313

@@ -365,7 +366,82 @@ async def proxy(self, host, port, proxied_path):
365366
body = b""
366367
else:
367368
body = None
369+
accept_type = self.request.headers.get('Accept')
370+
if accept_type == 'text/event-stream':
371+
return await self._proxy_progressive(host, port, proxied_path, body)
372+
else:
373+
return await self._proxy_normal(host, port, proxied_path, body)
374+
375+
async def _proxy_progressive(self, host, port, proxied_path, body):
376+
# Proxy in progressive flush mode, whenever chunks are received. Potentially slower but get results quicker for voila
377+
378+
client = httpclient.AsyncHTTPClient()
379+
380+
# Set up handlers so we can progressively flush result
381+
382+
headers_raw = []
383+
384+
def dump_headers(headers_raw):
385+
for line in headers_raw:
386+
r = re.match('^([a-zA-Z0-9\-_]+)\s*\:\s*([^\r\n]+)[\r\n]*$', line)
387+
if r:
388+
k,v = r.groups([1,2])
389+
if k not in ('Content-Length', 'Transfer-Encoding',
390+
'Content-Encoding', 'Connection'):
391+
# some header appear multiple times, eg 'Set-Cookie'
392+
self.set_header(k,v)
393+
else:
394+
r = re.match('^HTTP[^\s]* ([0-9]+)', line)
395+
if r:
396+
status_code = r.group(1)
397+
self.set_status(int(status_code))
398+
headers_raw.clear()
399+
400+
# clear tornado default header
401+
self._headers = httputil.HTTPHeaders()
402+
403+
def header_callback(line):
404+
headers_raw.append(line)
405+
406+
def streaming_callback(chunk):
407+
# Do this here, not in header_callback so we can be sure headers are out of the way first
408+
dump_headers(headers_raw) # array will be empty if this was already called before
409+
self.write(chunk)
410+
self.flush()
411+
412+
# Now make the request
413+
414+
req = self._build_proxy_request(host, port, proxied_path, body,
415+
streaming_callback=streaming_callback,
416+
header_callback=header_callback)
417+
418+
try:
419+
response = await client.fetch(req, raise_error=False)
420+
except httpclient.HTTPError as err:
421+
if err.code == 599:
422+
self._record_activity()
423+
self.set_status(599)
424+
self.write(str(err))
425+
return
426+
else:
427+
raise
428+
429+
# record activity at start and end of requests
430+
self._record_activity()
431+
432+
# For all non http errors...
433+
if response.error and type(response.error) is not httpclient.HTTPError:
434+
self.set_status(500)
435+
self.write(str(response.error))
436+
else:
437+
self.set_status(response.code, response.reason) # Should already have been set
438+
439+
dump_headers(headers_raw) # Should already have been emptied
368440

441+
if response.body: # Likewise, should already be chunked out and flushed
442+
self.write(response.body)
443+
444+
async def _proxy_normal(self, host, port, proxied_path, body):
369445
if self.unix_socket is not None:
370446
# Port points to a Unix domain socket
371447
self.log.debug("Making client for Unix socket %r", self.unix_socket)
@@ -458,6 +534,7 @@ def rewrite_pe(rewritable_response: RewritableResponse):
458534
if rewritten_response.body:
459535
self.write(rewritten_response.body)
460536

537+
461538
async def proxy_open(self, host, port, proxied_path=""):
462539
"""
463540
Called when a client opens a websocket connection.

0 commit comments

Comments
 (0)