1+ from asyncio import Task , create_task , get_running_loop , sleep
12from datetime import datetime , timedelta , timezone
23from enum import IntEnum
3- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union
4+ from inspect import isawaitable
5+ from math import inf
6+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , ContextManager , List , Optional , Union
7+ from warnings import warn
48
9+ from ...utils .abc .base_context_managers import BaseAsyncContextManager
10+ from ...utils .abc .base_iterators import DiscordPaginationIterator
511from ...utils .attrs_utils import (
612 ClientSerializerMixin ,
713 DictSerializerMixin ,
1824
1925if TYPE_CHECKING :
2026 from ...client .models .component import ActionRow , Button , SelectMenu
27+ from ..http .client import HTTPClient
2128 from .guild import Invite , InviteTargetType
2229 from .member import Member
2330 from .message import Attachment , Embed , Message , Sticker
2835 "Channel" ,
2936 "ThreadMember" ,
3037 "ThreadMetadata" ,
38+ "AsyncHistoryIterator" ,
3139)
3240
3341
@@ -95,6 +103,181 @@ class ThreadMember(ClientSerializerMixin):
95103 mute_config : Optional [Any ] = field (default = None ) # todo explore this, it isn't in the ddev docs
96104
97105
106+ class AsyncHistoryIterator (DiscordPaginationIterator ):
107+ """
108+ A class object that allows iterating through a channel's history.
109+
110+ :param _client: The HTTPClient of the bot
111+ :type _client: HTTPClient
112+ :param obj: The channel to get the history from
113+ :type obj: Union[int, str, Snowflake, Channel]
114+ :param start_at?: The message to begin getting the history from
115+ :type start_at?: Optional[Union[int, str, Snowflake, Message]]
116+ :param reverse?: Whether to only get newer message. Default False
117+ :type reverse?: Optional[bool]
118+ :param check?: A check to ignore certain messages
119+ :type check?: Optional[Callable[[Member], bool]]
120+ :param maximum?: A set maximum of messages to get before stopping the iteration
121+ :type maximum?: Optional[int]
122+ """
123+
124+ def __init__ (
125+ self ,
126+ _client : "HTTPClient" ,
127+ obj : Union [int , str , Snowflake , "Channel" ],
128+ maximum : Optional [int ] = inf ,
129+ start_at : Optional [Union [int , str , Snowflake , "Message" ]] = MISSING ,
130+ check : Optional [Callable [["Message" ], bool ]] = None ,
131+ reverse : Optional [bool ] = False ,
132+ ):
133+ super ().__init__ (obj , _client , maximum = maximum , start_at = start_at , check = check )
134+
135+ from .message import Message
136+
137+ if reverse and start_at is MISSING :
138+ raise LibraryException (
139+ code = 12 ,
140+ message = "A message to start from is required to go through the channel in reverse." ,
141+ )
142+
143+ if reverse :
144+ self .before = MISSING
145+ self .after = self .start_at
146+ else :
147+ self .before = self .start_at
148+ self .after = MISSING
149+
150+ self .objects : Optional [List [Message ]]
151+
152+ async def get_first_objects (self ) -> None :
153+ from .message import Message
154+
155+ limit = min (self .maximum , 100 )
156+
157+ if self .maximum == limit :
158+ self .__stop = True
159+
160+ if self .after is not MISSING :
161+ msgs = await self ._client .get_channel_messages (
162+ channel_id = self .object_id , after = self .after , limit = limit
163+ )
164+ msgs .reverse ()
165+ self .after = int (msgs [- 1 ]["id" ])
166+ else :
167+ msgs = await self ._client .get_channel_messages (
168+ channel_id = self .object_id , before = self .before , limit = limit
169+ )
170+ self .before = int (msgs [- 1 ]["id" ])
171+
172+ if len (msgs ) < 100 :
173+ # already all messages resolved with one operation
174+ self .__stop = True
175+
176+ self .object_count += limit
177+
178+ self .objects = [Message (** msg , _client = self ._client ) for msg in msgs ]
179+
180+ async def flatten (self ) -> List ["Message" ]:
181+ """returns all remaining items as list"""
182+ return [item async for item in self ]
183+
184+ async def get_objects (self ) -> None :
185+ from .message import Message
186+
187+ limit = min (50 , self .maximum - self .object_count )
188+
189+ if self .after is not MISSING :
190+ msgs = await self ._client .get_channel_messages (
191+ channel_id = self .object_id , after = self .after , limit = limit
192+ )
193+ msgs .reverse ()
194+ self .after = int (msgs [- 1 ]["id" ])
195+ else :
196+ msgs = await self ._client .get_channel_messages (
197+ channel_id = self .object_id , before = self .before , limit = limit
198+ )
199+ self .before = int (msgs [- 1 ]["id" ])
200+
201+ if len (msgs ) < limit or limit == self .maximum - self .object_count :
202+ # end of messages reached again
203+ self .__stop = True
204+
205+ self .object_count += limit
206+
207+ self .objects .extend ([Message (** msg , _client = self ._client ) for msg in msgs ])
208+
209+ async def __anext__ (self ) -> "Message" :
210+ if self .objects is None :
211+ await self .get_first_objects ()
212+
213+ try :
214+ obj = self .objects .pop (0 )
215+
216+ if self .check :
217+
218+ res = self .check (obj )
219+ _res = await res if isawaitable (res ) else res
220+ while not _res :
221+ if (
222+ not self .__stop
223+ and len (self .objects ) < 5
224+ and self .object_count >= self .maximum
225+ ):
226+ await self .get_objects ()
227+
228+ self .object_count -= 1
229+ obj = self .objects .pop (0 )
230+
231+ _res = self .check (obj )
232+
233+ if not self .__stop and len (self .objects ) < 5 and self .object_count <= self .maximum :
234+ await self .get_objects ()
235+ except IndexError :
236+ raise StopAsyncIteration
237+ else :
238+ return obj
239+
240+
241+ class AsyncTypingContextManager (BaseAsyncContextManager ):
242+ """
243+ An async context manager for triggering typing.
244+
245+ :param obj: The channel to trigger typing in.
246+ :type obj: Union[int, str, Snowflake, Channel]
247+ :param _client: The HTTPClient of the bot
248+ :type _client: HTTPClient
249+ """
250+
251+ def __init__ (
252+ self ,
253+ obj : Union [int , str , "Snowflake" , "Channel" ],
254+ _client : "HTTPClient" ,
255+ ):
256+
257+ try :
258+ self .loop = get_running_loop ()
259+ except RuntimeError as e :
260+ raise RuntimeError ("No running event loop detected!" ) from e
261+
262+ self .object_id = None if not obj else int (obj ) if not hasattr (obj , "id" ) else int (obj .id )
263+ self ._client = _client
264+ self .__task : Optional [Task ] = None
265+
266+ def __await__ (self ):
267+ return self ._client .trigger_typing (self .object_id ).__await__ ()
268+
269+ async def do_action (self ):
270+ while True :
271+ await self ._client .trigger_typing (self .object_id )
272+ await sleep (8 )
273+
274+ async def __aenter__ (self ):
275+ self .__task = create_task (self .do_action ())
276+
277+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
278+ self .__task .cancel ()
279+
280+
98281@define ()
99282class Channel (ClientSerializerMixin , IDMixin ):
100283 """
@@ -180,6 +363,16 @@ def __attrs_post_init__(self): # sourcery skip: last-if-guard
180363 def __repr__ (self ) -> str :
181364 return self .name
182365
366+ @property
367+ def typing (self ) -> Union [Awaitable , ContextManager ]:
368+ """
369+ Manages the typing of the channel. Use with `await` or `async with`
370+
371+ :return: A manager for typing
372+ :rtype: AsyncTypingContextManager
373+ """
374+ return AsyncTypingContextManager (self , self ._client )
375+
183376 @property
184377 def mention (self ) -> str :
185378 """
@@ -190,6 +383,33 @@ def mention(self) -> str:
190383 """
191384 return f"<#{ self .id } >"
192385
386+ def history (
387+ self ,
388+ start_at : Optional [Union [int , str , Snowflake , "Message" ]] = MISSING ,
389+ reverse : Optional [bool ] = False ,
390+ maximum : Optional [int ] = inf ,
391+ check : Optional [Callable [["Message" ], bool ]] = None ,
392+ ) -> AsyncHistoryIterator :
393+ """
394+ :param start_at?: The message to begin getting the history from
395+ :type start_at?: Optional[Union[int, str, Snowflake, Message]]
396+ :param reverse?: Whether to only get newer message. Default False
397+ :type reverse?: Optional[bool]
398+ :param maximum?: A set maximum of messages to get before stopping the iteration
399+ :type maximum?: Optional[int]
400+ :param check?: A custom check to ignore certain messages
401+ :type check?: Optional[Callable[[Message], bool]]
402+
403+ :return: An asynchronous iterator over the history of the channel
404+ :rtype: AsyncHistoryIterator
405+ """
406+ if not self ._client :
407+ raise LibraryException (code = 13 )
408+
409+ return AsyncHistoryIterator (
410+ self ._client , self , start_at = start_at , reverse = reverse , maximum = maximum , check = check
411+ )
412+
193413 async def send (
194414 self ,
195415 content : Optional [str ] = MISSING ,
@@ -1122,6 +1342,10 @@ async def get_history(self, limit: int = 100) -> Optional[List["Message"]]:
11221342 :rtype: List[Message]
11231343 """
11241344
1345+ warn (
1346+ "This method has been deprecated in favour of the 'history' method." , DeprecationWarning
1347+ )
1348+
11251349 if not self ._client :
11261350 raise LibraryException (code = 13 )
11271351
0 commit comments