8787
8888 T = TypeVar ("T" )
8989 BE = TypeVar ("BE" , bound = BaseException )
90- MU = TypeVar ("MU" , bound = "MaybeUnlock" )
9190 Response = Coroutine [Any , Any , T ]
9291
9392API_VERSION : int = 10
@@ -106,61 +105,92 @@ async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str
106105
107106
108107class Route :
109- API_BASE_URL : str = "https://discord.com/api/v{API_VERSION}"
110-
111- def __init__ (self , method : str , path : str , ** parameters : Any ) -> None :
112- self .path : str = path
113- self .method : str = method
114- url = self .base + self .path
115- if parameters :
116- url = url .format_map (
117- {
118- k : _uriquote (v ) if isinstance (v , str ) else v
119- for k , v in parameters .items ()
120- }
121- )
122- self .url : str = url
108+ def __init__ (
109+ self ,
110+ method : str ,
111+ path : str ,
112+ guild_id : str | None = None ,
113+ channel_id : str | None = None ,
114+ webhook_id : str | None = None ,
115+ webhook_token : str | None = None ,
116+ ** parameters : str | int ,
117+ ):
118+ self .method = method
119+ self .path = path
123120
124- # major parameters:
125- self .channel_id : Snowflake | None = parameters . get ( "channel_id" )
126- self .guild_id : Snowflake | None = parameters . get ( "guild_id" )
127- self .webhook_id : Snowflake | None = parameters . get ( " webhook_id" )
128- self .webhook_token : str | None = parameters . get ( " webhook_token" )
121+ # major parameters
122+ self .guild_id = guild_id
123+ self .channel_id = channel_id
124+ self .webhook_id = webhook_id
125+ self .webhook_token = webhook_token
129126
130- @property
131- def base (self ) -> str :
132- return self .API_BASE_URL .format (API_VERSION = API_VERSION )
127+ self .parameters = parameters
133128
134- @property
135- def bucket (self ) -> str :
136- # the bucket is just method + path w/ major parameters
137- return f"{ self .channel_id } :{ self .guild_id } :{ self .path } "
129+ def merge (self , url : str ):
130+ return url + self .path .format (
131+ guild_id = self .guild_id ,
132+ channel_id = self .channel_id ,
133+ webhook_id = self .webhook_id ,
134+ webhook_token = self .webhook_token ,
135+ ** self .parameters ,
136+ )
138137
138+ def __eq__ (self , route : 'Route' ) -> bool :
139+ return (
140+ route .channel_id == self .channel_id
141+ or route .guild_id == self .guild_id
142+ or route .webhook_id == self .webhook_id
143+ or route .webhook_token == self .webhook_token
144+ ) and route .method == self .method
139145
140- class MaybeUnlock :
141- def __init__ (self , lock : asyncio .Lock ) -> None :
142- self .lock : asyncio .Lock = lock
143- self ._unlock : bool = True
144146
145- def __enter__ (self : MU ) -> MU :
146- return self
147147
148- def defer (self ) -> None :
149- self ._unlock = False
148+ class Executor :
149+ def __init__ (self , route : Route ) -> None :
150+ self .route = route
151+ self .is_global : bool | None = None
152+ self ._request_queue : asyncio .Queue [asyncio .Event ] | None = None
153+ self .rate_limited : bool = False
150154
151- def __exit__ (
152- self ,
153- exc_type : type [BE ] | None ,
154- exc : BE | None ,
155- traceback : TracebackType | None ,
155+ async def executed (
156+ self , reset_after : int | float , limit : int , is_global : bool
156157 ) -> None :
157- if self ._unlock :
158- self .lock .release ()
158+ self .rate_limited = True
159+ self .is_global = is_global
160+ self ._reset_after = reset_after
161+ self ._request_queue = asyncio .Queue ()
162+
163+ await asyncio .sleep (reset_after )
159164
165+ self .is_global = False
160166
161- # For some reason, the Discord voice websocket expects this header to be
162- # completely lowercase while aiohttp respects spec and does it as case-insensitive
163- aiohttp .hdrs .WEBSOCKET = "websocket" # type: ignore
167+ # NOTE: This could break if someone did a second global rate limit somehow
168+ requests_passed : int = 0
169+ for _ in range (self ._request_queue .qsize () - 1 ):
170+ if requests_passed == limit :
171+ requests_passed = 0
172+ if not is_global :
173+ await asyncio .sleep (reset_after )
174+ else :
175+ await asyncio .sleep (5 )
176+
177+ requests_passed += 1
178+ e = await self ._request_queue .get ()
179+ e .set ()
180+
181+ async def wait (self ) -> None :
182+ if not self .rate_limited :
183+ return
184+
185+ event = asyncio .Event ()
186+
187+ if self ._request_queue :
188+ self ._request_queue .put_nowait (event )
189+ else :
190+ raise ValueError (
191+ 'Request queue does not exist, rate limit may have been solved.'
192+ )
193+ await event .wait ()
164194
165195
166196class HTTPClient :
@@ -174,20 +204,20 @@ def __init__(
174204 proxy_auth : aiohttp .BasicAuth | None = None ,
175205 loop : asyncio .AbstractEventLoop | None = None ,
176206 unsync_clock : bool = True ,
207+ discord_api_url : str = "https://discord.com/api/v10"
177208 ) -> None :
209+ self .api_url = discord_api_url
178210 self .loop : asyncio .AbstractEventLoop = (
179211 asyncio .get_event_loop () if loop is None else loop
180212 )
181213 self .connector = connector
182214 self .__session : aiohttp .ClientSession | utils .Undefined = MISSING # filled in static_login
183- self ._locks : weakref .WeakValueDictionary = weakref .WeakValueDictionary ()
184- self ._global_over : asyncio .Event = asyncio .Event ()
185- self ._global_over .set ()
186215 self .token : str | None = None
187216 self .bot_token : bool = False
188217 self .proxy : str | None = proxy
189218 self .proxy_auth : aiohttp .BasicAuth | None = proxy_auth
190219 self .use_clock : bool = not unsync_clock
220+ self ._executors : list [Executor ] = []
191221
192222 user_agent = (
193223 "DiscordBot (https://pycord.dev, {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
@@ -226,15 +256,9 @@ async def request(
226256 form : Iterable [dict [str , Any ]] | None = None ,
227257 ** kwargs : Any ,
228258 ) -> Any :
229- bucket = route .bucket
259+ bucket = route .merge ( self . api_url )
230260 method = route .method
231- url = route .url
232-
233- lock = self ._locks .get (bucket )
234- if lock is None :
235- lock = asyncio .Lock ()
236- if bucket is not None :
237- self ._locks [bucket ] = lock
261+ url = bucket
238262
239263 # header creation
240264 headers : dict [str , str ] = {
@@ -266,123 +290,97 @@ async def request(
266290 if self .proxy_auth is not None :
267291 kwargs ["proxy_auth" ] = self .proxy_auth
268292
269- if not self ._global_over .is_set ():
270- # wait until the global lock is complete
271- await self ._global_over .wait ()
272-
273293 response : aiohttp .ClientResponse | None = None
274294 data : dict [str , Any ] | str | None = None
275- await lock .acquire ()
276- with MaybeUnlock (lock ) as maybe_lock :
277- for tries in range (5 ):
278- if files :
279- for f in files :
280- f .reset (seek = tries )
281-
282- if form :
283- form_data = aiohttp .FormData (quote_fields = False )
284- for params in form :
285- form_data .add_field (** params )
286- kwargs ["data" ] = form_data
287-
288- try :
289- async with self .__session .request (
290- method , url , ** kwargs
291- ) as response :
292- _log .debug (
293- "%s %s with %s has returned %s" ,
294- method ,
295- url ,
296- kwargs .get ("data" ),
297- response .status ,
295+
296+ for executor in self ._executors :
297+ if executor .is_global or executor .route == route :
298+ _log .debug (f'Pausing request to { route } : Found rate limit executor' )
299+ await executor .wait ()
300+
301+ for tries in range (5 ):
302+ if files :
303+ for f in files :
304+ f .reset (seek = tries )
305+
306+ if form :
307+ form_data = aiohttp .FormData (quote_fields = False )
308+ for params in form :
309+ form_data .add_field (** params )
310+ kwargs ["data" ] = form_data
311+
312+ try :
313+ async with self .__session .request (
314+ method , url , ** kwargs
315+ ) as response :
316+ _log .debug (
317+ "%s %s with %s has returned %s" ,
318+ method ,
319+ url ,
320+ kwargs .get ("data" ),
321+ response .status ,
322+ )
323+
324+ # even errors have text involved in them so this is safe to call
325+ data = await json_or_text (response )
326+
327+ # check if we have rate limit header information
328+ remaining = response .headers .get ("X-Ratelimit-Remaining" )
329+ if remaining == "0" and response .status != 429 :
330+ _log .debug (f'Request to { route } failed: Request returned rate limit' )
331+ executor = Executor (route = route )
332+
333+ self ._executors .append (executor )
334+ await executor .executed (
335+ # NOTE: 5 is just a placeholder since this should always be present
336+ reset_after = float (response .headers .get ('X-RateLimit-Reset-After' , "5" )),
337+ is_global = response .headers .get ('X-RateLimit-Scope' ) == 'global' ,
338+ limit = int (response .headers .get ('X-RateLimit-Limit' , 10 )),
298339 )
340+ self ._executors .remove (executor )
341+ continue
299342
300- # even errors have text involved in them so this is safe to call
301- data = await json_or_text (response )
302-
303- # check if we have rate limit header information
304- remaining = response .headers .get ("X-Ratelimit-Remaining" )
305- if remaining == "0" and response .status != 429 :
306- # we've depleted our current bucket
307- delta = utils ._parse_ratelimit_header (
308- response , use_clock = self .use_clock
309- )
310- _log .debug (
311- (
312- "A rate limit bucket has been exhausted (bucket:"
313- " %s, retry: %s)."
314- ),
315- bucket ,
316- delta ,
317- )
318- maybe_lock .defer ()
319- self .loop .call_later (delta , lock .release )
320-
321- # the request was successful so just return the text/json
322- if 300 > response .status >= 200 :
323- _log .debug ("%s %s has received %s" , method , url , data )
324- return data
325-
326- # we are being rate limited
327- if response .status == 429 :
328- if not response .headers .get ("Via" ) or isinstance (data , str ):
329- # Banned by Cloudflare more than likely.
330- raise HTTPException (response , data )
331-
332- fmt = (
333- "We are being rate limited. Retrying in %.2f seconds."
334- ' Handled under the bucket "%s"'
335- )
336-
337- # sleep a bit
338- retry_after : float = data ["retry_after" ]
339- _log .warning (fmt , retry_after , bucket )
340-
341- # check if it's a global rate limit
342- is_global = data .get ("global" , False )
343- if is_global :
344- _log .warning (
345- (
346- "Global rate limit has been hit. Retrying in"
347- " %.2f seconds."
348- ),
349- retry_after ,
350- )
351- self ._global_over .clear ()
352-
353- await asyncio .sleep (retry_after )
354- _log .debug ("Done sleeping for the rate limit. Retrying..." )
355-
356- # release the global lock now that the
357- # global rate limit has passed
358- if is_global :
359- self ._global_over .set ()
360- _log .debug ("Global rate limit is now over." )
361-
362- continue
363-
364- # we've received a 500, 502, 503, or 504, unconditional retry
365- if response .status in {500 , 502 , 503 , 504 }:
366- await asyncio .sleep (1 + tries * 2 )
367- continue
368-
369- # the usual error cases
370- if response .status == 403 :
371- raise Forbidden (response , data )
372- elif response .status == 404 :
373- raise NotFound (response , data )
374- elif response .status >= 500 :
375- raise DiscordServerError (response , data )
376- else :
377- raise HTTPException (response , data )
378-
379- # This is handling exceptions from the request
380- except OSError as e :
381- # Connection reset by peer
382- if tries < 4 and e .errno in (54 , 10054 ):
343+ # the request was successful so just return the text/json
344+ if 300 > response .status >= 200 :
345+ _log .debug ("%s %s has received %s" , method , url , data )
346+ return data
347+
348+ # we are being rate limited
349+ if response .status == 429 :
350+ _log .debug (f'Request to { route } failed: Request returned rate limit' )
351+ executor = Executor (route = route )
352+
353+ self ._executors .append (executor )
354+ await executor .executed (
355+ reset_after = data ['retry_after' ],
356+ is_global = response .headers .get ('X-RateLimit-Scope' ) == 'global' ,
357+ limit = int (response .headers .get ('X-RateLimit-Limit' , 10 )),
358+ )
359+ self ._executors .remove (executor )
360+ continue
361+
362+ # we've received a 500, 502, 503, or 504, unconditional retry
363+ if response .status in {500 , 502 , 503 , 504 }:
383364 await asyncio .sleep (1 + tries * 2 )
384365 continue
385- raise
366+
367+ # the usual error cases
368+ if response .status == 403 :
369+ raise Forbidden (response , data )
370+ elif response .status == 404 :
371+ raise NotFound (response , data )
372+ elif response .status >= 500 :
373+ raise DiscordServerError (response , data )
374+ else :
375+ raise HTTPException (response , data )
376+
377+ # This is handling exceptions from the request
378+ except OSError as e :
379+ # Connection reset by peer
380+ if tries < 4 and e .errno in (54 , 10054 ):
381+ await asyncio .sleep (1 + tries * 2 )
382+ continue
383+ raise
386384
387385 if response is not None :
388386 # We've run out of retries, raise.
0 commit comments