Skip to content

Commit 7cdd8d7

Browse files
committed
refactor(http): implement yet again another rate limiter.
1 parent 158df3a commit 7cdd8d7

File tree

3 files changed

+108
-67
lines changed

3 files changed

+108
-67
lines changed

interactions/api/http.py

Lines changed: 103 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from asyncio import AbstractEventLoop, get_event_loop
1+
from asyncio import AbstractEventLoop, Lock, get_event_loop, get_running_loop
22
from json import dumps
33
from logging import Logger, getLogger
44
from sys import version_info
5-
from threading import Event
65
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
76
from urllib.parse import quote
87

98
from aiohttp import ClientSession, FormData
109
from aiohttp import __version__ as http_version
1110

1211
import interactions.api.cache
12+
from interactions.models.misc import MISSING
1313

1414
from ..api.cache import Cache, Item
1515
from ..api.error import HTTPException
@@ -81,16 +81,48 @@ def bucket(self) -> str:
8181
return f"{self.channel_id}:{self.guild_id}:{self.path}"
8282

8383

84+
class Limiter:
85+
"""
86+
A class representing a limitation for an HTTP request.
87+
88+
:ivar Lock lock: The "lock" or controller of the request.
89+
:ivar List[str] hashes: The known hashes of the request.
90+
:ivar float reset_after: The remaining time before the request can be ran.
91+
"""
92+
93+
lock: Lock
94+
hashes: List[str]
95+
reset_after: float
96+
97+
def __init__(self, *, lock: Lock, reset_after: Optional[float] = MISSING) -> None:
98+
"""
99+
:param lock: The asynchronous lock to control limits for.
100+
:type lock: Lock
101+
:param reset_after: The remaining time to run the limited lock on. Defaults to ``0``.
102+
:type reset_after: Optional[float]
103+
"""
104+
self.lock = lock
105+
self.reset_after = 0 if reset_after is MISSING else reset_after
106+
self.hashes = []
107+
108+
async def __aenter__(self) -> "Limiter":
109+
await self.lock.acquire()
110+
return self
111+
112+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
113+
return self.lock.release()
114+
115+
84116
class Request:
85117
"""
86118
A class representing how HTTP requests are sent/read.
87119
88120
:ivar str token: The current application token.
89-
:ivar AbstractEventLoop loop: The current coroutine event loop.
90-
:ivar dict ratelimits: The current ratelimits from the Discord API.
91-
:ivar dict headers: The current headers for an HTTP request.
92-
:ivar ClientSession session: The current session for making requests.
93-
:ivar Event lock: The ratelimit lock event.
121+
:ivar AbstractEventLoop _loop: The current coroutine event loop.
122+
:ivar Dict[Route, Limiter] ratelimits: The current per-route rate limiters from the API.
123+
:ivar dict _headers: The current headers for an HTTP request.
124+
:ivar ClientSession _session: The current session for making requests.
125+
:ivar Limiter _global_lock: The global rate limiter.
94126
"""
95127

96128
__slots__ = (
@@ -100,23 +132,21 @@ class Request:
100132
"_headers",
101133
"_session",
102134
"_global_lock",
103-
"_global_remaining",
104135
)
105136
token: str
106137
_loop: AbstractEventLoop
107-
ratelimits: dict
138+
ratelimits: Dict[Route, Limiter]
108139
_headers: dict
109140
_session: ClientSession
110-
_global_lock: Event
111-
_global_remaining: float
141+
_global_lock: Limiter
112142

113143
def __init__(self, token: str) -> None:
114144
"""
115145
:param token: The application token used for authorizing.
116146
:type token: str
117147
"""
118148
self.token = token
119-
self._loop = get_event_loop()
149+
self._loop = get_event_loop() if version_info < (3, 10) else get_running_loop()
120150
self.ratelimits = {}
121151
self._headers = {
122152
"Authorization": f"Bot {self.token}",
@@ -125,8 +155,7 @@ def __init__(self, token: str) -> None:
125155
f"aiohttp/{http_version}",
126156
}
127157
self._session = _session
128-
self._global_lock = Event()
129-
self._global_remaining = 0
158+
self._global_lock = Limiter(lock=Lock(loop=self._loop))
130159

131160
def _check_session(self) -> None:
132161
"""Ensures that we have a valid connection session."""
@@ -135,10 +164,10 @@ def _check_session(self) -> None:
135164

136165
async def _check_lock(self) -> None:
137166
"""Checks the global lock for its current state."""
138-
if self._global_lock.is_set():
167+
if self._global_lock.lock.locked():
139168
log.warning("The HTTP client is still globally locked, waiting for it to clear.")
140-
self._global_lock.wait(self._global_remaining)
141-
self._global_lock.clear()
169+
await self._global_lock.lock.acquire()
170+
self._global_lock.reset_after = 0
142171

143172
async def request(self, route: Route, **kwargs) -> Optional[Any]:
144173
r"""
@@ -153,61 +182,69 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]:
153182
"""
154183
self._check_session()
155184
await self._check_lock()
156-
bucket: str = route.bucket
157-
ratelimit: Event = self.ratelimits.get(bucket)
158185

159-
if ratelimit is None:
160-
self.ratelimits[bucket] = {"lock": Event(), "remaining": 0}
186+
# This is the per-route check. We check BEFORE the request is made
187+
# to see if there's a rate limit for it. If there is, we'll call this
188+
# later in the event loop and reset the remaining time. Otherwise,
189+
# we'll set a "limiter" for it respective to that bucket. The hashes will
190+
# be checked later.
191+
if self.ratelimits.get(route):
192+
bucket: Limiter = self.ratelimits.get(route)
193+
if bucket.lock.locked():
194+
log.warning(
195+
f"The current bucket is still under a rate limit. Calling later in {bucket.reset_after} seconds."
196+
)
197+
self._loop.call_later(bucket.reset_after, bucket.lock)
198+
await bucket.lock.acquire()
199+
bucket.reset_after = 0
161200
else:
162-
if ratelimit.is_set():
163-
log.warning("The requested HTTP endpoint is still locked, waiting for it to clear.")
164-
ratelimit["lock"].wait(ratelimit["reset_after"])
165-
ratelimit["lock"].clear()
166-
167-
kwargs["headers"] = {**self._headers, **kwargs.get("headers", {})}
168-
kwargs["headers"]["Content-Type"] = "application/json"
169-
170-
reason = kwargs.pop("reason", None)
171-
if reason:
172-
kwargs["headers"]["X-Audit-Log-Reason"] = quote(reason, safe="/ ")
173-
174-
async with self._session.request(
175-
route.method, route.__api__ + route.path, **kwargs
176-
) as response:
177-
data = await response.json(content_type=None)
178-
reset_after: str = response.headers.get("X-Ratelimit-Reset-After")
179-
remaining: str = response.headers.get("X-Ratelimit-Remaining")
180-
bucket: str = response.headers.get("X-Ratelimit-Bucket")
181-
is_global: bool = (
182-
True
183-
if response.headers.get("X-Ratelimit-Global") or bool(data.get("global"))
184-
else False
185-
)
186-
187-
log.debug(f"{route.method}: {route.__api__ + route.path}: {kwargs}")
188-
log.debug(f"RETURN {response.status}: {dumps(data, indent=4, sort_keys=True)}")
189-
190-
if data.get("errors"):
191-
raise HTTPException(data["code"], message=data["message"])
192-
elif remaining and not int(remaining):
193-
if response.status != 429:
194-
if bucket:
201+
self.ratelimits.update({route: Limiter(lock=Lock(loop=self._loop))})
202+
203+
# We're controlling our HTTP request with the route as its own
204+
# separate lock here. This way, we can control the request of the
205+
# route as an asynchronous method. This way, if the event loop is to call on this later,
206+
# this will temporarily block but still allow to process the original request
207+
# we wanted to make.
208+
async with self.ratelimits.get(route) as _lock:
209+
kwargs["headers"] = {**self._headers, **kwargs.get("headers", {})}
210+
kwargs["headers"]["Content-Type"] = "application/json"
211+
212+
reason = kwargs.pop("reason", None)
213+
if reason:
214+
kwargs["headers"]["X-Audit-Log-Reason"] = quote(reason, safe="/ ")
215+
216+
async with self._session.request(
217+
route.method, route.__api__ + route.path, **kwargs
218+
) as response:
219+
data = await response.json(content_type=None)
220+
reset_after: str = response.headers.get("X-RateLimit-Reset-After")
221+
remaining: str = response.headers.get("X-RateLimit-Remaining")
222+
bucket: str = response.headers.get("X-RateLimit-Bucket")
223+
is_global: bool = response.headers.get("X-RateLimit-Global", False)
224+
225+
log.debug(f"{route.method}: {route.__api__ + route.path}: {kwargs}")
226+
log.debug(f"RETURN {response.status}: {dumps(data, indent=4, sort_keys=True)}")
227+
228+
if bucket not in _lock.hashes:
229+
_lock.hashes.append(bucket)
230+
231+
if isinstance(data, dict) and data.get("errors"):
232+
raise HTTPException(data["code"], message=data["message"])
233+
elif remaining and not int(remaining):
234+
if response.status == 429:
195235
log.warning(
196-
f"The requested HTTP endpoint is currently ratelimited. Waiting for {reset_after} seconds."
236+
f"The HTTP client has encountered a per-route ratelimit. Locking down future requests for {reset_after} seconds."
197237
)
198-
self.ratelimits[bucket].wait(float(reset_after))
199-
else:
238+
_lock.reset_after = reset_after
239+
self._loop.call_later(_lock.reset_after, _lock.lock)
240+
elif is_global:
200241
log.warning(
201-
f"The HTTP client has reached the maximum amount of requests. Cooling down for {reset_after} seconds."
242+
f"The HTTP client has encountered a global ratelimit. Locking down future requests for {reset_after} seconds."
202243
)
203-
self._global_lock.wait(float(reset_after))
204-
elif is_global:
205-
log.warning(
206-
f"The HTTP client has encountered a global ratelimit. Locking down future requests for {reset_after} seconds."
207-
)
208-
self._global_lock.wait(float(reset_after))
209-
210-
return data
244+
self._global_lock.reset_after = reset_after
245+
self._loop.call_later(self._global_lock.reset_after, self._globl_lock.lock)
246+
247+
return data
211248

212249
async def close(self) -> None:
213250
"""Closes the current session."""

interactions/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def __register_events(self) -> None:
103103
self._websocket.dispatch.register(self.__raw_socket_create)
104104
self._websocket.dispatch.register(self.__raw_channel_create, "on_channel_create")
105105
self._websocket.dispatch.register(self.__raw_message_create, "on_message_create")
106-
self._websocket.dispatch.register(self.__raw_message_create, "on_message_update")
107106
self._websocket.dispatch.register(self.__raw_guild_create, "on_guild_create")
108107

109108
async def __compare_sync(self, data: dict, pool: List[dict]) -> bool:

simple_bot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ async def on_ready():
1313
print("bot is now online.")
1414

1515

16+
@bot.event
17+
async def on_message_create(message: interactions.Message):
18+
await bot._http.send_message(channel_id=852402668294766615, content=message.content)
19+
20+
1621
@bot.command(
1722
type=interactions.ApplicationCommandType.MESSAGE,
1823
name="simple testing command",

0 commit comments

Comments
 (0)