11import json
22import asyncio
33import logging
4- from urllib .parse import urljoin
54from asyncio import Queue as AsyncQueue
65from threading import Event as ThreadEvent , Thread
76from queue import Queue as ThreadQueue
7+ from urllib .parse import parse_qs as parse_query_string
88from typing import Union , Tuple , Dict , Any , Optional , Callable , NamedTuple , Type , cast
99
1010from typing_extensions import TypedDict
11- from flask import Flask , Blueprint , send_from_directory , redirect , url_for
11+ from flask import Flask , Blueprint , send_from_directory , redirect , url_for , request
1212from flask_cors import CORS
1313from flask_sockets import Sockets
1414from geventwebsocket .websocket import WebSocket
@@ -77,7 +77,7 @@ def _setup_application(self, config: Config, app: Flask) -> None:
7777
7878 sockets = Sockets (app )
7979
80- @sockets .route (urljoin (config ["url_prefix" ], "/stream" )) # type: ignore
80+ @sockets .route (_join_url_paths (config ["url_prefix" ], "/stream" )) # type: ignore
8181 def model_stream (ws : WebSocket ) -> None :
8282 def send (value : Any ) -> None :
8383 ws .send (json .dumps (value ))
@@ -89,9 +89,14 @@ def recv() -> Optional[LayoutEvent]:
8989 else :
9090 return None
9191
92+ query_params = {
93+ k : v if len (v ) > 1 else v [0 ]
94+ for k , v in parse_query_string (ws .environ ["QUERY_STRING" ]).items ()
95+ }
96+
9297 run_dispatcher_in_thread (
9398 lambda : self ._dispatcher_type (
94- Layout (self ._root_component_constructor ())
99+ Layout (self ._root_component_constructor (** query_params ))
95100 ),
96101 send ,
97102 recv ,
@@ -109,7 +114,13 @@ def send_build_dir(path: str) -> Any:
109114
110115 @blueprint .route ("/" )
111116 def redirect_to_index () -> Any :
112- return redirect (url_for ("idom.send_build_dir" , path = "index.html" ))
117+ return redirect (
118+ url_for (
119+ "idom.send_build_dir" ,
120+ path = "index.html" ,
121+ ** request .args ,
122+ )
123+ )
113124
114125 def _setup_application_did_start_event (
115126 self , config : Config , app : Flask , event : ThreadEvent
@@ -261,3 +272,9 @@ def update_environ(self) -> None:
261272 super ().update_environ ()
262273 # BUG: for some reason coverage doesn't seem to think this line is covered
263274 self ._before_first_request_callback () # pragma: no cover
275+
276+
277+ def _join_url_paths (* args : str ) -> str :
278+ # urllib.parse.urljoin performs more logic than is needed. Thus we need a util func
279+ # to join paths as if they were POSIX paths.
280+ return "/" .join (map (lambda x : str (x ).rstrip ("/" ), filter (None , args )))
0 commit comments