@@ -142,27 +142,26 @@ class FrameNavigatedEvent(TypedDict):
142142Env = Dict [str , Union [str , float , bool ]]
143143
144144
145- class URLMatcher :
146- def __init__ (self , base_url : Union [str , None ], match : URLMatch ) -> None :
147- self ._callback : Optional [Callable [[str ], bool ]] = None
148- self ._regex_obj : Optional [Pattern [str ]] = None
149- if isinstance (match , str ):
150- if base_url and not match .startswith ("*" ):
151- match = urljoin (base_url , match )
152- regex = glob_to_regex (match )
153- self ._regex_obj = re .compile (regex )
154- elif isinstance (match , Pattern ):
155- self ._regex_obj = match
156- else :
157- self ._callback = match
158- self .match = match
159-
160- def matches (self , url : str ) -> bool :
161- if self ._callback :
162- return self ._callback (url )
163- if self ._regex_obj :
164- return cast (bool , self ._regex_obj .search (url ))
165- return False
145+ def url_matches (
146+ base_url : Optional [str ], url_string : str , match : Optional [URLMatch ]
147+ ) -> bool :
148+ if not match :
149+ return True
150+ if isinstance (match , str ) and match [0 ] != "*" :
151+ # Allow http(s) baseURL to match ws(s) urls.
152+ if (
153+ base_url
154+ and re .match (r"^https?://" , base_url )
155+ and re .match (r"^wss?://" , url_string )
156+ ):
157+ base_url = re .sub (r"^http" , "ws" , base_url )
158+ if base_url :
159+ match = urljoin (base_url , match )
160+ if isinstance (match , str ):
161+ match = glob_to_regex (match )
162+ if isinstance (match , Pattern ):
163+ return bool (match .search (url_string ))
164+ return match (url_string )
166165
167166
168167class HarLookupResult (TypedDict , total = False ):
@@ -271,12 +270,14 @@ def __init__(self, complete: "asyncio.Future", route: "Route") -> None:
271270class RouteHandler :
272271 def __init__ (
273272 self ,
274- matcher : URLMatcher ,
273+ base_url : Optional [str ],
274+ url : URLMatch ,
275275 handler : RouteHandlerCallback ,
276276 is_sync : bool ,
277277 times : Optional [int ] = None ,
278278 ):
279- self .matcher = matcher
279+ self ._base_url = base_url
280+ self .url = url
280281 self .handler = handler
281282 self ._times = times if times else math .inf
282283 self ._handled_count = 0
@@ -285,7 +286,7 @@ def __init__(
285286 self ._active_invocations : Set [RouteHandlerInvocation ] = set ()
286287
287288 def matches (self , request_url : str ) -> bool :
288- return self .matcher . matches ( request_url )
289+ return url_matches ( self ._base_url , request_url , self . url )
289290
290291 async def handle (self , route : "Route" ) -> bool :
291292 handler_invocation = RouteHandlerInvocation (
@@ -362,13 +363,13 @@ def prepare_interception_patterns(
362363 patterns = []
363364 all = False
364365 for handler in handlers :
365- if isinstance (handler .matcher . match , str ):
366- patterns .append ({"glob" : handler .matcher . match })
367- elif isinstance (handler .matcher . _regex_obj , re .Pattern ):
366+ if isinstance (handler .url , str ):
367+ patterns .append ({"glob" : handler .url })
368+ elif isinstance (handler .url , re .Pattern ):
368369 patterns .append (
369370 {
370- "regexSource" : handler .matcher . _regex_obj .pattern ,
371- "regexFlags" : escape_regex_flags (handler .matcher . _regex_obj ),
371+ "regexSource" : handler .url .pattern ,
372+ "regexFlags" : escape_regex_flags (handler .url ),
372373 }
373374 )
374375 else :
0 commit comments