4747 from_nullable_channel ,
4848)
4949from playwright ._impl ._event_context_manager import EventContextManagerImpl
50- from playwright ._impl ._helper import ContinueParameters , locals_to_params
50+ from playwright ._impl ._helper import FallbackOverrideParameters , locals_to_params
5151from playwright ._impl ._wait_helper import WaitHelper
5252
5353if TYPE_CHECKING : # pragma: no cover
5454 from playwright ._impl ._fetch import APIResponse
5555 from playwright ._impl ._frame import Frame
5656
5757
58+ def serialize_headers (headers : Dict [str , str ]) -> HeadersArray :
59+ return [
60+ {"name" : name , "value" : value }
61+ for name , value in headers .items ()
62+ if value is not None
63+ ]
64+
65+
5866class Request (ChannelOwner ):
5967 def __init__ (
6068 self , parent : ChannelOwner , type : str , guid : str , initializer : Dict
@@ -80,21 +88,31 @@ def __init__(
8088 }
8189 self ._provisional_headers = RawHeaders (self ._initializer ["headers" ])
8290 self ._all_headers_future : Optional [asyncio .Future [RawHeaders ]] = None
91+ self ._fallback_overrides : FallbackOverrideParameters = (
92+ FallbackOverrideParameters ()
93+ )
8394
8495 def __repr__ (self ) -> str :
8596 return f"<Request url={ self .url !r} method={ self .method !r} >"
8697
98+ def _apply_fallback_overrides (self , overrides : FallbackOverrideParameters ) -> None :
99+ self ._fallback_overrides = cast (
100+ FallbackOverrideParameters , {** self ._fallback_overrides , ** overrides }
101+ )
102+
87103 @property
88104 def url (self ) -> str :
89- return self ._initializer ["url" ]
105+ return cast ( str , self ._fallback_overrides . get ( "url" , self . _initializer ["url" ]))
90106
91107 @property
92108 def resource_type (self ) -> str :
93109 return self ._initializer ["resourceType" ]
94110
95111 @property
96112 def method (self ) -> str :
97- return self ._initializer ["method" ]
113+ return cast (
114+ str , self ._fallback_overrides .get ("method" , self ._initializer ["method" ])
115+ )
98116
99117 async def sizes (self ) -> RequestSizes :
100118 response = await self .response ()
@@ -104,10 +122,10 @@ async def sizes(self) -> RequestSizes:
104122
105123 @property
106124 def post_data (self ) -> Optional [str ]:
107- data = self .post_data_buffer
125+ data = self ._fallback_overrides . get ( "postData" , self . post_data_buffer )
108126 if not data :
109127 return None
110- return data .decode ()
128+ return data .decode () if isinstance ( data , bytes ) else data
111129
112130 @property
113131 def post_data_json (self ) -> Optional [Any ]:
@@ -124,6 +142,13 @@ def post_data_json(self) -> Optional[Any]:
124142
125143 @property
126144 def post_data_buffer (self ) -> Optional [bytes ]:
145+ override = self ._fallback_overrides .get ("post_data" )
146+ if override :
147+ return (
148+ override .encode ()
149+ if isinstance (override , str )
150+ else cast (bytes , override )
151+ )
127152 b64_content = self ._initializer .get ("postData" )
128153 if b64_content is None :
129154 return None
@@ -157,6 +182,9 @@ def timing(self) -> ResourceTiming:
157182
158183 @property
159184 def headers (self ) -> Headers :
185+ override = self ._fallback_overrides .get ("headers" )
186+ if override :
187+ return RawHeaders ._from_headers_dict_lossy (override ).headers ()
160188 return self ._provisional_headers .headers ()
161189
162190 async def all_headers (self ) -> Headers :
@@ -169,6 +197,9 @@ async def header_value(self, name: str) -> Optional[str]:
169197 return (await self ._actual_headers ()).get (name )
170198
171199 async def _actual_headers (self ) -> "RawHeaders" :
200+ override = self ._fallback_overrides .get ("headers" )
201+ if override :
202+ return RawHeaders (serialize_headers (override ))
172203 if not self ._all_headers_future :
173204 self ._all_headers_future = asyncio .Future ()
174205 headers = await self ._channel .send ("rawRequestHeaders" )
@@ -181,6 +212,21 @@ def __init__(
181212 self , parent : ChannelOwner , type : str , guid : str , initializer : Dict
182213 ) -> None :
183214 super ().__init__ (parent , type , guid , initializer )
215+ self ._handling_future : Optional [asyncio .Future ["bool" ]] = None
216+
217+ def _start_handling (self ) -> "asyncio.Future[bool]" :
218+ self ._handling_future = asyncio .Future ()
219+ return self ._handling_future
220+
221+ def _report_handled (self , done : bool ) -> None :
222+ chain = self ._handling_future
223+ assert chain
224+ self ._handling_future = None
225+ chain .set_result (done )
226+
227+ def _check_not_handled (self ) -> None :
228+ if not self ._handling_future :
229+ raise Error ("Route is already handled!" )
184230
185231 def __repr__ (self ) -> str :
186232 return f"<Route request={ self .request } >"
@@ -203,6 +249,7 @@ async def fulfill(
203249 contentType : str = None ,
204250 response : "APIResponse" = None ,
205251 ) -> None :
252+ self ._check_not_handled ()
206253 params = locals_to_params (locals ())
207254 if response :
208255 del params ["response" ]
@@ -247,37 +294,74 @@ async def fulfill(
247294 headers ["content-length" ] = str (length )
248295 params ["headers" ] = serialize_headers (headers )
249296 await self ._race_with_page_close (self ._channel .send ("fulfill" , params ))
297+ self ._report_handled (True )
250298
251- async def continue_ (
299+ async def fallback (
252300 self ,
253301 url : str = None ,
254302 method : str = None ,
255303 headers : Dict [str , str ] = None ,
256304 postData : Union [str , bytes ] = None ,
257305 ) -> None :
258- overrides : ContinueParameters = {}
259- if url :
260- overrides ["url" ] = url
261- if method :
262- overrides ["method" ] = method
263- if headers :
264- overrides ["headers" ] = serialize_headers (headers )
265- if isinstance (postData , str ):
266- overrides ["postData" ] = base64 .b64encode (postData .encode ()).decode ()
267- elif isinstance (postData , bytes ):
268- overrides ["postData" ] = base64 .b64encode (postData ).decode ()
269- await self ._race_with_page_close (
270- self ._channel .send ("continue" , cast (Any , overrides ))
271- )
306+ overrides = cast (FallbackOverrideParameters , locals_to_params (locals ()))
307+ self ._check_not_handled ()
308+ self .request ._apply_fallback_overrides (overrides )
309+ self ._report_handled (False )
272310
273- def _internal_continue (self ) -> None :
311+ async def continue_ (
312+ self ,
313+ url : str = None ,
314+ method : str = None ,
315+ headers : Dict [str , str ] = None ,
316+ postData : Union [str , bytes ] = None ,
317+ ) -> None :
318+ overrides = cast (FallbackOverrideParameters , locals_to_params (locals ()))
319+ self ._check_not_handled ()
320+ self .request ._apply_fallback_overrides (overrides )
321+ await self ._internal_continue ()
322+ self ._report_handled (True )
323+
324+ def _internal_continue (
325+ self , is_internal : bool = False
326+ ) -> Coroutine [Any , Any , None ]:
274327 async def continue_route () -> None :
275328 try :
276- await self .continue_ ()
277- except Exception :
278- pass
279-
280- asyncio .create_task (continue_route ())
329+ post_data_for_wire : Optional [str ] = None
330+ post_data_from_overrides = self .request ._fallback_overrides .get (
331+ "postData"
332+ )
333+ if post_data_from_overrides is not None :
334+ post_data_for_wire = (
335+ base64 .b64encode (post_data_from_overrides .encode ()).decode ()
336+ if isinstance (post_data_from_overrides , str )
337+ else base64 .b64encode (post_data_from_overrides ).decode ()
338+ )
339+ params = locals_to_params (
340+ cast (Dict [str , str ], self .request ._fallback_overrides )
341+ )
342+ if "headers" in params :
343+ params ["headers" ] = serialize_headers (params ["headers" ])
344+ if post_data_for_wire is not None :
345+ params ["postData" ] = post_data_for_wire
346+ await self ._race_with_page_close (
347+ self ._channel .send (
348+ "continue" ,
349+ params ,
350+ )
351+ )
352+ except Exception as e :
353+ if not is_internal :
354+ raise e
355+
356+ return continue_route ()
357+
358+ # FIXME: Port corresponding tests, and call this method
359+ async def _redirected_navigation_request (self , url : str ) -> None :
360+ self ._check_not_handled ()
361+ await self ._race_with_page_close (
362+ self ._channel .send ("redirectNavigationRequest" , {"url" : url })
363+ )
364+ self ._report_handled (True )
281365
282366 async def _race_with_page_close (self , future : Coroutine ) -> None :
283367 if hasattr (self .request .frame , "_page" ):
@@ -484,17 +568,17 @@ def _on_close(self) -> None:
484568 self .emit (WebSocket .Events .Close , self )
485569
486570
487- def serialize_headers (headers : Dict [str , str ]) -> HeadersArray :
488- return [{"name" : name , "value" : value } for name , value in headers .items ()]
489-
490-
491571class RawHeaders :
492572 def __init__ (self , headers : HeadersArray ) -> None :
493573 self ._headers_array = headers
494574 self ._headers_map : Dict [str , Dict [str , bool ]] = defaultdict (dict )
495575 for header in headers :
496576 self ._headers_map [header ["name" ].lower ()][header ["value" ]] = True
497577
578+ @staticmethod
579+ def _from_headers_dict_lossy (headers : Dict [str , str ]) -> "RawHeaders" :
580+ return RawHeaders (serialize_headers (headers ))
581+
498582 def get (self , name : str ) -> Optional [str ]:
499583 values = self .get_all (name )
500584 if not values :
0 commit comments