1- from asyncio import AbstractEventLoop , get_event_loop
1+ from asyncio import AbstractEventLoop , Lock , get_event_loop , get_running_loop
22from json import dumps
33from logging import Logger , getLogger
44from sys import version_info
5- from threading import Event
65from typing import Any , ClassVar , Dict , List , Optional , Tuple , Union
76from urllib .parse import quote
87
98from aiohttp import ClientSession , FormData
109from aiohttp import __version__ as http_version
1110
1211import interactions .api .cache
12+ from interactions .models .misc import MISSING
1313
1414from ..api .cache import Cache , Item
1515from ..api .error import HTTPException
@@ -81,16 +81,48 @@ def bucket(self) -> str:
8181 return f"{ self .channel_id } :{ self .guild_id } :{ self .path } "
8282
8383
84+ class Limiter :
85+ """
86+ A class representing a limitation for an HTTP request.
87+
88+ :ivar Lock lock: The "lock" or controller of the request.
89+ :ivar List[str] hashes: The known hashes of the request.
90+ :ivar float reset_after: The remaining time before the request can be ran.
91+ """
92+
93+ lock : Lock
94+ hashes : List [str ]
95+ reset_after : float
96+
97+ def __init__ (self , * , lock : Lock , reset_after : Optional [float ] = MISSING ) -> None :
98+ """
99+ :param lock: The asynchronous lock to control limits for.
100+ :type lock: Lock
101+ :param reset_after: The remaining time to run the limited lock on. Defaults to ``0``.
102+ :type reset_after: Optional[float]
103+ """
104+ self .lock = lock
105+ self .reset_after = 0 if reset_after is MISSING else reset_after
106+ self .hashes = []
107+
108+ async def __aenter__ (self ) -> "Limiter" :
109+ await self .lock .acquire ()
110+ return self
111+
112+ async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None :
113+ return self .lock .release ()
114+
115+
84116class Request :
85117 """
86118 A class representing how HTTP requests are sent/read.
87119
88120 :ivar str token: The current application token.
89- :ivar AbstractEventLoop loop : The current coroutine event loop.
90- :ivar dict ratelimits: The current ratelimits from the Discord API.
91- :ivar dict headers : The current headers for an HTTP request.
92- :ivar ClientSession session : The current session for making requests.
93- :ivar Event lock : The ratelimit lock event .
121+ :ivar AbstractEventLoop _loop : The current coroutine event loop.
122+ :ivar Dict[Route, Limiter] ratelimits: The current per-route rate limiters from the API.
123+ :ivar dict _headers : The current headers for an HTTP request.
124+ :ivar ClientSession _session : The current session for making requests.
125+ :ivar Limiter _global_lock : The global rate limiter .
94126 """
95127
96128 __slots__ = (
@@ -100,23 +132,21 @@ class Request:
100132 "_headers" ,
101133 "_session" ,
102134 "_global_lock" ,
103- "_global_remaining" ,
104135 )
105136 token : str
106137 _loop : AbstractEventLoop
107- ratelimits : dict
138+ ratelimits : Dict [ Route , Limiter ]
108139 _headers : dict
109140 _session : ClientSession
110- _global_lock : Event
111- _global_remaining : float
141+ _global_lock : Limiter
112142
113143 def __init__ (self , token : str ) -> None :
114144 """
115145 :param token: The application token used for authorizing.
116146 :type token: str
117147 """
118148 self .token = token
119- self ._loop = get_event_loop ()
149+ self ._loop = get_event_loop () if version_info < ( 3 , 10 ) else get_running_loop ()
120150 self .ratelimits = {}
121151 self ._headers = {
122152 "Authorization" : f"Bot { self .token } " ,
@@ -125,8 +155,7 @@ def __init__(self, token: str) -> None:
125155 f"aiohttp/{ http_version } " ,
126156 }
127157 self ._session = _session
128- self ._global_lock = Event ()
129- self ._global_remaining = 0
158+ self ._global_lock = Limiter (lock = Lock (loop = self ._loop ))
130159
131160 def _check_session (self ) -> None :
132161 """Ensures that we have a valid connection session."""
@@ -135,10 +164,10 @@ def _check_session(self) -> None:
135164
136165 async def _check_lock (self ) -> None :
137166 """Checks the global lock for its current state."""
138- if self ._global_lock .is_set ():
167+ if self ._global_lock .lock . locked ():
139168 log .warning ("The HTTP client is still globally locked, waiting for it to clear." )
140- self ._global_lock .wait ( self . _global_remaining )
141- self ._global_lock .clear ()
169+ await self ._global_lock .lock . acquire ( )
170+ self ._global_lock .reset_after = 0
142171
143172 async def request (self , route : Route , ** kwargs ) -> Optional [Any ]:
144173 r"""
@@ -153,61 +182,69 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
153182 """
154183 self ._check_session ()
155184 await self ._check_lock ()
156- bucket : str = route .bucket
157- ratelimit : Event = self .ratelimits .get (bucket )
158185
159- if ratelimit is None :
160- self .ratelimits [bucket ] = {"lock" : Event (), "remaining" : 0 }
186+ # This is the per-route check. We check BEFORE the request is made
187+ # to see if there's a rate limit for it. If there is, we'll call this
188+ # later in the event loop and reset the remaining time. Otherwise,
189+ # we'll set a "limiter" for it respective to that bucket. The hashes will
190+ # be checked later.
191+ if self .ratelimits .get (route ):
192+ bucket : Limiter = self .ratelimits .get (route )
193+ if bucket .lock .locked ():
194+ log .warning (
195+ f"The current bucket is still under a rate limit. Calling later in { bucket .reset_after } seconds."
196+ )
197+ self ._loop .call_later (bucket .reset_after , bucket .lock )
198+ await bucket .lock .acquire ()
199+ bucket .reset_after = 0
161200 else :
162- if ratelimit .is_set ():
163- log .warning ("The requested HTTP endpoint is still locked, waiting for it to clear." )
164- ratelimit ["lock" ].wait (ratelimit ["reset_after" ])
165- ratelimit ["lock" ].clear ()
166-
167- kwargs ["headers" ] = {** self ._headers , ** kwargs .get ("headers" , {})}
168- kwargs ["headers" ]["Content-Type" ] = "application/json"
169-
170- reason = kwargs .pop ("reason" , None )
171- if reason :
172- kwargs ["headers" ]["X-Audit-Log-Reason" ] = quote (reason , safe = "/ " )
173-
174- async with self ._session .request (
175- route .method , route .__api__ + route .path , ** kwargs
176- ) as response :
177- data = await response .json (content_type = None )
178- reset_after : str = response .headers .get ("X-Ratelimit-Reset-After" )
179- remaining : str = response .headers .get ("X-Ratelimit-Remaining" )
180- bucket : str = response .headers .get ("X-Ratelimit-Bucket" )
181- is_global : bool = (
182- True
183- if response .headers .get ("X-Ratelimit-Global" ) or bool (data .get ("global" ))
184- else False
185- )
186-
187- log .debug (f"{ route .method } : { route .__api__ + route .path } : { kwargs } " )
188- log .debug (f"RETURN { response .status } : { dumps (data , indent = 4 , sort_keys = True )} " )
189-
190- if data .get ("errors" ):
191- raise HTTPException (data ["code" ], message = data ["message" ])
192- elif remaining and not int (remaining ):
193- if response .status != 429 :
194- if bucket :
201+ self .ratelimits .update ({route : Limiter (lock = Lock (loop = self ._loop ))})
202+
203+ # We're controlling our HTTP request with the route as its own
204+ # separate lock here. This way, we can control the request of the
205+ # route as an asynchronous method. This way, if the event loop is to call on this later,
206+ # this will temporarily block but still allow to process the original request
207+ # we wanted to make.
208+ async with self .ratelimits .get (route ) as _lock :
209+ kwargs ["headers" ] = {** self ._headers , ** kwargs .get ("headers" , {})}
210+ kwargs ["headers" ]["Content-Type" ] = "application/json"
211+
212+ reason = kwargs .pop ("reason" , None )
213+ if reason :
214+ kwargs ["headers" ]["X-Audit-Log-Reason" ] = quote (reason , safe = "/ " )
215+
216+ async with self ._session .request (
217+ route .method , route .__api__ + route .path , ** kwargs
218+ ) as response :
219+ data = await response .json (content_type = None )
220+ reset_after : str = response .headers .get ("X-RateLimit-Reset-After" )
221+ remaining : str = response .headers .get ("X-RateLimit-Remaining" )
222+ bucket : str = response .headers .get ("X-RateLimit-Bucket" )
223+ is_global : bool = response .headers .get ("X-RateLimit-Global" , False )
224+
225+ log .debug (f"{ route .method } : { route .__api__ + route .path } : { kwargs } " )
226+ log .debug (f"RETURN { response .status } : { dumps (data , indent = 4 , sort_keys = True )} " )
227+
228+ if bucket not in _lock .hashes :
229+ _lock .hashes .append (bucket )
230+
231+ if isinstance (data , dict ) and data .get ("errors" ):
232+ raise HTTPException (data ["code" ], message = data ["message" ])
233+ elif remaining and not int (remaining ):
234+ if response .status == 429 :
195235 log .warning (
196- f"The requested HTTP endpoint is currently ratelimited. Waiting for { reset_after } seconds."
236+ f"The HTTP client has encountered a per-route ratelimit. Locking down future requests for { reset_after } seconds."
197237 )
198- self .ratelimits [bucket ].wait (float (reset_after ))
199- else :
238+ _lock .reset_after = reset_after
239+ self ._loop .call_later (_lock .reset_after , _lock .lock )
240+ elif is_global :
200241 log .warning (
201- f"The HTTP client has reached the maximum amount of requests. Cooling down for { reset_after } seconds."
242+ f"The HTTP client has encountered a global ratelimit. Locking down future requests for { reset_after } seconds."
202243 )
203- self ._global_lock .wait (float (reset_after ))
204- elif is_global :
205- log .warning (
206- f"The HTTP client has encountered a global ratelimit. Locking down future requests for { reset_after } seconds."
207- )
208- self ._global_lock .wait (float (reset_after ))
209-
210- return data
244+ self ._global_lock .reset_after = reset_after
245+ self ._loop .call_later (self ._global_lock .reset_after , self ._globl_lock .lock )
246+
247+ return data
211248
212249 async def close (self ) -> None :
213250 """Closes the current session."""
0 commit comments