Skip to content

Commit f0f857b

Browse files
committed
feat(http): Utilise per-bucket rate-limiting/locking.
1 parent 145d3c8 commit f0f857b

File tree

1 file changed

+43
-34
lines changed

1 file changed

+43
-34
lines changed

interactions/api/http.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -72,41 +72,43 @@ def __init__(self, method: str, path: str, **kwargs) -> None:
7272
self.channel_id = kwargs.get("channel_id")
7373
self.guild_id = kwargs.get("guild_id")
7474

75-
@property
76-
def bucket(self) -> str:
75+
def get_bucket(self, shared_bucket: Optional[str] = None) -> str:
7776
"""
78-
Returns the route's bucket.
77+
Returns the route's bucket. If shared_bucket is None, returns the path with major parameters.
78+
Otherwise, it relies on Discord's given bucket.
79+
80+
:param shared_bucket: The bucket that Discord provides, if available.
81+
:type shared_bucket: Optional[str]
7982
8083
:return: The route bucket.
8184
:rtype: str
8285
"""
83-
return f"{self.channel_id}:{self.guild_id}:{self.path}"
86+
return (
87+
f"{self.channel_id}:{self.guild_id}:{self.path}"
88+
if shared_bucket is None
89+
else f"{self.channel_id}:{self.guild_id}:{shared_bucket}"
90+
)
8491

8592
@property
86-
def hashbucket(self) -> str:
93+
def endpoint(self) -> str:
8794
"""
88-
Returns the route's full bucket, reproducible for paeudo-hashing.
89-
This contains both bucket properties, but also the METHOD attribute.
90-
Note, that this does NOT contain the hash.
95+
Returns the route's endpoint.
9196
92-
:return: The route bucket.
97+
:return: The route endpoint.
9398
:rtype: str
9499
"""
95-
96-
return f"{self.method}::{self.bucket}"
100+
return f"{self.method}:{self.path}"
97101

98102

99103
class Limiter:
100104
"""
101105
A class representing a limitation for an HTTP request.
102106
103107
:ivar Lock lock: The "lock" or controller of the request.
104-
:ivar List[str] hashes: The known hashes of the request.
105108
:ivar float reset_after: The remaining time before the request can be ran.
106109
"""
107110

108111
lock: Lock
109-
hashes: List[str]
110112
reset_after: float
111113

112114
def __init__(self, *, lock: Lock, reset_after: Optional[float] = MISSING) -> None:
@@ -118,7 +120,6 @@ def __init__(self, *, lock: Lock, reset_after: Optional[float] = MISSING) -> Non
118120
"""
119121
self.lock = lock
120122
self.reset_after = 0 if reset_after is MISSING else reset_after
121-
self.hashes = []
122123

123124
async def __aenter__(self) -> "Limiter":
124125
await self.lock.acquire()
@@ -135,6 +136,7 @@ class Request:
135136
:ivar str token: The current application token.
136137
:ivar AbstractEventLoop _loop: The current coroutine event loop.
137138
:ivar Dict[str, Limiter] ratelimits: The current per-route rate limiters from the API.
139+
:ivar Dict[str, str] buckets: The current endpoint to shared_bucket cache from the API.
138140
:ivar dict _headers: The current headers for an HTTP request.
139141
:ivar ClientSession _session: The current session for making requests.
140142
:ivar Limiter _global_lock: The global rate limiter.
@@ -144,13 +146,15 @@ class Request:
144146
"token",
145147
"_loop",
146148
"ratelimits",
149+
"buckets",
147150
"_headers",
148151
"_session",
149152
"_global_lock",
150153
)
151154
token: str
152155
_loop: AbstractEventLoop
153-
ratelimits: Dict[str, Limiter] # hashbucket: Limiter
156+
ratelimits: Dict[str, Limiter] # bucket: Limiter
157+
buckets: Dict[str, str] # endpoint: shared_bucket
154158
_headers: dict
155159
_session: ClientSession
156160
_global_lock: Limiter
@@ -163,6 +167,7 @@ def __init__(self, token: str) -> None:
163167
self.token = token
164168
self._loop = get_event_loop() if version_info < (3, 10) else get_running_loop()
165169
self.ratelimits = {}
170+
self.buckets = {}
166171
self._headers = {
167172
"Authorization": f"Bot {self.token}",
168173
"User-Agent": f"DiscordBot (https://github.com/goverfl0w/interactions.py {__version__} "
@@ -205,22 +210,25 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
205210

206211
# Huge credit and thanks to LordOfPolls for the lock/retry logic.
207212

208-
# This section generates the bucket through the hashbucket attr,
209-
# which essentially contains path, method, and major params.
213+
bucket = route.get_bucket(
214+
self.buckets.get(route.endpoint)
215+
) # string returning path OR prioritised hash bucket metadata.
216+
217+
# The idea is that its regulated by the priority of Discord's bucket header and not just self-computation.
210218

211-
if self.ratelimits.get(route.hashbucket):
212-
bucket: Limiter = self.ratelimits.get(route.hashbucket)
213-
if bucket.lock.locked():
219+
if self.ratelimits.get(bucket):
220+
_limiter: Limiter = self.ratelimits.get(bucket)
221+
if _limiter.lock.locked():
214222
log.warning(
215-
f"The current bucket is still under a rate limit. Calling later in {bucket.reset_after} seconds."
223+
f"The current bucket is still under a rate limit. Calling later in {_limiter.reset_after} seconds."
216224
)
217-
self._loop.call_later(bucket.reset_after, bucket.lock.release)
218-
bucket.reset_after = 0
225+
self._loop.call_later(_limiter.reset_after, _limiter.lock.release)
226+
_limiter.reset_after = 0
219227
else:
220-
self.ratelimits.update({route.hashbucket: Limiter(lock=Lock(loop=self._loop))})
221-
bucket: Limiter = self.ratelimits.get(route.hashbucket)
228+
self.ratelimits.update({bucket: Limiter(lock=Lock(loop=self._loop))})
229+
_limiter: Limiter = self.ratelimits.get(bucket)
222230

223-
await bucket.lock.acquire()
231+
await _limiter.lock.acquire() # _limiter is the per shared bucket/route endpoint
224232

225233
# Implement retry logic. The common seems to be 5, so this is hardcoded, for the most part.
226234

@@ -243,8 +251,9 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
243251

244252
log.debug(f"{route.method}: {route.__api__ + route.path}: {kwargs}")
245253

246-
if _bucket not in bucket.hashes:
247-
bucket.hashes.append(_bucket)
254+
if _bucket is not None:
255+
self.buckets[route.endpoint] = _bucket
256+
# real-time replacement/update/add if needed.
248257

249258
if isinstance(data, dict) and data.get("errors"):
250259
log.debug(
@@ -258,8 +267,8 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
258267
log.warning(
259268
f"The HTTP client has encountered a per-route ratelimit. Locking down future requests for {reset_after} seconds."
260269
)
261-
bucket.reset_after = reset_after
262-
await asyncio.sleep(bucket.reset_after)
270+
_limiter.reset_after = reset_after
271+
await asyncio.sleep(_limiter.reset_after)
263272
continue
264273
elif is_global:
265274
log.warning(
@@ -279,22 +288,22 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
279288
await asyncio.sleep(2 * tries + 1)
280289
continue
281290
try:
282-
bucket.lock.release()
291+
_limiter.lock.release()
283292
except RuntimeError:
284293
pass
285294
raise
286295

287296
# For generic exceptions we give a traceback for debug reasons.
288297
except Exception as e:
289298
try:
290-
bucket.lock.release()
299+
_limiter.lock.release()
291300
except RuntimeError:
292301
pass
293302
log.error("".join(traceback.format_exception(type(e), e, e.__traceback__)))
294303
break
295304

296-
if bucket.lock.locked():
297-
bucket.lock.release()
305+
if _limiter.lock.locked():
306+
_limiter.lock.release()
298307

299308
async def close(self) -> None:
300309
"""Closes the current session."""

0 commit comments

Comments
 (0)