Skip to content

Commit 78d50b6

Browse files
feat: Add base iterators & context managers + helpers (#1057)
* fix: small but annoying bug * feat: Add channel history async iterator * docs: change docstring * feat: add `get_channel_history` utility method * fix: add client safeguard * feat: add maximum amount to get * refactor/feat: Create an ABC and update the channel history iterator accordingly * feat/refactor: add checks and move files * docs: fix strings * chore: remove print * fix: add check to channel history iter * fix: add check to channel history iter * feat: make the guild members iterator * fix: type annotation * fix: correct id key * fix: make `maximum` `inf` * fix: math is hard... * feat: Add normal `BaseIterator` * chore: make comment * refactor: change import paths to relative * fix: path * feat: add context managers * feat: add typing context manager * feat: add new attributes and `get_guild_members` utility method * fix: speling (typo intended) * fix: import context managers * Update utils.py * ci: correct from checks. * Update utils.py * Update base_iterators.py * ci: correct from checks. * Update base_iterators.py * Update channel.py * Update guild.py * Update guild.py * ci: correct from checks. * Update channel.py * ci: correct from checks. * Update base_iterators.py * Update base_context_managers.py * ci: correct from checks. * Update base_iterators.py * ci: correct from checks. * Update base_iterators.py * Update channel.py * Update guild.py * Update base_context_managers.py * Update channel.py * Update base_context_managers.py * ci: correct from checks. * Update channel.py * Update base_context_managers.py * Update channel.py * ci: correct from checks. * Update base_context_managers.py * ci: correct from checks. * Update base_context_managers.py * ci: correct from checks. * Update base_context_managers.py * Update base_iterators.py * ci: correct from checks. * Update channel.py * ci: correct from checks. * Update guild.py * Update guild.py * ci: correct from checks. * Update utils.py * ci: correct from checks. * Update base_iterators.py * Update utils.py * ci: correct from checks. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent baecb46 commit 78d50b6

File tree

7 files changed

+572
-4
lines changed

7 files changed

+572
-4
lines changed

interactions/api/gateway/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,4 +921,5 @@ async def close(self) -> None:
921921
"""
922922
if self._client:
923923
await self._client.close()
924+
924925
self.__closed.set()

interactions/api/models/channel.py

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from asyncio import Task, create_task, get_running_loop, sleep
12
from datetime import datetime, timedelta, timezone
23
from enum import IntEnum
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
4+
from inspect import isawaitable
5+
from math import inf
6+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ContextManager, List, Optional, Union
7+
from warnings import warn
48

9+
from ...utils.abc.base_context_managers import BaseAsyncContextManager
10+
from ...utils.abc.base_iterators import DiscordPaginationIterator
511
from ...utils.attrs_utils import (
612
ClientSerializerMixin,
713
DictSerializerMixin,
@@ -18,6 +24,7 @@
1824

1925
if TYPE_CHECKING:
2026
from ...client.models.component import ActionRow, Button, SelectMenu
27+
from ..http.client import HTTPClient
2128
from .guild import Invite, InviteTargetType
2229
from .member import Member
2330
from .message import Attachment, Embed, Message, Sticker
@@ -28,6 +35,7 @@
2835
"Channel",
2936
"ThreadMember",
3037
"ThreadMetadata",
38+
"AsyncHistoryIterator",
3139
)
3240

3341

@@ -95,6 +103,181 @@ class ThreadMember(ClientSerializerMixin):
95103
mute_config: Optional[Any] = field(default=None) # todo explore this, it isn't in the ddev docs
96104

97105

106+
class AsyncHistoryIterator(DiscordPaginationIterator):
107+
"""
108+
A class object that allows iterating through a channel's history.
109+
110+
:param _client: The HTTPClient of the bot
111+
:type _client: HTTPClient
112+
:param obj: The channel to get the history from
113+
:type obj: Union[int, str, Snowflake, Channel]
114+
:param start_at?: The message to begin getting the history from
115+
:type start_at?: Optional[Union[int, str, Snowflake, Message]]
116+
:param reverse?: Whether to only get newer message. Default False
117+
:type reverse?: Optional[bool]
118+
:param check?: A check to ignore certain messages
119+
:type check?: Optional[Callable[[Member], bool]]
120+
:param maximum?: A set maximum of messages to get before stopping the iteration
121+
:type maximum?: Optional[int]
122+
"""
123+
124+
def __init__(
125+
self,
126+
_client: "HTTPClient",
127+
obj: Union[int, str, Snowflake, "Channel"],
128+
maximum: Optional[int] = inf,
129+
start_at: Optional[Union[int, str, Snowflake, "Message"]] = MISSING,
130+
check: Optional[Callable[["Message"], bool]] = None,
131+
reverse: Optional[bool] = False,
132+
):
133+
super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check)
134+
135+
from .message import Message
136+
137+
if reverse and start_at is MISSING:
138+
raise LibraryException(
139+
code=12,
140+
message="A message to start from is required to go through the channel in reverse.",
141+
)
142+
143+
if reverse:
144+
self.before = MISSING
145+
self.after = self.start_at
146+
else:
147+
self.before = self.start_at
148+
self.after = MISSING
149+
150+
self.objects: Optional[List[Message]]
151+
152+
async def get_first_objects(self) -> None:
153+
from .message import Message
154+
155+
limit = min(self.maximum, 100)
156+
157+
if self.maximum == limit:
158+
self.__stop = True
159+
160+
if self.after is not MISSING:
161+
msgs = await self._client.get_channel_messages(
162+
channel_id=self.object_id, after=self.after, limit=limit
163+
)
164+
msgs.reverse()
165+
self.after = int(msgs[-1]["id"])
166+
else:
167+
msgs = await self._client.get_channel_messages(
168+
channel_id=self.object_id, before=self.before, limit=limit
169+
)
170+
self.before = int(msgs[-1]["id"])
171+
172+
if len(msgs) < 100:
173+
# already all messages resolved with one operation
174+
self.__stop = True
175+
176+
self.object_count += limit
177+
178+
self.objects = [Message(**msg, _client=self._client) for msg in msgs]
179+
180+
async def flatten(self) -> List["Message"]:
181+
"""returns all remaining items as list"""
182+
return [item async for item in self]
183+
184+
async def get_objects(self) -> None:
185+
from .message import Message
186+
187+
limit = min(50, self.maximum - self.object_count)
188+
189+
if self.after is not MISSING:
190+
msgs = await self._client.get_channel_messages(
191+
channel_id=self.object_id, after=self.after, limit=limit
192+
)
193+
msgs.reverse()
194+
self.after = int(msgs[-1]["id"])
195+
else:
196+
msgs = await self._client.get_channel_messages(
197+
channel_id=self.object_id, before=self.before, limit=limit
198+
)
199+
self.before = int(msgs[-1]["id"])
200+
201+
if len(msgs) < limit or limit == self.maximum - self.object_count:
202+
# end of messages reached again
203+
self.__stop = True
204+
205+
self.object_count += limit
206+
207+
self.objects.extend([Message(**msg, _client=self._client) for msg in msgs])
208+
209+
async def __anext__(self) -> "Message":
210+
if self.objects is None:
211+
await self.get_first_objects()
212+
213+
try:
214+
obj = self.objects.pop(0)
215+
216+
if self.check:
217+
218+
res = self.check(obj)
219+
_res = await res if isawaitable(res) else res
220+
while not _res:
221+
if (
222+
not self.__stop
223+
and len(self.objects) < 5
224+
and self.object_count >= self.maximum
225+
):
226+
await self.get_objects()
227+
228+
self.object_count -= 1
229+
obj = self.objects.pop(0)
230+
231+
_res = self.check(obj)
232+
233+
if not self.__stop and len(self.objects) < 5 and self.object_count <= self.maximum:
234+
await self.get_objects()
235+
except IndexError:
236+
raise StopAsyncIteration
237+
else:
238+
return obj
239+
240+
241+
class AsyncTypingContextManager(BaseAsyncContextManager):
242+
"""
243+
An async context manager for triggering typing.
244+
245+
:param obj: The channel to trigger typing in.
246+
:type obj: Union[int, str, Snowflake, Channel]
247+
:param _client: The HTTPClient of the bot
248+
:type _client: HTTPClient
249+
"""
250+
251+
def __init__(
252+
self,
253+
obj: Union[int, str, "Snowflake", "Channel"],
254+
_client: "HTTPClient",
255+
):
256+
257+
try:
258+
self.loop = get_running_loop()
259+
except RuntimeError as e:
260+
raise RuntimeError("No running event loop detected!") from e
261+
262+
self.object_id = None if not obj else int(obj) if not hasattr(obj, "id") else int(obj.id)
263+
self._client = _client
264+
self.__task: Optional[Task] = None
265+
266+
def __await__(self):
267+
return self._client.trigger_typing(self.object_id).__await__()
268+
269+
async def do_action(self):
270+
while True:
271+
await self._client.trigger_typing(self.object_id)
272+
await sleep(8)
273+
274+
async def __aenter__(self):
275+
self.__task = create_task(self.do_action())
276+
277+
async def __aexit__(self, exc_type, exc_val, exc_tb):
278+
self.__task.cancel()
279+
280+
98281
@define()
99282
class Channel(ClientSerializerMixin, IDMixin):
100283
"""
@@ -180,6 +363,16 @@ def __attrs_post_init__(self): # sourcery skip: last-if-guard
180363
def __repr__(self) -> str:
181364
return self.name
182365

366+
@property
367+
def typing(self) -> Union[Awaitable, ContextManager]:
368+
"""
369+
Manages the typing of the channel. Use with `await` or `async with`
370+
371+
:return: A manager for typing
372+
:rtype: AsyncTypingContextManager
373+
"""
374+
return AsyncTypingContextManager(self, self._client)
375+
183376
@property
184377
def mention(self) -> str:
185378
"""
@@ -190,6 +383,33 @@ def mention(self) -> str:
190383
"""
191384
return f"<#{self.id}>"
192385

386+
def history(
387+
self,
388+
start_at: Optional[Union[int, str, Snowflake, "Message"]] = MISSING,
389+
reverse: Optional[bool] = False,
390+
maximum: Optional[int] = inf,
391+
check: Optional[Callable[["Message"], bool]] = None,
392+
) -> AsyncHistoryIterator:
393+
"""
394+
:param start_at?: The message to begin getting the history from
395+
:type start_at?: Optional[Union[int, str, Snowflake, Message]]
396+
:param reverse?: Whether to only get newer message. Default False
397+
:type reverse?: Optional[bool]
398+
:param maximum?: A set maximum of messages to get before stopping the iteration
399+
:type maximum?: Optional[int]
400+
:param check?: A custom check to ignore certain messages
401+
:type check?: Optional[Callable[[Message], bool]]
402+
403+
:return: An asynchronous iterator over the history of the channel
404+
:rtype: AsyncHistoryIterator
405+
"""
406+
if not self._client:
407+
raise LibraryException(code=13)
408+
409+
return AsyncHistoryIterator(
410+
self._client, self, start_at=start_at, reverse=reverse, maximum=maximum, check=check
411+
)
412+
193413
async def send(
194414
self,
195415
content: Optional[str] = MISSING,
@@ -1122,6 +1342,10 @@ async def get_history(self, limit: int = 100) -> Optional[List["Message"]]:
11221342
:rtype: List[Message]
11231343
"""
11241344

1345+
warn(
1346+
"This method has been deprecated in favour of the 'history' method.", DeprecationWarning
1347+
)
1348+
11251349
if not self._client:
11261350
raise LibraryException(code=13)
11271351

0 commit comments

Comments
 (0)