diff --git a/discord/__init__.py b/discord/__init__.py index bcffff183b..8e74431d94 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -72,7 +72,7 @@ from .sticker import * from .team import * from .template import * -from .threads import * +from .channel.thread import * from .user import * from .voice_client import * from .webhook import * diff --git a/discord/abc.py b/discord/abc.py index 546deea7fd..49a247283a 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,11 +26,9 @@ from __future__ import annotations import asyncio -import copy import time from typing import ( TYPE_CHECKING, - Any, Callable, Iterable, Protocol, @@ -43,17 +41,11 @@ from . import utils from .context_managers import Typing -from .enums import ChannelType from .errors import ClientException, InvalidArgument from .file import File, VoiceMessage -from .flags import ChannelFlags, MessageFlags -from .invite import Invite +from .flags import MessageFlags from .iterators import HistoryIterator, MessagePinIterator from .mentions import AllowedMentions -from .partial_emoji import PartialEmoji, _EmojiTag -from .permissions import PermissionOverwrite, Permissions -from .role import Role -from .scheduled_events import ScheduledEvent from .sticker import GuildSticker, StickerItem from .utils.private import warn_deprecated from .voice_client import VoiceClient, VoiceProtocol @@ -62,7 +54,6 @@ "Snowflake", "User", "PrivateChannel", - "GuildChannel", "Messageable", "Connectable", "Mentionable", @@ -76,7 +67,6 @@ from .app.state import ConnectionState from .asset import Asset from .channel import ( - CategoryChannel, DMChannel, GroupChannel, PartialMessageable, @@ -86,14 +76,9 @@ ) from .client import Client from .embeds import Embed - from .enums import InviteTarget - from .guild import Guild - from .member import Member from .message import Message, MessageReference, PartialMessage from .poll import Poll - from .threads import Thread - from .types.channel import Channel as ChannelPayload - from .types.channel import GuildChannel as GuildChannelPayload + from .channel.thread import Thread from .types.channel import OverwriteType from .types.channel import PermissionOverwrite as PermissionOverwritePayload from .ui.view import View diff --git a/discord/app/cache.py b/discord/app/cache.py index 545e1198ee..7f97b5f837 100644 --- a/discord/app/cache.py +++ b/discord/app/cache.py @@ -28,6 +28,7 @@ from discord import utils from discord.member import Member from discord.message import Message +from discord.soundboard import SoundboardSound from ..channel import DMChannel from ..emoji import AppEmoji, GuildEmoji @@ -142,6 +143,8 @@ async def store_private_channel(self, channel: "PrivateChannel") -> None: ... async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: ... + async def store_built_message(self, message: Message) -> None: ... + async def upsert_message(self, message: Message) -> None: ... async def delete_message(self, message_id: int) -> None: ... @@ -166,6 +169,14 @@ async def get_all_members(self) -> list[Member]: ... async def clear(self, views: bool = True) -> None: ... + async def store_sound(self, sound: SoundboardSound) -> None: ... + + async def get_sound(self, sound_id: int) -> SoundboardSound | None: ... + + async def get_all_sounds(self) -> list[SoundboardSound]: ... + + async def delete_sound(self, sound_id: int) -> None: ... + class MemoryCache(Cache): def __init__(self, max_messages: int | None = None) -> None: @@ -177,6 +188,7 @@ def __init__(self, max_messages: int | None = None) -> None: self._stickers: dict[int, list[GuildSticker]] = {} self._views: dict[str, View] = {} self._modals: dict[str, Modal] = {} + self._sounds: dict[int, SoundboardSound] = {} self._messages: Deque[Message] = deque(maxlen=self.max_messages) self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {} @@ -362,9 +374,15 @@ async def upsert_message(self, message: Message) -> None: async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: msg = await Message._from_data(state=self._state, channel=channel, data=message) - self._messages.append(msg) + self.store_built_message(msg) return msg + async def store_built_message(self, message: Message) -> None: + self._messages.append(message) + + async def delete_message(self, message_id: int) -> None: + self._messages.remove(utils.find(lambda m: m.id == message_id, reversed(self._messages))) + async def get_message(self, message_id: int) -> Message | None: return utils.find(lambda m: m.id == message_id, reversed(self._messages)) @@ -393,3 +411,15 @@ async def get_guild_members(self, guild_id: int) -> list[Member]: async def get_all_members(self) -> list[Member]: return self._flatten([list(members.values()) for members in self._guild_members.values()]) + + async def store_sound(self, sound: SoundboardSound) -> None: + self._sounds[sound.id] = sound + + async def get_sound(self, sound_id: int) -> SoundboardSound | None: + return self._sounds.get(sound_id) + + async def get_all_sounds(self) -> list[SoundboardSound]: + return list(self._sounds.values()) + + async def delete_sound(self, sound_id: int) -> None: + self._sounds.pop(sound_id, None) diff --git a/discord/app/event_emitter.py b/discord/app/event_emitter.py index bdcfb25f43..d669419956 100644 --- a/discord/app/event_emitter.py +++ b/discord/app/event_emitter.py @@ -24,9 +24,9 @@ import asyncio from abc import ABC, abstractmethod -from asyncio import Future from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from collections.abc import Awaitable, Coroutine +from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeAlias, TypeVar from typing_extensions import Self @@ -43,60 +43,81 @@ class Event(ABC): @abstractmethod async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: ... + def _populate_from_slots(self, obj: Any) -> None: + """ + Populate this event instance with attributes from another object. + + Handles both __slots__ and __dict__ based objects. + + Parameters + ---------- + obj: Any + The object to copy attributes from. + """ + # Collect all slots from the object's class hierarchy + slots = set() + for klass in type(obj).__mro__: + if hasattr(klass, "__slots__"): + slots.update(klass.__slots__) + + # Copy slot attributes + for slot in slots: + if hasattr(obj, slot): + try: + setattr(self, slot, getattr(obj, slot)) + except AttributeError: + # Some slots might be read-only or not settable + pass + + # Also copy __dict__ if it exists + if hasattr(obj, "__dict__"): + for key, value in obj.__dict__.items(): + try: + setattr(self, key, value) + except AttributeError: + pass + + +ListenerCallback: TypeAlias = Callable[[Event], Any] + + +class EventReciever(Protocol): + def __call__(self, event: Event) -> Awaitable[Any]: ... + class EventEmitter: def __init__(self, state: "ConnectionState") -> None: - self._listeners: dict[type[Event], list[Callable]] = {} - self._events: dict[str, list[type[Event]]] - self._wait_fors: dict[type[Event], list[Future]] = defaultdict(list) - self._state = state + self._receivers: list[EventReciever] = [] + self._events: dict[str, list[type[Event]]] = defaultdict(list) + self._state: ConnectionState = state + + from ..events import ALL_EVENTS # noqa: PLC0415 + + for event_cls in ALL_EVENTS: + self.add_event(event_cls) def add_event(self, event: type[Event]) -> None: - try: - self._events[event.__event_name__].append(event) - except KeyError: - self._events[event.__event_name__] = [event] + self._events[event.__event_name__].append(event) def remove_event(self, event: type[Event]) -> list[type[Event]] | None: return self._events.pop(event.__event_name__, None) - def add_listener(self, event: type[Event], listener: Callable) -> None: - try: - self._listeners[event].append(listener) - except KeyError: - self.add_event(event) - self._listeners[event] = [listener] - - def remove_listener(self, event: type[Event], listener: Callable) -> None: - self._listeners[event].remove(listener) - - def add_wait_for(self, event: type[T]) -> Future[T]: - fut = Future() + def add_receiver(self, receiver: EventReciever) -> None: + self._receivers.append(receiver) - self._wait_fors[event].append(fut) - - return fut - - def remove_wait_for(self, event: type[Event], fut: Future) -> None: - self._wait_fors[event].remove(fut) + def remove_receiver(self, receiver: EventReciever) -> None: + self._receivers.remove(receiver) async def emit(self, event_str: str, data: Any) -> None: events = self._events.get(event_str, []) - for event in events: - eve = await event.__load__(data=data, state=self._state) + coros: list[Awaitable[None]] = [] + for event_cls in events: + event = await event_cls.__load__(data=data, state=self._state) - if eve is None: + if event is None: continue - funcs = self._listeners.get(event, []) - - for func in funcs: - asyncio.create_task(func(eve)) - - wait_fors = self._wait_fors.get(event) + coros.extend(receiver(event) for receiver in self._receivers) - if wait_fors is not None: - for wait_for in wait_fors: - wait_for.set_result(eve) - self._wait_fors.pop(event) + await asyncio.gather(*coros) diff --git a/discord/app/state.py b/discord/app/state.py index b57ad5e776..7d4055bc21 100644 --- a/discord/app/state.py +++ b/discord/app/state.py @@ -44,6 +44,8 @@ cast, ) +from discord.soundboard import SoundboardSound + from .. import utils from ..activity import BaseActivity from ..automod import AutoModRule @@ -65,7 +67,7 @@ from ..raw_models import * from ..role import Role from ..sticker import GuildSticker -from ..threads import Thread, ThreadMember +from ..channel.thread import Thread, ThreadMember from ..ui.modal import Modal from ..ui.view import View from ..user import ClientUser, User @@ -237,12 +239,12 @@ def __init__( self._voice_clients: dict[int, VoiceClient] = {} if not intents.members or cache_flags._empty: - self.store_user = self.create_user # type: ignore + self.store_user = self.create_user_async # type: ignore self.deref_user = self.deref_user_no_intents # type: ignore self.cache_app_emojis: bool = options.get("cache_app_emojis", False) - self.emitter = EventEmitter(self) + self.emitter: EventEmitter = EventEmitter(self) self.cache: Cache = cache self.cache._state = self @@ -320,6 +322,9 @@ async def deref_user(self, user_id: int) -> None: def create_user(self, data: UserPayload) -> User: return User(state=self, data=data) + async def create_user_async(self, data: UserPayload) -> User: + return User(state=self, data=data) + def deref_user_no_intents(self, user_id: int) -> None: return @@ -373,6 +378,21 @@ async def _remove_guild(self, guild: Guild) -> None: del guild + async def _add_default_sounds(self) -> None: + default_sounds = await self.http.get_default_sounds() + for default_sound in default_sounds: + sound = SoundboardSound(state=self, http=self.http, data=default_sound) + await self._add_sound(sound) + + async def _add_sound(self, sound: SoundboardSound) -> None: + await self.cache.store_sound(sound) + + async def _remove_sound(self, sound: SoundboardSound) -> None: + await self.cache.delete_sound(sound.id) + + async def get_sounds(self) -> list[SoundboardSound]: + return list(await self.cache.get_all_sounds()) + async def get_emojis(self) -> list[GuildEmoji | AppEmoji]: return await self.cache.get_all_emojis() @@ -431,7 +451,7 @@ async def _get_guild_channel( # guild_id is in data guild = await self._get_guild(int(guild_id or data["guild_id"])) # type: ignore except KeyError: - channel = DMChannel._from_message(self, channel_id) + channel = DMChannel(id=channel_id, state=self) guild = None else: channel = guild and guild._resolve_channel(channel_id) @@ -487,15 +507,15 @@ async def query_members( ) raise - def _get_create_guild(self, data): + async def _get_create_guild(self, data): if data.get("unavailable") is False: # GUILD_CREATE with unavailable in the response # usually means that the guild has become available # and is therefore in the cache - guild = self._get_guild(int(data["id"])) + guild = await self._get_guild(int(data["id"])) if guild is not None: guild.unavailable = False - guild._from_data(data) + await guild._from_data(data, self) return guild return self._add_guild_from_data(data) @@ -660,6 +680,7 @@ async def _delay_ready(self) -> None: future = asyncio.ensure_future(self.chunk_guild(guild)) current_bucket.append(future) else: + await self._add_default_sounds() future = self.loop.create_future() future.set_result([]) diff --git a/discord/asset.py b/discord/asset.py index afdd47f3aa..684dcd1dd6 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -34,12 +34,15 @@ from . import utils from .errors import DiscordException, InvalidArgument +if TYPE_CHECKING: + from .app.state import ConnectionState + __all__ = ("Asset",) if TYPE_CHECKING: ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] - from .state import ConnectionState + from .app.state import ConnectionState VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) @@ -172,7 +175,7 @@ def __init__(self, state, *, url: str, key: str, animated: bool = False): self._key = key @classmethod - def _from_default_avatar(cls, state, index: int) -> Asset: + def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset: return cls( state, url=f"{cls.BASE}/embed/avatars/{index}.png", @@ -181,7 +184,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset: ) @classmethod - def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: + def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -192,7 +195,7 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: ) @classmethod - def _from_avatar_decoration(cls, state, user_id: int, avatar_decoration: str) -> Asset: + def _from_avatar_decoration(cls, state: ConnectionState, user_id: int, avatar_decoration: str) -> Asset: animated = avatar_decoration.startswith("a_") endpoint = ( "avatar-decoration-presets" @@ -232,7 +235,7 @@ def _from_user_primary_guild_tag(cls, state: ConnectionState, identity_guild_id: ) @classmethod - def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: + def _from_guild_avatar(cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -243,7 +246,7 @@ def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) - ) @classmethod - def _from_guild_banner(cls, state, guild_id: int, member_id: int, banner: str) -> Asset: + def _from_guild_banner(cls, state: ConnectionState, guild_id: int, member_id: int, banner: str) -> Asset: animated = banner.startswith("a_") format = "gif" if animated else "png" return cls( @@ -254,7 +257,7 @@ def _from_guild_banner(cls, state, guild_id: int, member_id: int, banner: str) - ) @classmethod - def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: + def _from_icon(cls, state: ConnectionState, object_id: int, icon_hash: str, path: str) -> Asset: return cls( state, url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", @@ -263,7 +266,7 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: ) @classmethod - def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: + def _from_cover_image(cls, state: ConnectionState, object_id: int, cover_image_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", @@ -282,7 +285,7 @@ def _from_collectible(cls, state: ConnectionState, asset: str, animated: bool = ) @classmethod - def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: + def _from_guild_image(cls, state: ConnectionState, guild_id: int, image: str, path: str) -> Asset: animated = False format = "png" if path == "banners": @@ -297,7 +300,7 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset ) @classmethod - def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: + def _from_guild_icon(cls, state: ConnectionState, guild_id: int, icon_hash: str) -> Asset: animated = icon_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -308,7 +311,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: ) @classmethod - def _from_sticker_banner(cls, state, banner: int) -> Asset: + def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", @@ -317,7 +320,7 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset: ) @classmethod - def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: + def _from_user_banner(cls, state: ConnectionState, user_id: int, banner_hash: str) -> Asset: animated = banner_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -328,7 +331,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: ) @classmethod - def _from_scheduled_event_image(cls, state, event_id: int, cover_hash: str) -> Asset: + def _from_scheduled_event_image(cls, state: ConnectionState, event_id: int, cover_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/guild-events/{event_id}/{cover_hash}.png", @@ -337,7 +340,7 @@ def _from_scheduled_event_image(cls, state, event_id: int, cover_hash: str) -> A ) @classmethod - def _from_soundboard_sound(cls, state, sound_id: int) -> Asset: + def _from_soundboard_sound(cls, state: ConnectionState, sound_id: int) -> Asset: return cls( state, url=f"{cls.BASE}/soundboard-sounds/{sound_id}", diff --git a/discord/audit_logs.py b/discord/audit_logs.py index e5d0b7a268..53615b7780 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -48,8 +48,8 @@ if TYPE_CHECKING: - from . import abc from .app.state import ConnectionState + from .channel.base import GuildChannel from .emoji import GuildEmoji from .guild import Guild from .member import Member @@ -57,7 +57,7 @@ from .scheduled_events import ScheduledEvent from .stage_instance import StageInstance from .sticker import GuildSticker - from .threads import Thread + from .channel.thread import Thread from .types.audit_log import AuditLogChange as AuditLogChangePayload from .types.audit_log import AuditLogEntry as AuditLogEntryPayload from .types.automod import AutoModAction as AutoModActionPayload @@ -80,13 +80,13 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel(entry: AuditLogEntry, data: Snowflake | None) -> abc.GuildChannel | Object | None: +def _transform_channel(entry: AuditLogEntry, data: Snowflake | None) -> GuildChannel | Object | None: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_channels(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[abc.GuildChannel | Object] | None: +def _transform_channels(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[GuildChannel | Object] | None: if data is None: return None return [_transform_channel(entry, channel) for channel in data] @@ -438,7 +438,7 @@ class _AuditLogProxyMemberPrune: class _AuditLogProxyMemberMoveOrMessageDelete: - channel: abc.GuildChannel + channel: GuildChannel count: int @@ -447,12 +447,12 @@ class _AuditLogProxyMemberDisconnect: class _AuditLogProxyPinAction: - channel: abc.GuildChannel + channel: GuildChannel message_id: int class _AuditLogProxyStageInstanceAction: - channel: abc.GuildChannel + channel: GuildChannel class AuditLogEntry(Hashable): @@ -593,7 +593,7 @@ async def get_target( self, ) -> ( Guild - | abc.GuildChannel + | GuildChannel | Member | User | Role @@ -639,7 +639,7 @@ def after(self) -> AuditLogDiff: def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel(self, target_id: int) -> abc.GuildChannel | Object: + def _convert_target_channel(self, target_id: int) -> GuildChannel | Object: return self.guild.get_channel(target_id) or Object(id=target_id) async def _convert_target_user(self, target_id: int) -> Member | User | None: diff --git a/discord/bot.py b/discord/bot.py index 41b126674a..bc9c683a43 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -46,7 +46,6 @@ ) from .client import Client -from .cog import CogMixin from .commands import ( ApplicationCommand, ApplicationContext, @@ -59,6 +58,7 @@ ) from .enums import IntegrationType, InteractionContextType, InteractionType from .errors import CheckFailure, DiscordException +from .events import InteractionCreate from .interactions import Interaction from .shard import AutoShardedClient from .types import interactions @@ -1082,7 +1082,7 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: ctx: :class:`.ApplicationCommand` The invocation context to invoke. """ - self._bot.dispatch("application_command", ctx) + # self._bot.dispatch("application_command", ctx) # TODO: Remove when moving away from ApplicationContext try: if await self._bot.can_run(ctx, call_once=True): await ctx.command.invoke(ctx) @@ -1091,14 +1091,15 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: except DiscordException as exc: await ctx.command.dispatch_error(ctx, exc) else: - self._bot.dispatch("application_command_completion", ctx) + # self._bot.dispatch("application_command_completion", ctx) # TODO: Remove when moving away from ApplicationContext + pass @property @abstractmethod def _bot(self) -> Bot | AutoShardedBot: ... -class BotBase(ApplicationCommandMixin, CogMixin, ABC): +class BotBase(ApplicationCommandMixin, ABC): _supports_prefixed_commands = False def __init__(self, description=None, *args, **options): @@ -1152,11 +1153,13 @@ def __init__(self, description=None, *args, **options): self._before_invoke = None self._after_invoke = None + self._bot.add_listener(self.on_interaction, event=InteractionCreate) + async def on_connect(self): if self.auto_sync_commands: await self.sync_commands() - async def on_interaction(self, interaction): + async def on_interaction(self, interaction: InteractionCreate): await self.process_application_commands(interaction) async def on_application_command_error(self, context: ApplicationContext, exception: DiscordException) -> None: diff --git a/discord/channel/__init__.py b/discord/channel/__init__.py new file mode 100644 index 0000000000..abaa5da9f5 --- /dev/null +++ b/discord/channel/__init__.py @@ -0,0 +1,94 @@ +from ..enums import ChannelType, try_enum +from .base import ( + BaseChannel, + GuildChannel, + GuildMessageableChannel, + GuildPostableChannel, + GuildThreadableChannel, + GuildTopLevelChannel, +) +from .category import CategoryChannel +from .dm import DMChannel +from .dm import GroupDMChannel as GroupChannel +from .forum import ForumChannel +from .media import MediaChannel +from .news import NewsChannel +from .partial import PartialMessageable +from .stage import StageChannel +from .text import TextChannel +from .thread import Thread +from .voice import VoiceChannel + +__all__ = ( + "BaseChannel", + "CategoryChannel", + "DMChannel", + "ForumChannel", + "GroupChannel", + "GuildChannel", + "GuildMessageableChannel", + "GuildPostableChannel", + "GuildThreadableChannel", + "GuildTopLevelChannel", + "MediaChannel", + "NewsChannel", + "PartialMessageable", + "StageChannel", + "TextChannel", + "Thread", + "VoiceChannel", +) + + +def _guild_channel_factory(channel_type: int): + value = try_enum(ChannelType, channel_type) + if value is ChannelType.text: + return TextChannel, value + elif value is ChannelType.voice: + return VoiceChannel, value + elif value is ChannelType.category: + return CategoryChannel, value + elif value is ChannelType.news: + return NewsChannel, value + elif value is ChannelType.stage_voice: + return StageChannel, value + elif value is ChannelType.directory: + return None, value # todo: Add DirectoryChannel when applicable + elif value is ChannelType.forum: + return ForumChannel, value + elif value is ChannelType.media: + return MediaChannel, value + else: + return None, value + + +def _channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return cls, value + + +def _threaded_channel_factory(channel_type: int): + cls, value = _channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value + + +def _threaded_guild_channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value diff --git a/discord/channel/base.py b/discord/channel/base.py new file mode 100644 index 0000000000..919735210a --- /dev/null +++ b/discord/channel/base.py @@ -0,0 +1,2021 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import copy +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Collection, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload + +from typing_extensions import Self, TypeVar, override + +from ..abc import Messageable, Snowflake, User, _Overwrites +from ..emoji import GuildEmoji, PartialEmoji +from ..enums import ChannelType, InviteTarget, SortOrder, try_enum +from ..flags import ChannelFlags, MessageFlags +from ..iterators import ArchivedThreadIterator +from ..mixins import Hashable +from ..utils import MISSING, Undefined, find, snowflake_time +from ..utils.private import copy_doc, get_as_snowflake + +if TYPE_CHECKING: + from ..embeds import Embed + from ..errors import InvalidArgument + from ..file import File + from ..guild import Guild + from ..invite import Invite + from ..member import Member + from ..mentions import AllowedMentions + from ..message import EmojiInputType, Message + from ..object import Object + from ..partial_emoji import _EmojiTag + from ..permissions import PermissionOverwrite, Permissions + from ..role import Role + from ..scheduled_events import ScheduledEvent + from ..sticker import GuildSticker, StickerItem + from ..types.channel import CategoryChannel as CategoryChannelPayload + from ..types.channel import Channel as ChannelPayload + from ..types.channel import ForumChannel as ForumChannelPayload + from ..types.channel import ForumTag as ForumTagPayload + from ..types.channel import GuildChannel as GuildChannelPayload + from ..types.channel import MediaChannel as MediaChannelPayload + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import StageChannel as StageChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from ..types.channel import VoiceChannel as VoiceChannelPayload + from ..types.guild import ChannelPositionUpdate as ChannelPositionUpdatePayload + from ..ui.view import View + from .category import CategoryChannel + from .channel import ForumTag + from .thread import Thread + +_log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from ..app.state import ConnectionState + + +P = TypeVar("P", bound="ChannelPayload") + + +class BaseChannel(ABC, Generic[P]): + __slots__: tuple[str, ...] = ("id", "_type", "_state", "_data") # pyright: ignore [reportIncompatibleUnannotatedOverride] + + def __init__(self, id: int, state: ConnectionState): + self.id: int = id + self._state: ConnectionState = state + self._data: P = {} # type: ignore + + async def _update(self, data: P) -> None: + self._type: int = data["type"] + self._data = self._data | data # type: ignore + + @classmethod + async def _from_data(cls, *, data: P, state: ConnectionState, **kwargs) -> Self: + if kwargs: + _log.warning("Unexpected keyword arguments passed to %s._from_data: %r", cls.__name__, kwargs) + self = cls(int(data["id"]), state) + await self._update(data) + return self + + @property + def type(self) -> ChannelType: + """The channel's Discord channel type.""" + return try_enum(ChannelType, self._type) + + async def _get_channel(self) -> Self: + return self + + @property + def created_at(self) -> datetime.datetime: + """The channel's creation time in UTC.""" + return snowflake_time(self.id) + + @abstractmethod + @override + def __repr__(self) -> str: ... + + @property + @abstractmethod + def jump_url(self) -> str: ... + + +P_guild = TypeVar( + "P_guild", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildChannel(BaseChannel[P_guild], ABC, Generic[P_guild]): + """Represents a Discord guild channel.""" + + """An ABC that details the common operations on a Discord guild channel. + + The following implement this ABC: + + - :class:`~discord.TextChannel` + - :class:`~discord.VoiceChannel` + - :class:`~discord.CategoryChannel` + - :class:`~discord.StageChannel` + - :class:`~discord.ForumChannel` + + This ABC must also implement :class:`~discord.abc.Snowflake`. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`~discord.Guild` + The guild the channel belongs to. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__: tuple[str, ...] = ("name", "guild", "category_id", "flags", "_overwrites") + + @override + def __init__(self, id: int, *, guild: Guild, state: ConnectionState) -> None: + self.guild: Guild = guild + super().__init__(id, state) + + @classmethod + @override + async def _from_data(cls, *, data: P_guild, state: ConnectionState, guild: Guild, **kwargs) -> Self: + if kwargs: + _log.warning("Unexpected keyword arguments passed to %s._from_data: %r", cls.__name__, kwargs) + self = cls(int(data["id"]), guild=guild, state=state) + await self._update(data) + return self + + @override + async def _update(self, data: P_guild) -> None: + await super()._update(data) + self.name: str = data["name"] + self.category_id: int | None = get_as_snowflake(data, "parent_id") or getattr(self, "category_id", None) + if flags_value := data.get("flags", 0): + self.flags: ChannelFlags = ChannelFlags._from_value(flags_value) + self._fill_overwrites(data) + + @override + def __str__(self) -> str: + return self.name + + async def _edit(self, options: dict[str, Any], reason: str | None) -> ChannelPayload | None: + try: + parent = options.pop("category") + except KeyError: + parent_id = MISSING + else: + parent_id = parent and parent.id + + try: + options["rate_limit_per_user"] = options.pop("slowmode_delay") + except KeyError: + pass + + try: + options["default_thread_rate_limit_per_user"] = options.pop("default_thread_slowmode_delay") + except KeyError: + pass + + try: + options["flags"] = options.pop("flags").value + except KeyError: + pass + + try: + options["available_tags"] = [tag.to_dict() for tag in options.pop("available_tags")] + except KeyError: + pass + + try: + rtc_region = options.pop("rtc_region") + except KeyError: + pass + else: + options["rtc_region"] = None if rtc_region is None else str(rtc_region) + + try: + video_quality_mode = options.pop("video_quality_mode") + except KeyError: + pass + else: + options["video_quality_mode"] = int(video_quality_mode) + + lock_permissions = options.pop("sync_permissions", False) + + try: + position = options.pop("position") + except KeyError: + if parent_id is not MISSING: + if lock_permissions: + category = self.guild.get_channel(parent_id) + if category: + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + options["parent_id"] = parent_id + elif lock_permissions and self.category_id is not None: + # if we're syncing permissions on a pre-existing channel category without changing it + # we need to update the permissions to point to the pre-existing category + category = self.guild.get_channel(self.category_id) + if category: + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + else: + await self._move( + position, + parent_id=parent_id, + lock_permissions=lock_permissions, + reason=reason, + ) + + overwrites = options.get("overwrites") + if overwrites is not None: + perms = [] + for target, perm in overwrites.items(): + if not isinstance(perm, PermissionOverwrite): + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") + + allow, deny = perm.pair() + payload = { + "allow": allow.value, + "deny": deny.value, + "id": target.id, + "type": (_Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER), + } + + perms.append(payload) + options["permission_overwrites"] = perms + + try: + ch_type = options["type"] + except KeyError: + pass + else: + if not isinstance(ch_type, ChannelType): + raise InvalidArgument("type field must be of type ChannelType") + options["type"] = ch_type.value + + try: + default_reaction_emoji = options["default_reaction_emoji"] + except KeyError: + pass + else: + if isinstance(default_reaction_emoji, _EmojiTag): # GuildEmoji, PartialEmoji + default_reaction_emoji = default_reaction_emoji._to_partial() + elif isinstance(default_reaction_emoji, int): + default_reaction_emoji = PartialEmoji(name=None, id=default_reaction_emoji) + elif isinstance(default_reaction_emoji, str): + default_reaction_emoji = PartialEmoji.from_str(default_reaction_emoji) + elif default_reaction_emoji is None: + pass + else: + raise InvalidArgument("default_reaction_emoji must be of type: GuildEmoji | int | str | None") + + options["default_reaction_emoji"] = ( + default_reaction_emoji._to_forum_reaction_payload() if default_reaction_emoji else None + ) + + if options: + return await self._state.http.edit_channel(self.id, reason=reason, **options) + + def _fill_overwrites(self, data: GuildChannelPayload) -> None: + self._overwrites: list[_Overwrites] = [] + everyone_index = 0 + everyone_id = self.guild.id + + for index, overridden in enumerate(data.get("permission_overwrites", [])): + overwrite = _Overwrites(overridden) + self._overwrites.append(overwrite) + + if overwrite.type == _Overwrites.MEMBER: + continue + + if overwrite.id == everyone_id: + # the @everyone role is not guaranteed to be the first one + # in the list of permission overwrites, however the permission + # resolution code kind of requires that it is the first one in + # the list since it is special. So we need the index so we can + # swap it to be the first one. + everyone_index = index + + # do the swap + tmp = self._overwrites + if tmp: + tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] + + @property + def changed_roles(self) -> list[Role]: + """Returns a list of roles that have been overridden from + their default values in the :attr:`~discord.Guild.roles` attribute. + """ + ret = [] + g = self.guild + for overwrite in filter(lambda o: o.is_role(), self._overwrites): + role = g.get_role(overwrite.id) + if role is None: + continue + + role = copy.copy(role) + role.permissions.handle_overwrite(overwrite.allow, overwrite.deny) + ret.append(role) + return ret + + @property + def mention(self) -> str: + """The string that allows you to mention the channel.""" + return f"<#{self.id}>" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f"https://discord.com/channels/{self.guild.id}/{self.id}" + + def overwrites_for(self, obj: Role | User) -> PermissionOverwrite: + """Returns the channel-specific overwrites for a member or a role. + + Parameters + ---------- + obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`] + The role or user denoting + whose overwrite to get. + + Returns + ------- + :class:`~discord.PermissionOverwrite` + The permission overwrites for this object. + """ + + if isinstance(obj, User): + predicate: Callable[[Any], bool] = lambda p: p.is_member() + elif isinstance(obj, Role): + predicate = lambda p: p.is_role() + else: + predicate = lambda p: True + + for overwrite in filter(predicate, self._overwrites): + if overwrite.id == obj.id: + allow = Permissions(overwrite.allow) + deny = Permissions(overwrite.deny) + return PermissionOverwrite.from_pair(allow, deny) + + return PermissionOverwrite() + + async def get_overwrites(self) -> dict[Role | Member | Object, PermissionOverwrite]: + """Returns all of the channel's overwrites. + + This is returned as a dictionary where the key contains the target which + can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the value is the + overwrite as a :class:`~discord.PermissionOverwrite`. + + Returns + ------- + Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`, :class:`~discord.Object`], :class:`~discord.PermissionOverwrite`] + The channel's permission overwrites. + """ + ret: dict[Role | Member | Object, PermissionOverwrite] = {} + for ow in self._overwrites: + allow = Permissions(ow.allow) + deny = Permissions(ow.deny) + overwrite = PermissionOverwrite.from_pair(allow, deny) + target = None + + if ow.is_role(): + target = self.guild.get_role(ow.id) + elif ow.is_member(): + target = await self.guild.get_member(ow.id) + + if target is not None: + ret[target] = overwrite + else: + ret[Object(id=ow.id)] = overwrite + return ret + + @property + def category(self) -> CategoryChannel | None: + """The category this channel belongs to. + + If there is no category then this is ``None``. + """ + return cast("CategoryChannel | None", self.guild.get_channel(self.category_id)) if self.category_id else None + + @property + def members(self) -> Collection[Member]: + """Returns all members that can view this channel. + + This is calculated based on the channel's permission overwrites and + the members' roles. + + Returns + ------- + Collection[:class:`Member`] + All members who have permission to view this channel. + """ + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + + async def permissions_are_synced(self) -> bool: + """Whether the permissions for this channel are synced with the + category it belongs to. + + If there is no category then this is ``False``. + + .. versionadded:: 3.0 + """ + if self.category_id is None: + return False + + category: CategoryChannel | None = cast("CategoryChannel | None", self.guild.get_channel(self.category_id)) + return bool(category and await category.get_overwrites() == await self.get_overwrites()) + + def permissions_for(self, obj: Member | Role, /) -> Permissions: + """Handles permission resolution for the :class:`~discord.Member` + or :class:`~discord.Role`. + + This function takes into consideration the following cases: + + - Guild owner + - Guild roles + - Channel overrides + - Member overrides + + If a :class:`~discord.Role` is passed, then it checks the permissions + someone with that role would have, which is essentially: + + - The default role permissions + - The permissions of the role used as a parameter + - The default role permission overwrites + - The permission overwrites of the role used as a parameter + + .. versionchanged:: 2.0 + The object passed in can now be a role object. + + Parameters + ---------- + obj: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The object to resolve permissions for. This could be either + a member or a role. If it's a role then member overwrites + are not computed. + + Returns + ------- + :class:`~discord.Permissions` + The resolved permissions for the member or role. + """ + + # The current cases can be explained as: + # Guild owner get all permissions -- no questions asked. Otherwise... + # The @everyone role gets the first application. + # After that, the applied roles that the user has in the channel + # (or otherwise) are then OR'd together. + # After the role permissions are resolved, the member permissions + # have to take into effect. + # After all that is done, you have to do the following: + + # If manage permissions is True, then all permissions are set to True. + + # The operation first takes into consideration the denied + # and then the allowed. + + if self.guild.owner_id == obj.id: + return Permissions.all() + + default = self.guild.default_role + base = Permissions(default.permissions.value if default else 0) + + # Handle the role case first + if isinstance(obj, Role): + base.value |= obj._permissions + + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + except IndexError: + pass + + if obj.is_default(): + return base + + overwrite = find(lambda o: o.type == _Overwrites.ROLE and o.id == obj.id, self._overwrites) + if overwrite is not None: + base.handle_overwrite(overwrite.allow, overwrite.deny) + + return base + + roles = obj._roles + get_role = self.guild.get_role + + # Apply guild roles that the member has. + for role_id in roles: + role = get_role(role_id) + if role is not None: + base.value |= role._permissions + + # Guild-wide Administrator -> True for everything + # Bypass all channel-specific overrides + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + remaining_overwrites = self._overwrites[1:] + else: + remaining_overwrites = self._overwrites + except IndexError: + remaining_overwrites = self._overwrites + + denies = 0 + allows = 0 + + # Apply channel specific role permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_role() and roles.has(overwrite.id): + denies |= overwrite.deny + allows |= overwrite.allow + + base.handle_overwrite(allow=allows, deny=denies) + + # Apply member specific permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_member() and overwrite.id == obj.id: + base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) + break + + # if you can't send a message in a channel then you can't have certain + # permissions as well + if not base.send_messages: + base.send_tts_messages = False + base.mention_everyone = False + base.embed_links = False + base.attach_files = False + + # if you can't read a channel then you have no permissions there + if not base.read_messages: + denied = Permissions.all_channel() + base.value &= ~denied.value + + return base + + async def delete(self, *, reason: str | None = None) -> None: + """|coro| + + Deletes the channel. + + You must have :attr:`~discord.Permissions.manage_channels` permission to use this. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this channel. + Shows up on the audit log. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to delete the channel. + ~discord.NotFound + The channel was not found or was already deleted. + ~discord.HTTPException + Deleting the channel failed. + """ + await self._state.http.delete_channel(self.id, reason=reason) + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: PermissionOverwrite | None = ..., + reason: str | None = ..., + ) -> None: ... + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: Undefined = MISSING, + reason: str | None = ..., + **permissions: bool, + ) -> None: ... + + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: PermissionOverwrite | None | Undefined = MISSING, + reason: str | None = None, + **permissions: bool, + ) -> None: + r"""|coro| + + Sets the channel specific permission overwrites for a target in the + channel. + + The ``target`` parameter should either be a :class:`~discord.Member` or a + :class:`~discord.Role` that belongs to guild. + + The ``overwrite`` parameter, if given, must either be ``None`` or + :class:`~discord.PermissionOverwrite`. For convenience, you can pass in + keyword arguments denoting :class:`~discord.Permissions` attributes. If this is + done, then you cannot mix the keyword arguments with the ``overwrite`` + parameter. + + If the ``overwrite`` parameter is ``None``, then the permission + overwrites are deleted. + + You must have the :attr:`~discord.Permissions.manage_roles` permission to use this. + + .. note:: + + This method *replaces* the old overwrites with the ones given. + + Examples + ---------- + + Setting allow and deny: :: + + await message.channel.set_permissions(message.author, read_messages=True, send_messages=False) + + Deleting overwrites :: + + await channel.set_permissions(member, overwrite=None) + + Using :class:`~discord.PermissionOverwrite` :: + + overwrite = discord.PermissionOverwrite() + overwrite.send_messages = False + overwrite.read_messages = True + await channel.set_permissions(member, overwrite=overwrite) + + Parameters + ----------- + target: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The member or role to overwrite permissions for. + overwrite: Optional[:class:`~discord.PermissionOverwrite`] + The permissions to allow and deny to the target, or ``None`` to + delete the overwrite. + \*\*permissions + A keyword argument list of permissions to set for ease of use. + Cannot be mixed with ``overwrite``. + reason: Optional[:class:`str`] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + ~discord.Forbidden + You do not have permissions to edit channel specific permissions. + ~discord.HTTPException + Editing channel specific permissions failed. + ~discord.NotFound + The role or member being edited is not part of the guild. + ~discord.InvalidArgument + The overwrite parameter invalid or the target type was not + :class:`~discord.Role` or :class:`~discord.Member`. + """ + + http = self._state.http + + if isinstance(target, User): + perm_type = _Overwrites.MEMBER + elif isinstance(target, Role): + perm_type = _Overwrites.ROLE + else: + raise InvalidArgument("target parameter must be either Member or Role") + + if overwrite is MISSING: + if len(permissions) == 0: + raise InvalidArgument("No overwrite provided.") + try: + overwrite = PermissionOverwrite(**permissions) + except (ValueError, TypeError) as e: + raise InvalidArgument("Invalid permissions given to keyword arguments.") from e + elif len(permissions) > 0: + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") + + # TODO: wait for event + + if overwrite is None: + await http.delete_channel_permissions(self.id, target.id, reason=reason) + elif isinstance(overwrite, PermissionOverwrite): + (allow, deny) = overwrite.pair() + await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) + else: + raise InvalidArgument("Invalid overwrite type provided.") + + async def _clone_impl( + self, + base_attrs: dict[str, Any], + *, + name: str | None = None, + reason: str | None = None, + ) -> Self: + base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] + base_attrs["parent_id"] = self.category_id + base_attrs["name"] = name or self.name + guild_id = self.guild.id + cls = self.__class__ + data: P_guild = cast( + "P_guild", await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) + ) + clone = cls(id=int(data["id"]), guild=self.guild, state=self._state) + await clone._update(data) + + self.guild._channels[clone.id] = clone + return clone + + async def clone(self, *, name: str | None = None, reason: str | None = None) -> Self: + """|coro| + + Clones this channel. This creates a channel with the same properties + as this channel. + + You must have the :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. versionadded:: 1.1 + + Parameters + ---------- + name: Optional[:class:`str`] + The name of the new channel. If not provided, defaults to this + channel name. + reason: Optional[:class:`str`] + The reason for cloning this channel. Shows up on the audit log. + + Returns + ------- + :class:`.abc.GuildChannel` + The channel that was created. + + Raises + ------ + ~discord.Forbidden + You do not have the proper permissions to create this channel. + ~discord.HTTPException + Creating the channel failed. + """ + raise NotImplementedError + + async def create_invite( + self, + *, + reason: str | None = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + target_event: ScheduledEvent | None = None, + target_type: InviteTarget | None = None, + target_user: User | None = None, + target_application_id: int | None = None, + ) -> Invite: + """|coro| + + Creates an instant invite from a text or voice channel. + + You must have the :attr:`~discord.Permissions.create_instant_invite` permission to + do this. + + Parameters + ---------- + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite + doesn't expire. Defaults to ``0``. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there + are unlimited uses. Defaults to ``0``. + temporary: :class:`bool` + Denotes that the invite grants temporary membership + (i.e. they get kicked after they disconnect). Defaults to ``False``. + unique: :class:`bool` + Indicates if a unique invite URL should be created. Defaults to True. + If this is set to ``False`` then it will return a previously created + invite. + reason: Optional[:class:`str`] + The reason for creating this invite. Shows up on the audit log. + target_type: Optional[:class:`.InviteTarget`] + The type of target for the voice channel invite, if any. + + .. versionadded:: 2.0 + + target_user: Optional[:class:`User`] + The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. + The user must be streaming in the channel. + + .. versionadded:: 2.0 + + target_application_id: Optional[:class:`int`] + The id of the embedded application for the invite, required if `target_type` is + `TargetType.embedded_application`. + + .. versionadded:: 2.0 + + target_event: Optional[:class:`.ScheduledEvent`] + The scheduled event object to link to the event. + Shortcut to :meth:`.Invite.set_scheduled_event` + + See :meth:`.Invite.set_scheduled_event` for more + info on event invite linking. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + ~discord.HTTPException + Invite creation failed. + + ~discord.NotFound + The channel that was passed is a category or an invalid channel. + """ + if target_type is InviteTarget.unknown: + raise TypeError("target_type cannot be unknown") + + data = await self._state.http.create_invite( + self.id, + reason=reason, + max_age=max_age, + max_uses=max_uses, + temporary=temporary, + unique=unique, + target_type=target_type.value if target_type else None, + target_user_id=target_user.id if target_user else None, + target_application_id=target_application_id, + ) + invite = await Invite.from_incomplete(data=data, state=self._state) + if target_event: + invite.set_scheduled_event(target_event) + return invite + + async def invites(self) -> list[Invite]: + """|coro| + + Returns a list of all active instant invites from this channel. + + You must have :attr:`~discord.Permissions.manage_channels` to get this information. + + Returns + ------- + List[:class:`~discord.Invite`] + The list of invites that are currently active. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to get the information. + ~discord.HTTPException + An error occurred while fetching the information. + """ + + data = await self._state.http.invites_from_channel(self.id) + guild = self.guild + return [Invite(state=self._state, data=invite, channel=self, guild=guild) for invite in data] + + +P_guild_top_level = TypeVar( + "P_guild_top_level", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildTopLevelChannel(GuildChannel[P_guild_top_level], ABC, Generic[P_guild_top_level]): + """An ABC for guild channels that can be positioned in the channel list. + + This includes categories and all channels that appear in the channel sidebar + (text, voice, news, stage, forum, media channels). Threads do not inherit from + this class as they are not positioned in the main channel list. + + .. versionadded:: 3.0 + + Attributes + ---------- + position: int + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__: tuple[str, ...] = ("position",) + + @override + async def _update(self, data: P_guild_top_level) -> None: + await super()._update(data) + self.position: int = data.get("position", 0) + + @property + @abstractmethod + def _sorting_bucket(self) -> int: + """Returns the bucket for sorting channels by type.""" + raise NotImplementedError + + async def _move( + self, + position: int, + parent_id: Any | None = None, + lock_permissions: bool = False, + *, + reason: str | None, + ) -> None: + """Internal method to move a channel to a specific position. + + Parameters + ---------- + position: int + The new position for the channel. + parent_id: Any | None + The parent category ID, if moving to a category. + lock_permissions: bool + Whether to sync permissions with the category. + reason: str | None + The reason for moving the channel. + + Raises + ------ + InvalidArgument + The position is less than 0. + """ + if position < 0: + raise InvalidArgument("Channel position cannot be less than 0.") + + bucket = self._sorting_bucket + channels: list[Self] = [c for c in self.guild.channels if c._sorting_bucket == bucket] + + channels.sort(key=lambda c: c.position) + + try: + # remove ourselves from the channel list + channels.remove(self) + except ValueError: + # not there somehow lol + return + else: + index = next( + (i for i, c in enumerate(channels) if c.position >= position), + len(channels), + ) + # add ourselves at our designated position + channels.insert(index, self) + + payload: list[ChannelPositionUpdatePayload] = [] + for index, c in enumerate(channels): + d: ChannelPositionUpdatePayload = {"id": c.id, "position": index} + if parent_id is not MISSING and c.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + + @overload + async def move( + self, + *, + beginning: bool, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | None | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + end: bool, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + before: Snowflake, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + after: Snowflake, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + async def move(self, **kwargs: Any) -> None: + """|coro| + + A rich interface to help move a channel relative to other channels. + + If exact position movement is required, ``edit`` should be used instead. + + You must have :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. note:: + + Voice channels will always be sorted below text channels. + This is a Discord limitation. + + .. versionadded:: 1.7 + + Parameters + ---------- + beginning: bool + Whether to move the channel to the beginning of the + channel list (or category if given). + This is mutually exclusive with ``end``, ``before``, and ``after``. + end: bool + Whether to move the channel to the end of the + channel list (or category if given). + This is mutually exclusive with ``beginning``, ``before``, and ``after``. + before: ~discord.abc.Snowflake + The channel that should be before our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``after``. + after: ~discord.abc.Snowflake + The channel that should be after our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``before``. + offset: int + The number of channels to offset the move by. For example, + an offset of ``2`` with ``beginning=True`` would move + it 2 after the beginning. A positive number moves it below + while a negative number moves it above. Note that this + number is relative and computed after the ``beginning``, + ``end``, ``before``, and ``after`` parameters. + category: ~discord.abc.Snowflake | None + The category to move this channel under. + If ``None`` is given then it moves it out of the category. + This parameter is ignored if moving a category channel. + sync_permissions: bool + Whether to sync the permissions with the category (if given). + reason: str | None + The reason for the move. + + Raises + ------ + InvalidArgument + An invalid position was given or a bad mix of arguments was passed. + Forbidden + You do not have permissions to move the channel. + HTTPException + Moving the channel failed. + """ + + if not kwargs: + return + + beginning, end = kwargs.get("beginning"), kwargs.get("end") + before, after = kwargs.get("before"), kwargs.get("after") + offset = kwargs.get("offset", 0) + if sum(bool(a) for a in (beginning, end, before, after)) > 1: + raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") + + bucket = self._sorting_bucket + parent_id = kwargs.get("category", MISSING) + channels: list[GuildChannel] + if parent_id not in (MISSING, None): + parent_id = parent_id.id + channels = [ + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id + ] + else: + channels = [ + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id + ] + + channels.sort(key=lambda c: (c.position, c.id)) + + try: + # Try to remove ourselves from the channel list + channels.remove(self) + except ValueError: + # If we're not there then it's probably due to not being in the category + pass + + index = None + if beginning: + index = 0 + elif end: + index = len(channels) + elif before: + index = next((i for i, c in enumerate(channels) if c.id == before.id), None) + elif after: + index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) + + if index is None: + raise InvalidArgument("Could not resolve appropriate move position") + # TODO: This could use self._move to avoid code duplication + channels.insert(max((index + offset), 0), self) + payload: list[ChannelPositionUpdatePayload] = [] + lock_permissions = kwargs.get("sync_permissions", False) + reason = kwargs.get("reason") + for index, channel in enumerate(channels): + d: ChannelPositionUpdatePayload = {"id": channel.id, "position": index} # pyright: ignore[reportAssignmentType] + if parent_id is not MISSING and channel.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + + +P_guild_threadable = TypeVar( + "P_guild_threadable", + bound="TextChannelPayload | NewsChannelPayload | ForumChannelPayload | MediaChannelPayload", + default="TextChannelPayload | NewsChannelPayload | ForumChannelPayload | MediaChannelPayload", +) + + +class GuildThreadableChannel(ABC): + """An ABC for guild channels that support thread creation. + + This includes text, news, forum, and media channels. + Voice, stage, and category channels do not support threads. + + This is a mixin class that adds threading capabilities to guild channels. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_auto_archive_duration: int + The default auto archive duration in minutes for threads created in this channel. + default_thread_slowmode_delay: int | None + The initial slowmode delay to set on newly created threads in this channel. + """ + + __slots__ = () # Mixin class - slots defined in concrete classes + + # Type hints for attributes that this mixin expects from the inheriting class + if TYPE_CHECKING: + id: int + guild: Guild + default_auto_archive_duration: int + default_thread_slowmode_delay: int | None + + async def _update(self, data) -> None: + """Update threadable channel attributes.""" + await super()._update(data) # Call next in MRO + self.default_auto_archive_duration: int = data.get("default_auto_archive_duration", 1440) + self.default_thread_slowmode_delay: int | None = data.get("default_thread_rate_limit_per_user") + + @property + def threads(self) -> list[Thread]: + """Returns all the threads that you can see in this channel. + + .. versionadded:: 2.0 + + Returns + ------- + list[:class:`Thread`] + All active threads in this channel. + """ + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] + + def get_thread(self, thread_id: int, /) -> Thread | None: + """Returns a thread with the given ID. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_id: int + The ID to search for. + + Returns + ------- + Thread | None + The returned thread or ``None`` if not found. + """ + return self.guild.get_thread(thread_id) + + def archived_threads( + self, + *, + private: bool = False, + joined: bool = False, + limit: int | None = 50, + before: Snowflake | datetime.datetime | None = None, + ) -> ArchivedThreadIterator: + """Returns an iterator that iterates over all archived threads in the channel. + + You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads + then :attr:`~Permissions.manage_threads` is also required. + + .. versionadded:: 2.0 + + Parameters + ---------- + limit: int | None + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Snowflake | datetime.datetime | None + Retrieve archived channels before the given date or ID. + private: bool + Whether to retrieve private archived threads. + joined: bool + Whether to retrieve private archived threads that you've joined. + You cannot set ``joined`` to ``True`` and ``private`` to ``False``. + + Yields + ------ + :class:`Thread` + The archived threads. + + Raises + ------ + Forbidden + You do not have permissions to get archived threads. + HTTPException + The request to get the archived threads failed. + """ + return ArchivedThreadIterator( + self.id, + self.guild, + limit=limit, + joined=joined, + private=private, + before=before, + ) + + +P_guild_postable = TypeVar( + "P_guild_postable", + bound="ForumChannelPayload | MediaChannelPayload", + default="ForumChannelPayload | MediaChannelPayload", +) + + +class ForumTag(Hashable): + """Represents a forum tag that can be added to a thread inside a :class:`ForumChannel` + . + .. versionadded:: 2.3 + + .. container:: operations + + .. describe:: x == y + + Checks if two forum tags are equal. + + .. describe:: x != y + + Checks if two forum tags are not equal. + + .. describe:: hash(x) + + Returns the forum tag's hash. + + .. describe:: str(x) + + Returns the forum tag's name. + + Attributes + ---------- + id: :class:`int` + The tag ID. + Note that if the object was created manually then this will be ``0``. + name: :class:`str` + The name of the tag. Can only be up to 20 characters. + moderated: :class:`bool` + Whether this tag can only be added or removed by a moderator with + the :attr:`~Permissions.manage_threads` permission. + emoji: :class:`PartialEmoji` + The emoji that is used to represent this tag. + Note that if the emoji is a custom emoji, it will *not* have name information. + """ + + __slots__ = ("name", "id", "moderated", "emoji") + + def __init__(self, *, name: str, emoji: EmojiInputType, moderated: bool = False) -> None: + self.name: str = name + self.id: int = 0 + self.moderated: bool = moderated + self.emoji: PartialEmoji + if isinstance(emoji, _EmojiTag): + self.emoji = emoji._to_partial() + elif isinstance(emoji, str): + self.emoji = PartialEmoji.from_str(emoji) + else: + raise TypeError(f"emoji must be a GuildEmoji, PartialEmoji, or str and not {emoji.__class__!r}") + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name + + @classmethod + def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag: + self = cls.__new__(cls) + self.name = data["name"] + self.id = int(data["id"]) + self.moderated = data.get("moderated", False) + + emoji_name = data["emoji_name"] or "" + emoji_id = get_as_snowflake(data, "emoji_id") or None + self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) + return self + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "name": self.name, + "moderated": self.moderated, + } | self.emoji._to_forum_reaction_payload() + + if self.id: + payload["id"] = self.id + + return payload + + +class GuildPostableChannel( + GuildTopLevelChannel[P_guild_postable], GuildThreadableChannel, ABC, Generic[P_guild_postable] +): + """An ABC for guild channels that support posts (threads with tags). + + This is a common base for forum and media channels. These channels don't support + direct messaging, but users create posts (which are threads) with associated tags. + + .. versionadded:: 3.0 + + Attributes + ---------- + topic: str | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + nsfw: bool + Whether the channel is marked as NSFW. + slowmode_delay: int + The number of seconds a member must wait between creating posts + in this channel. A value of ``0`` denotes that it is disabled. + last_message_id: int | None + The ID of the last message sent in this channel. It may not always point to an existing or valid message. + available_tags: list[ForumTag] + The set of tags that can be used in this channel. + default_sort_order: SortOrder | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: str | GuildEmoji | None + The default reaction emoji for posts in this channel. + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + "available_tags", + "default_sort_order", + "default_reaction_emoji", + ) + + @override + async def _update(self, data: P_guild_postable) -> None: + await super()._update(data) + if not data.pop("_invoke_flag", False): + self.topic: str | None = data.get("topic") + self.nsfw: bool = data.get("nsfw", False) + self.slowmode_delay: int = data.get("rate_limit_per_user", 0) + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") + + self.available_tags: list[ForumTag] = [ + ForumTag.from_data(state=self._state, data=tag) for tag in (data.get("available_tags") or []) + ] + self.default_sort_order: SortOrder | None = data.get("default_sort_order", None) + if self.default_sort_order is not None: + self.default_sort_order = try_enum(SortOrder, self.default_sort_order) + + self.default_reaction_emoji = None + reaction_emoji_ctx: dict = data.get("default_reaction_emoji") + if reaction_emoji_ctx is not None: + emoji_name = reaction_emoji_ctx.get("emoji_name") + if emoji_name is not None: + self.default_reaction_emoji = reaction_emoji_ctx["emoji_name"] + else: + emoji_id = get_as_snowflake(reaction_emoji_ctx, "emoji_id") + if emoji_id: + self.default_reaction_emoji = await self._state.get_emoji(emoji_id) + + @property + def guidelines(self) -> str | None: + """The channel's guidelines. An alias of :attr:`topic`.""" + return self.topic + + @property + def requires_tag(self) -> bool: + """Whether a tag is required to be specified when creating a post in this channel. + + .. versionadded:: 2.3 + """ + return self.flags.require_tag + + def get_tag(self, id: int, /) -> ForumTag | None: + """Returns the :class:`ForumTag` from this channel with the given ID, if any. + + .. versionadded:: 2.3 + """ + return find(lambda t: t.id == id, self.available_tags) + + async def create_thread( + self, + name: str, + content: str | None = None, + *, + embed: Embed | None = None, + embeds: list[Embed] | None = None, + file: File | None = None, + files: list[File] | None = None, + stickers: Sequence[GuildSticker | StickerItem] | None = None, + delete_message_after: float | None = None, + nonce: int | str | None = None, + allowed_mentions: AllowedMentions | None = None, + view: View | None = None, + applied_tags: list[ForumTag] | None = None, + suppress: bool = False, + silent: bool = False, + auto_archive_duration: int | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a post (thread with initial message) in this forum or media channel. + + To create a post, you must have :attr:`~discord.Permissions.create_public_threads` or + :attr:`~discord.Permissions.send_messages` permission. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the post/thread. + content: :class:`str` + The content of the initial message. + embed: :class:`~discord.Embed` + The rich embed for the content. + embeds: list[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + file: :class:`~discord.File` + The file to upload. + files: list[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + stickers: Sequence[:class:`~discord.GuildSticker` | :class:`~discord.StickerItem`] + A list of stickers to upload. Must be a maximum of 3. + delete_message_after: :class:`float` + The time in seconds to wait before deleting the initial message. + nonce: :class:`str` | :class:`int` + The nonce to use for sending this message. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. + view: :class:`discord.ui.View` + A Discord UI View to add to the message. + applied_tags: list[:class:`ForumTag`] + A list of tags to apply to the new post. + suppress: :class:`bool` + Whether to suppress embeds in the initial message. + silent: :class:`bool` + Whether to send the message without triggering a notification. + auto_archive_duration: :class:`int` + The duration in minutes before the post is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages in the new post. + If not provided, the channel's default slowmode is used. + reason: :class:`str` + The reason for creating the post. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created post/thread. + + Raises + ------ + Forbidden + You do not have permissions to create a post. + HTTPException + Creating the post failed. + InvalidArgument + You provided invalid arguments. + """ + from ..errors import InvalidArgument + from ..file import File + from ..flags import MessageFlags + + state = self._state + message_content = str(content) if content is not None else None + + if embed is not None and embeds is not None: + raise InvalidArgument("cannot pass both embed and embeds parameter to create_thread()") + + if embed is not None: + embed = embed.to_dict() + + elif embeds is not None: + if len(embeds) > 10: + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") + embeds = [e.to_dict() for e in embeds] + + if stickers is not None: + stickers = [sticker.id for sticker in stickers] + + if allowed_mentions is None: + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() + elif state.allowed_mentions is not None: + allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + else: + allowed_mentions = allowed_mentions.to_dict() + + flags = MessageFlags( + suppress_embeds=bool(suppress), + suppress_notifications=bool(silent), + ) + + if view: + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") + + components = view.to_components() + if view.is_components_v2(): + if embeds or content: + raise TypeError("cannot send embeds or content with a view using v2 component logic") + flags.is_components_v2 = True + else: + components = None + + if applied_tags is not None: + applied_tags = [str(tag.id) for tag in applied_tags] + + if file is not None and files is not None: + raise InvalidArgument("cannot pass both file and files parameter to create_thread()") + + if files is not None: + if len(files) > 10: + raise InvalidArgument("files parameter must be a list of up to 10 elements") + elif not all(isinstance(f, File) for f in files): + raise InvalidArgument("files parameter must be a list of File") + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + files = [file] + + try: + data = await state.http.start_forum_thread( + self.id, + content=message_content, + name=name, + files=files, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + stickers=stickers, + components=components, + auto_archive_duration=auto_archive_duration + if auto_archive_duration is not MISSING + else self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay + if slowmode_delay is not MISSING + else self.default_thread_slowmode_delay, + applied_tags=applied_tags, + flags=flags.value, + reason=reason, + ) + finally: + if files is not None: + for f in files: + f.close() + + from .thread import Thread + + ret = Thread(guild=self.guild, state=self._state, data=data) + msg = ret.get_partial_message(int(data["last_message_id"])) + if view and view.is_dispatchable(): + await state.store_view(view, msg.id) + + if delete_message_after is not None: + await msg.delete(delay=delete_message_after) + return ret + + +P_guild_messageable = TypeVar( + "P_guild_messageable", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildMessageableChannel(Messageable, ABC): + """An ABC mixin for guild channels that support messaging. + + This includes text and news channels, as well as threads. Voice and stage channels + do not support direct messaging (though they can have threads). + + This is a mixin class that adds messaging capabilities to guild channels. + + .. versionadded:: 3.0 + + Attributes + ---------- + topic: str | None + The channel's topic. ``None`` if it doesn't exist. + nsfw: bool + Whether the channel is marked as NSFW. + slowmode_delay: int + The number of seconds a member must wait between sending messages + in this channel. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + last_message_id: int | None + The ID of the last message sent in this channel. It may not always point to an existing or valid message. + """ + + __slots__ = () # Mixin class - slots defined in concrete classes + + # Attributes expected from inheriting classes + id: int + guild: Guild + _state: ConnectionState + topic: str | None + nsfw: bool + slowmode_delay: int + last_message_id: int | None + + async def _update(self, data) -> None: + """Update mutable attributes from API payload.""" + await super()._update(data) + # This data may be missing depending on how this object is being created/updated + if not data.pop("_invoke_flag", False): + self.topic = data.get("topic") + self.nsfw = data.get("nsfw", False) + # Does this need coercion into `int`? No idea yet. + self.slowmode_delay = data.get("rate_limit_per_user", 0) + self.last_message_id = get_as_snowflake(data, "last_message_id") + + @copy_doc(GuildChannel.permissions_for) + @override + def permissions_for(self, obj: Member | Role, /) -> Permissions: + base = super().permissions_for(obj) + + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + async def get_members(self) -> list[Member]: + """Returns all members that can see this channel.""" + return [m for m in await self.guild.get_members() if self.permissions_for(m).read_messages] + + async def get_last_message(self) -> Message | None: + """Fetches the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + ------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return await self._state._get_message(self.last_message_id) if self.last_message_id else None + + async def edit(self, **options) -> _TextChannel: + """Edits the channel.""" + raise NotImplementedError + + @copy_doc(GuildChannel.clone) + @override + async def clone(self, *, name: str | None = None, reason: str | None = None) -> Self: + return await self._clone_impl( + { + "topic": self.topic, + "nsfw": self.nsfw, + "rate_limit_per_user": self.slowmode_delay, + }, + name=name, + reason=reason, + ) + + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: str | None = None) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this. + + Parameters + ---------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id: int = messages[0].id + await self._state.http.delete_message(self.id, message_id, reason=reason) + return + + if len(messages) > 100: + raise ClientException("Can only bulk delete messages up to 100 messages") + + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: int | None = 100, + check: Callable[[Message], bool] | utils.Undefined = MISSING, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = False, + bulk: bool = True, + reason: str | None = None, + ) -> list[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have the :attr:`~Permissions.manage_messages` permission to + delete messages even if they are your own. + The :attr:`~Permissions.read_message_history` permission is + also needed to retrieve message history. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + bulk: :class:`bool` + If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting + a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Returns + ------- + List[:class:`.Message`] + The list of messages that were deleted. + + Raises + ------ + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Examples + -------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f"Deleted {len(deleted)} message(s)") + """ + return await discord.abc._purge_messages_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> list[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Returns + ------- + List[:class:`Webhook`] + The webhooks for this channel. + + Raises + ------ + Forbidden + You don't have permissions to get the webhooks. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason: str | None = None) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionchanged:: 1.1 + Added the ``reason`` keyword-only parameter. + + Parameters + ---------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = bytes_to_base64_data(avatar) # type: ignore + + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + return Webhook.from_state(data, state=self._state) + + async def follow(self, *, destination: TextChannel, reason: str | None = None) -> Webhook: + """ + Follows a channel using a webhook. + + Only news channels can be followed. + + .. note:: + + The webhook returned will not provide a token to do webhook + actions, as Discord does not provide it. + + .. versionadded:: 1.3 + + Parameters + ---------- + destination: :class:`TextChannel` + The channel you would like to follow from. + reason: Optional[:class:`str`] + The reason for following the channel. Shows up on the destination guild's audit log. + + .. versionadded:: 1.4 + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Following the channel failed. + Forbidden + You do not have the permissions to create a webhook. + """ + + from .news import NewsChannel + + if not isinstance(self, NewsChannel): + raise ClientException("The channel must be a news channel.") + + if not isinstance(destination, TextChannel): + raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}") + + from .webhook import Webhook + + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) + return Webhook._as_follower(data, channel=destination, user=self._state.user) + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) diff --git a/discord/channel/category.py b/discord/channel/category.py new file mode 100644 index 0000000000..e95e8b167a --- /dev/null +++ b/discord/channel/category.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, overload + +from typing_extensions import override + +if TYPE_CHECKING: + from collections.abc import Mapping + + from ..app.state import ConnectionState + from ..guild import Guild + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from . import ForumChannel, StageChannel, TextChannel, VoiceChannel + +from ..enums import ChannelType, try_enum +from ..flags import ChannelFlags +from ..types.channel import CategoryChannel as CategoryChannelPayload +from ..utils.private import copy_doc +from .base import GuildChannel, GuildTopLevelChannel + + +def comparator(channel: GuildChannel): + # Sorts channels so voice channels (VoiceChannel, StageChannel) appear below non-voice channels + return isinstance(channel, (VoiceChannel, StageChannel)), (channel.position or -1) + + +class CategoryChannel(GuildTopLevelChannel[CategoryChannelPayload]): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ---------- + name: str + The category name. + guild: Guild + The guild the category belongs to. + id: int + The category channel ID. + position: int + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. + flags: ChannelFlags + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + __slots__: tuple[str, ...] = () + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.category.value + + @property + def type(self) -> ChannelType: + """The channel's Discord type.""" + return try_enum(ChannelType, self._type) + + @copy_doc(GuildChannel.clone) + async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: + return await self._clone_impl({}, name=name, reason=reason) + + @overload + async def edit( + self, + *, + name: str = ..., + position: int = ..., + overwrites: Mapping[Role | Member, PermissionOverwrite] = ..., + reason: str | None = ..., + ) -> CategoryChannel | None: ... + + @overload + async def edit(self) -> CategoryChannel | None: ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new category's name. + position: :class:`int` + The new category's position. + reason: Optional[:class:`str`] + The reason for editing this category. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + + Returns + ------- + Optional[:class:`.CategoryChannel`] + The newly edited category channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of categories. + Forbidden + You do not have permissions to edit the category. + HTTPException + Editing the category failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + @copy_doc(GuildTopLevelChannel.move) + async def move(self, **kwargs): + kwargs.pop("category", None) + await super().move(**kwargs) + + @property + def channels(self) -> list[GuildChannelType]: + """Returns the channels that are under this category. + + These are sorted by the official Discord UI, which places voice channels below the text channels. + """ + + ret = [c for c in self.guild.channels if c.category_id == self.id] + ret.sort(key=comparator) + return ret + + @property + def text_channels(self) -> list[TextChannel]: + """Returns the text channels that are under this category.""" + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def voice_channels(self) -> list[VoiceChannel]: + """Returns the voice channels that are under this category.""" + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def stage_channels(self) -> list[StageChannel]: + """Returns the stage channels that are under this category. + + .. versionadded:: 1.7 + """ + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def forum_channels(self) -> list[ForumChannel]: + """Returns the forum channels that are under this category. + + .. versionadded:: 2.0 + """ + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, ForumChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + async def create_text_channel(self, name: str, **options: Any) -> TextChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category. + + Returns + ------- + :class:`TextChannel` + The channel that was just created. + """ + return await self.guild.create_text_channel(name, category=self, **options) + + async def create_voice_channel(self, name: str, **options: Any) -> VoiceChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category. + + Returns + ------- + :class:`VoiceChannel` + The channel that was just created. + """ + return await self.guild.create_voice_channel(name, category=self, **options) + + async def create_stage_channel(self, name: str, **options: Any) -> StageChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category. + + .. versionadded:: 1.7 + + Returns + ------- + :class:`StageChannel` + The channel that was just created. + """ + return await self.guild.create_stage_channel(name, category=self, **options) + + async def create_forum_channel(self, name: str, **options: Any) -> ForumChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_forum_channel` to create a :class:`ForumChannel` in the category. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`ForumChannel` + The channel that was just created. + """ + return await self.guild.create_forum_channel(name, category=self, **options) diff --git a/discord/channel.py b/discord/channel/channel.py.old similarity index 93% rename from discord/channel.py rename to discord/channel/channel.py.old index 57e894b8c4..00b149b748 100644 --- a/discord/channel.py +++ b/discord/channel/channel.py.old @@ -39,12 +39,14 @@ overload, ) +from typing_extensions import override + import discord.abc -from . import utils -from .asset import Asset -from .emoji import GuildEmoji -from .enums import ( +from .. import utils +from ..asset import Asset +from ..emoji import GuildEmoji +from ..enums import ( ChannelType, EmbeddedActivity, InviteTarget, @@ -55,21 +57,21 @@ VoiceRegion, try_enum, ) -from .enums import ThreadArchiveDuration as ThreadArchiveDurationEnum -from .errors import ClientException, InvalidArgument -from .file import File -from .flags import ChannelFlags, MessageFlags -from .invite import Invite -from .iterators import ArchivedThreadIterator -from .mixins import Hashable -from .object import Object -from .partial_emoji import PartialEmoji, _EmojiTag -from .permissions import PermissionOverwrite, Permissions -from .soundboard import PartialSoundboardSound, SoundboardSound -from .stage_instance import StageInstance -from .threads import Thread -from .utils import MISSING -from .utils.private import bytes_to_base64_data, copy_doc, get_as_snowflake +from ..enums import ThreadArchiveDuration as ThreadArchiveDurationEnum +from ..errors import ClientException, InvalidArgument +from ..file import File +from ..flags import ChannelFlags, MessageFlags +from ..invite import Invite +from ..iterators import ArchivedThreadIterator +from ..mixins import Hashable +from ..object import Object +from ..partial_emoji import PartialEmoji, _EmojiTag +from ..permissions import PermissionOverwrite, Permissions +from ..soundboard import PartialSoundboardSound, SoundboardSound +from ..stage_instance import StageInstance +from .thread import Thread +from ..utils import MISSING +from ..utils.private import bytes_to_base64_data, copy_doc, get_as_snowflake __all__ = ( "TextChannel", @@ -86,112 +88,30 @@ ) if TYPE_CHECKING: - from .abc import Snowflake, SnowflakeTime - from .app.state import ConnectionState - from .embeds import Embed - from .guild import Guild - from .guild import GuildChannel as GuildChannelType - from .member import Member, VoiceState - from .mentions import AllowedMentions - from .message import EmojiInputType, Message, PartialMessage - from .role import Role - from .sticker import GuildSticker, StickerItem - from .types.channel import CategoryChannel as CategoryChannelPayload - from .types.channel import DMChannel as DMChannelPayload - from .types.channel import ForumChannel as ForumChannelPayload - from .types.channel import ForumTag as ForumTagPayload - from .types.channel import GroupDMChannel as GroupChannelPayload - from .types.channel import StageChannel as StageChannelPayload - from .types.channel import TextChannel as TextChannelPayload - from .types.channel import VoiceChannel as VoiceChannelPayload - from .types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend - from .types.snowflake import SnowflakeList - from .types.threads import ThreadArchiveDuration - from .ui.view import View - from .user import BaseUser, ClientUser, User - from .webhook import Webhook - - -class ForumTag(Hashable): - """Represents a forum tag that can be added to a thread inside a :class:`ForumChannel` - . - .. versionadded:: 2.3 - - .. container:: operations - - .. describe:: x == y - - Checks if two forum tags are equal. - - .. describe:: x != y - - Checks if two forum tags are not equal. - - .. describe:: hash(x) - - Returns the forum tag's hash. - - .. describe:: str(x) - - Returns the forum tag's name. - - Attributes - ---------- - id: :class:`int` - The tag ID. - Note that if the object was created manually then this will be ``0``. - name: :class:`str` - The name of the tag. Can only be up to 20 characters. - moderated: :class:`bool` - Whether this tag can only be added or removed by a moderator with - the :attr:`~Permissions.manage_threads` permission. - emoji: :class:`PartialEmoji` - The emoji that is used to represent this tag. - Note that if the emoji is a custom emoji, it will *not* have name information. - """ - - __slots__ = ("name", "id", "moderated", "emoji") - - def __init__(self, *, name: str, emoji: EmojiInputType, moderated: bool = False) -> None: - self.name: str = name - self.id: int = 0 - self.moderated: bool = moderated - self.emoji: PartialEmoji - if isinstance(emoji, _EmojiTag): - self.emoji = emoji._to_partial() - elif isinstance(emoji, str): - self.emoji = PartialEmoji.from_str(emoji) - else: - raise TypeError(f"emoji must be a GuildEmoji, PartialEmoji, or str and not {emoji.__class__!r}") - - def __repr__(self) -> str: - return f"" - - def __str__(self) -> str: - return self.name - - @classmethod - def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag: - self = cls.__new__(cls) - self.name = data["name"] - self.id = int(data["id"]) - self.moderated = data.get("moderated", False) - - emoji_name = data["emoji_name"] or "" - emoji_id = get_as_snowflake(data, "emoji_id") or None - self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) - return self - - def to_dict(self) -> dict[str, Any]: - payload: dict[str, Any] = { - "name": self.name, - "moderated": self.moderated, - } | self.emoji._to_forum_reaction_payload() - - if self.id: - payload["id"] = self.id - - return payload + from ..abc import Snowflake, SnowflakeTime + from ..app.state import ConnectionState + from ..embeds import Embed + from ..guild import Guild + from ..guild import GuildChannel as GuildChannelType + from ..member import Member, VoiceState + from ..mentions import AllowedMentions + from ..message import EmojiInputType, Message, PartialMessage + from ..role import Role + from ..sticker import GuildSticker, StickerItem + from ..types.channel import CategoryChannel as CategoryChannelPayload + from ..types.channel import DMChannel as DMChannelPayload + from ..types.channel import ForumChannel as ForumChannelPayload + from ..types.channel import ForumTag as ForumTagPayload + from ..types.channel import GroupDMChannel as GroupChannelPayload + from ..types.channel import StageChannel as StageChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from ..types.channel import VoiceChannel as VoiceChannelPayload + from ..types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend + from ..types.snowflake import SnowflakeList + from ..types.threads import ThreadArchiveDuration + from ..ui.view import View + from ..user import BaseUser, ClientUser, User + from ..webhook import Webhook class _TextChannel(discord.abc.GuildChannel, Hashable): @@ -219,13 +139,31 @@ class _TextChannel(discord.abc.GuildChannel, Hashable): def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: TextChannelPayload | ForumChannelPayload, + state: ConnectionState, ): + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) - self.guild = guild + self.id: int = id + self.guild: Guild = guild + + @classmethod + async def _from_data( + cls, + *, + data: TextChannelPayload | ForumChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self @property def _repr_attrs(self) -> tuple[str, ...]: @@ -237,6 +175,7 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {joined}>" async def _update(self, data: TextChannelPayload | ForumChannelPayload) -> None: + """Update mutable attributes from API payload.""" # This data will always exist self.name: str = data["name"] self.category_id: int | None = get_as_snowflake(data, "parent_id") @@ -659,7 +598,7 @@ def archived_threads( ) -class TextChannel(discord.abc.Messageable, _TextChannel): +class TextChannel(discord.abc.Messageable, ForumChannel): """Represents a Discord text channel. .. container:: operations @@ -723,15 +662,33 @@ class TextChannel(discord.abc.Messageable, _TextChannel): .. versionadded:: 2.3 """ - def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): - super().__init__(state=state, guild=guild, data=data) + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state) + + @classmethod + async def _from_data( + cls, + *, + data: TextChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self @property def _repr_attrs(self) -> tuple[str, ...]: return super()._repr_attrs + ("news",) async def _update(self, data: TextChannelPayload) -> None: - super()._update(data) + await super()._update(data) async def _get_channel(self) -> TextChannel: return self @@ -837,7 +794,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_thread( self, @@ -1002,11 +959,31 @@ class ForumChannel(_TextChannel): .. versionadded:: 2.5 """ - def __init__(self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload): - super().__init__(state=state, guild=guild, data=data) + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state) + @classmethod + @override + async def _from_data( + cls, + *, + data: ForumChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self + + @override async def _update(self, data: ForumChannelPayload) -> None: - super()._update(data) + await super()._update(data) self.available_tags: list[ForumTag] = [ ForumTag.from_data(state=self._state, data=tag) for tag in (data.get("available_tags") or []) ] @@ -1154,7 +1131,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_thread( self, @@ -1520,7 +1497,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): @@ -1545,14 +1522,34 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: VoiceChannelPayload | StageChannelPayload, + state: ConnectionState, + type: int | ChannelType, ): + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id: int = id self.guild = guild - self._update(data) + self._type: int = int(type) + + @classmethod + async def _from_data( + cls, + *, + data: VoiceChannelPayload | StageChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + type=data["type"], + ) + await self._update(data) + return self def _get_voice_client_key(self) -> tuple[int, str]: return self.guild.id, "guild_id" @@ -1704,15 +1701,35 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: VoiceChannelPayload, + state: ConnectionState, + type: int | ChannelType, ): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state, type=type) self.status: str | None = None - super().__init__(state=state, guild=guild, data=data) + + @classmethod + async def _from_data( + cls, + *, + data: VoiceChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + type=data["type"], + ) + await self._update(data) + return self async def _update(self, data: VoiceChannelPayload): - super()._update(data) + await super()._update(data) if data.get("status"): self.status = data.get("status") @@ -2084,7 +2101,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_activity_invite(self, activity: EmbeddedActivity | int, **kwargs) -> Invite: """|coro| @@ -2252,7 +2269,7 @@ class StageChannel(discord.abc.Messageable, VocalGuildChannel): __slots__ = ("topic",) async def _update(self, data: StageChannelPayload) -> None: - super()._update(data) + await super()._update(data) self.topic = data.get("topic") def __repr__(self) -> str: @@ -2734,7 +2751,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore class CategoryChannel(discord.abc.GuildChannel, Hashable): @@ -2778,22 +2795,30 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): .. versionadded:: 2.0 """ - __slots__ = ( - "name", - "id", - "guild", - "_state", - "position", - "_overwrites", - "category_id", - "flags", - ) + __slots__ = ("name", "id", "guild", "_state", "position", "_overwrites", "category_id", "flags", "_type") - def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): + def __init__(self, *, id: int, guild: Guild, state: ConnectionState) -> None: + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id: int = id self.guild = guild - self._update(data) + + @classmethod + async def _from_data( + cls, + *, + data: CategoryChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self def __repr__(self) -> str: return f"" @@ -2801,6 +2826,7 @@ def __repr__(self) -> str: async def _update(self, data: CategoryChannelPayload) -> None: # This data will always exist self.name: str = data["name"] + self._type: int = data["type"] self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated @@ -2816,7 +2842,7 @@ def _sorting_bucket(self) -> int: @property def type(self) -> ChannelType: """The channel's Discord type.""" - return ChannelType.category + return try_enum(ChannelType, self._type) @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: @@ -2879,7 +2905,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore @copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): @@ -3023,21 +3049,28 @@ class DMChannel(discord.abc.Messageable, Hashable): The direct message channel ID. """ - __slots__ = ("id", "recipient", "me", "_state") + __slots__ = ("id", "recipient", "me", "_state", "_type") - def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): + def __init__(self, *, me: ClientUser, state: ConnectionState, id: int) -> None: + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self._recipients = data.get("recipients") self.recipient: User | None = None self.me: ClientUser = me - self.id: int = int(data["id"]) - # there shouldn't be any point in time where a DM channel - # is made without the event loop having started - asyncio.create_task(self._load()) + self.id: int = id + + @classmethod + async def _from_data(cls, *, data: DMChannelPayload, state: ConnectionState, me: ClientUser) -> DMChannel: + """Create channel instance from API payload.""" + self = cls(me=me, state=state, id=int(data["id"])) + await self._update(data) + return self - async def _load(self) -> None: - if r := self._recipients: - self.recipient = await self._state.store_user(r[0]) + async def _update(self, data: DMChannelPayload) -> None: + """Update mutable attributes from API payload.""" + recipients = data.get("recipients", []) + self._type = data["type"] + if recipients: + self.recipient = await self._state.store_user(recipients[0]) async def _get_channel(self): return self @@ -3063,7 +3096,7 @@ def _from_message(cls: type[DMC], state: ConnectionState, channel_id: int) -> DM @property def type(self) -> ChannelType: """The channel's Discord type.""" - return ChannelType.private + return try_enum(ChannelType, self._type) @property def jump_url(self) -> str: @@ -3327,7 +3360,7 @@ class PartialMessageable(discord.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__(self, state: ConnectionState, id: int, type: ChannelType | None = None): + def __init__(self, state: ConnectionState, id: int): self._state: ConnectionState = state self._channel: Object = Object(id=id) self.id: int = id @@ -3437,57 +3470,3 @@ def __init__( else None ) self.data = data - - -def _guild_channel_factory(channel_type: int): - value = try_enum(ChannelType, channel_type) - if value is ChannelType.text: - return TextChannel, value - elif value is ChannelType.voice: - return VoiceChannel, value - elif value is ChannelType.category: - return CategoryChannel, value - elif value is ChannelType.news: - return TextChannel, value - elif value is ChannelType.stage_voice: - return StageChannel, value - elif value is ChannelType.directory: - return None, value # todo: Add DirectoryChannel when applicable - elif value is ChannelType.forum: - return ForumChannel, value - elif value is ChannelType.media: - return MediaChannel, value - else: - return None, value - - -def _channel_factory(channel_type: int): - cls, value = _guild_channel_factory(channel_type) - if value is ChannelType.private: - return DMChannel, value - elif value is ChannelType.group: - return GroupChannel, value - else: - return cls, value - - -def _threaded_channel_factory(channel_type: int): - cls, value = _channel_factory(channel_type) - if value in ( - ChannelType.private_thread, - ChannelType.public_thread, - ChannelType.news_thread, - ): - return Thread, value - return cls, value - - -def _threaded_guild_channel_factory(channel_type: int): - cls, value = _guild_channel_factory(channel_type) - if value in ( - ChannelType.private_thread, - ChannelType.public_thread, - ChannelType.news_thread, - ): - return Thread, value - return cls, value diff --git a/discord/channel/dm.py b/discord/channel/dm.py new file mode 100644 index 0000000000..626ad439f2 --- /dev/null +++ b/discord/channel/dm.py @@ -0,0 +1,106 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections.abc import Collection +from typing import TYPE_CHECKING + +from typing_extensions import override + +from ..abc import Messageable, Snowflake +from ..asset import Asset +from ..permissions import Permissions +from ..types.channel import DMChannel as DMChannelPayload +from ..types.channel import GroupDMChannel as GroupDMChannelPayload +from .base import BaseChannel, P + +if TYPE_CHECKING: + from ..app.state import ConnectionState + from ..message import Message + from ..user import User + + +class DMChannel(BaseChannel[DMChannelPayload], Messageable): + __slots__: tuple[str, ...] = ("last_message", "recipient") + + def __init__(self, id: int, state: "ConnectionState") -> None: + super().__init__(id, state) + self.recipient: User | None = None + self.last_message: Message | None = None + + @override + async def _update(self, data: DMChannelPayload) -> None: + await super()._update(data) + if last_message_id := data.get("last_message_id", None): + self.last_message = await self._state.cache.get_message(int(last_message_id)) + if recipients := data.get("recipients"): + self.recipient = await self._state.cache.store_user(recipients[0]) + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel.""" + return f"https://discord.com/channels/@me/{self.id}" + + +class GroupDMChannel(BaseChannel[GroupDMChannelPayload], Messageable): + __slots__: tuple[str, ...] = ("recipients", "icon_hash", "owner", "name") + + def __init__(self, id: int, state: "ConnectionState") -> None: + super().__init__(id, state) + self.recipients: Collection[User] = set() + self.icon_hash: str | None = None + self.owner: User | None = None + + @override + async def _update(self, data: GroupDMChannelPayload) -> None: + await super()._update(data) + self.name: str = data["name"] + if recipients := data.get("recipients"): + self.recipients = {await self._state.cache.store_user(recipient_data) for recipient_data in recipients} + if icon_hash := data.get("icon"): + self.icon_hash = icon_hash + if owner_id := data.get("owner_id"): + self.owner = await self._state.cache.get_user(int(owner_id)) + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel.""" + return f"https://discord.com/channels/@me/{self.id}" + + @property + def icon(self) -> Asset | None: + """Returns the channel's icon asset if available.""" + if self.icon_hash is None: + return None + return Asset._from_icon(self._state, self.id, self.icon_hash, path="channel") diff --git a/discord/channel/forum.py b/discord/channel/forum.py new file mode 100644 index 0000000000..39fad7f36d --- /dev/null +++ b/discord/channel/forum.py @@ -0,0 +1,210 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, overload + +from typing_extensions import Self, override + +from ..enums import ChannelType, SortOrder +from ..flags import ChannelFlags +from ..utils import MISSING, Undefined +from .base import GuildPostableChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..emoji import GuildEmoji + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import ForumChannel as ForumChannelPayload + from .category import CategoryChannel + from .channel import ForumTag + +__all__ = ("ForumChannel",) + + +class ForumChannel(GuildPostableChannel["ForumChannelPayload"]): + """Represents a Discord forum channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between creating posts + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID sent to this channel. It may not point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for posts created in this channel. + default_thread_slowmode_delay: :class:`int` | None + The initial slowmode delay to set on newly created posts in this channel. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this forum channel. + default_sort_order: :class:`SortOrder` | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: :class:`str` | :class:`GuildEmoji` | None + The default forum reaction emoji. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = () + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.forum.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @overload + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + default_sort_order: SortOrder | Undefined = MISSING, + default_reaction_emoji: GuildEmoji | int | str | None | Undefined = MISSING, + available_tags: list[ForumTag] | Undefined = MISSING, + require_tag: bool | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self: ... + + @overload + async def edit(self) -> Self: ... + + async def edit(self, *, reason: str | None = None, **options) -> Self: + """|coro| + + Edits the forum channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic/guidelines. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel should be marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for posts created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for posts created in this channel. + default_sort_order: :class:`SortOrder` + The default sort order type to use to order posts in this channel. + default_reaction_emoji: :class:`GuildEmoji` | :class:`int` | :class:`str` | None + The default reaction emoji for posts. + Can be a unicode emoji or a custom emoji. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this channel. Must be less than ``20``. + require_tag: :class:`bool` + Whether a tag should be required to be specified when creating a post in this channel. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.ForumChannel` + The newly edited forum channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + if "require_tag" in options: + options["flags"] = ChannelFlags._from_value(self.flags.value) + options["flags"].require_tag = options.pop("require_tag") + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return self diff --git a/discord/channel/media.py b/discord/channel/media.py new file mode 100644 index 0000000000..64b4ea5620 --- /dev/null +++ b/discord/channel/media.py @@ -0,0 +1,227 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, overload + +from typing_extensions import Self, override + +from ..enums import ChannelType, SortOrder +from ..flags import ChannelFlags +from ..utils import MISSING, Undefined +from .base import GuildPostableChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..emoji import GuildEmoji + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import MediaChannel as MediaChannelPayload + from .category import CategoryChannel + from .channel import ForumTag + +__all__ = ("MediaChannel",) + + +class MediaChannel(GuildPostableChannel["MediaChannelPayload"]): + """Represents a Discord media channel. + + .. versionadded:: 2.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between creating posts + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID sent to this channel. It may not point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for posts created in this channel. + default_thread_slowmode_delay: :class:`int` | None + The initial slowmode delay to set on newly created posts in this channel. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this media channel. + default_sort_order: :class:`SortOrder` | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: :class:`str` | :class:`GuildEmoji` | None + The default reaction emoji. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = () + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.media.value + + @property + def media_download_options_hidden(self) -> bool: + """Whether media download options are hidden in this media channel. + + .. versionadded:: 2.7 + """ + return self.flags.hide_media_download_options + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @overload + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + default_sort_order: SortOrder | Undefined = MISSING, + default_reaction_emoji: GuildEmoji | int | str | None | Undefined = MISSING, + available_tags: list[ForumTag] | Undefined = MISSING, + require_tag: bool | Undefined = MISSING, + hide_media_download_options: bool | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self: ... + + @overload + async def edit(self) -> Self: ... + + async def edit(self, *, reason: str | None = None, **options) -> Self: + """|coro| + + Edits the media channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic/guidelines. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel should be marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for posts created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for posts created in this channel. + default_sort_order: :class:`SortOrder` + The default sort order type to use to order posts in this channel. + default_reaction_emoji: :class:`GuildEmoji` | :class:`int` | :class:`str` | None + The default reaction emoji for posts. + Can be a unicode emoji or a custom emoji. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this channel. Must be less than ``20``. + require_tag: :class:`bool` + Whether a tag should be required to be specified when creating a post in this channel. + hide_media_download_options: :class:`bool` + Whether to hide the media download options in this media channel. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.MediaChannel` + The newly edited media channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + # Handle require_tag flag + if "require_tag" in options or "hide_media_download_options" in options: + options["flags"] = ChannelFlags._from_value(self.flags.value) + if "require_tag" in options: + options["flags"].require_tag = options.pop("require_tag") + if "hide_media_download_options" in options: + options["flags"].hide_media_download_options = options.pop("hide_media_download_options") + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return self diff --git a/discord/channel/news.py b/discord/channel/news.py new file mode 100644 index 0000000000..79fd98f764 --- /dev/null +++ b/discord/channel/news.py @@ -0,0 +1,282 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..enums import ChannelType +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildThreadableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from .category import CategoryChannel + from .thread import Thread + +__all__ = ("NewsChannel",) + + +class NewsChannel( + GuildTopLevelChannel["NewsChannelPayload"], + GuildMessageableChannel, + GuildThreadableChannel, +): + """Represents a Discord guild news/announcement channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + ) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.news.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + type: ChannelType | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self | TextChannel: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for threads created in this channel. + type: :class:`ChannelType` + Change the type of this news channel. Only conversion between text and news is supported. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.NewsChannel` | :class:`.TextChannel` + The newly edited channel. If type was changed, the appropriate channel type is returned. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if nsfw is not MISSING: + options["nsfw"] = nsfw + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if default_auto_archive_duration is not MISSING: + options["default_auto_archive_duration"] = default_auto_archive_duration + if default_thread_slowmode_delay is not MISSING: + options["default_thread_slowmode_delay"] = default_thread_slowmode_delay + if type is not MISSING: + options["type"] = type + if overwrites is not MISSING: + options["overwrites"] = overwrites + + payload = await self._edit(options, reason=reason) + if payload is not None: + if payload.get("type") == ChannelType.text.value: + from .text import TextChannel + + return await TextChannel._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_thread( + self, + *, + name: str, + message: Snowflake | None = None, + auto_archive_duration: int | Undefined = MISSING, + type: ChannelType | None = None, + slowmode_delay: int | None = None, + invitable: bool | None = None, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this news channel. + + Parameters + ---------- + name: :class:`str` + The name of the thread. + message: :class:`abc.Snowflake` | None + A snowflake representing the message to create the thread with. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + type: :class:`ChannelType` | None + The type of thread to create. + slowmode_delay: :class:`int` | None + Specifies the slowmode rate limit for users in this thread, in seconds. + invitable: :class:`bool` | None + Whether non-moderators can add other non-moderators to this thread. + reason: :class:`str` | None + The reason for creating a new thread. + + Returns + ------- + :class:`Thread` + The created thread + """ + from .thread import Thread + + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + type=type.value, + rate_limit_per_user=slowmode_delay or 0, + invitable=invitable, + reason=reason, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or 0, + reason=reason, + ) + + return Thread(guild=self.guild, state=self._state, data=data) diff --git a/discord/channel/partial.py b/discord/channel/partial.py new file mode 100644 index 0000000000..81058cd40e --- /dev/null +++ b/discord/channel/partial.py @@ -0,0 +1,104 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..abc import Messageable +from ..enums import ChannelType +from ..mixins import Hashable +from ..object import Object + +if TYPE_CHECKING: + from ..message import PartialMessage + from ..state import ConnectionState + +__all__ = ("PartialMessageable",) + + +class PartialMessageable(Messageable, Hashable): + """Represents a partial messageable to aid with working messageable channels when + only a channel ID are present. + + The only way to construct this class is through :meth:`Client.get_partial_messageable`. + + Note that this class is trimmed down and has no rich attributes. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two partial messageables are equal. + + .. describe:: x != y + + Checks if two partial messageables are not equal. + + .. describe:: hash(x) + + Returns the partial messageable's hash. + + Attributes + ---------- + id: :class:`int` + The channel ID associated with this partial messageable. + type: Optional[:class:`ChannelType`] + The channel type associated with this partial messageable, if given. + """ + + def __init__(self, state: ConnectionState, id: int, type: ChannelType | None = None): + self._state: ConnectionState = state + self._channel: Object = Object(id=id) + self.id: int = id + self.type: ChannelType | None = type + + async def _get_channel(self) -> Object: + return self._channel + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + from ..message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + def __repr__(self) -> str: + return f"" diff --git a/discord/channel/stage.py b/discord/channel/stage.py new file mode 100644 index 0000000000..8c96a666b8 --- /dev/null +++ b/discord/channel/stage.py @@ -0,0 +1,345 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..abc import Connectable +from ..enums import ChannelType, StagePrivacyLevel, VideoQualityMode, VoiceRegion, try_enum +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..stage_instance import StageInstance + from ..types.channel import StageChannel as StageChannelPayload + from .category import CategoryChannel + +__all__ = ("StageChannel",) + + +class StageChannel( + GuildTopLevelChannel["StageChannelPayload"], + GuildMessageableChannel, + Connectable, +): + """Represents a Discord guild stage channel. + + .. versionadded:: 1.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a stage channel. + A value of ``0`` indicates no limit. + rtc_region: :class:`VoiceRegion` | None + The region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + last_message_id: :class:`int` | None + The ID of the last message sent to this channel. It may not always point to an existing or valid message. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "bitrate", + "user_limit", + "rtc_region", + "video_quality_mode", + ) + + @override + async def _update(self, data: StageChannelPayload) -> None: + await super()._update(data) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + rtc = data.get("rtc_region") + self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.stage_voice.value + + @property + def requesting_to_speak(self) -> list[Member]: + """A list of members who are requesting to speak in the stage channel.""" + return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] + + @property + def speakers(self) -> list[Member]: + """A list of members who have been permitted to speak in the stage channel. + + .. versionadded:: 2.0 + """ + return [member for member in self.members if member.voice and not member.voice.suppress] + + @property + def listeners(self) -> list[Member]: + """A list of members who are listening in the stage channel. + + .. versionadded:: 2.0 + """ + return [member for member in self.members if member.voice and member.voice.suppress] + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("topic", self.topic), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @property + def instance(self) -> StageInstance | None: + """Returns the currently running stage instance if any. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`StageInstance` | None + The stage instance or ``None`` if not active. + """ + return self.guild.get_stage_instance(self.id) + + @property + def moderators(self) -> list[Member]: + """Returns a list of members who have stage moderator permissions. + + .. versionadded:: 2.0 + + Returns + ------- + list[:class:`Member`] + The members with stage moderator permissions. + """ + from ..permissions import Permissions + + required = Permissions.stage_moderator() + return [m for m in self.members if (self.permissions_for(m) & required) == required] + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + rtc_region: VoiceRegion | None | Undefined = MISSING, + video_quality_mode: VideoQualityMode | Undefined = MISSING, + reason: str | None = None, + ) -> Self: + """|coro| + + Edits the stage channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + rtc_region: :class:`VoiceRegion` | None + The new region for the stage channel's voice communication. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.StageChannel` + The newly edited stage channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if overwrites is not MISSING: + options["overwrites"] = overwrites + if rtc_region is not MISSING: + options["rtc_region"] = rtc_region + if video_quality_mode is not MISSING: + options["video_quality_mode"] = video_quality_mode + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_instance( + self, + *, + topic: str, + privacy_level: StagePrivacyLevel = StagePrivacyLevel.guild_only, + reason: str | None = None, + send_notification: bool = False, + ) -> StageInstance: + """|coro| + + Creates a stage instance. + + You must have :attr:`~Permissions.manage_channels` permission to do this. + + Parameters + ---------- + topic: :class:`str` + The stage instance's topic. + privacy_level: :class:`StagePrivacyLevel` + The stage instance's privacy level. + send_notification: :class:`bool` + Whether to send a notification to everyone in the server that the stage is starting. + reason: :class:`str` | None + The reason for creating the stage instance. Shows up on the audit log. + + Returns + ------- + :class:`StageInstance` + The created stage instance. + + Raises + ------ + Forbidden + You do not have permissions to create a stage instance. + HTTPException + Creating the stage instance failed. + """ + from ..stage_instance import StageInstance + + payload = await self._state.http.create_stage_instance( + self.id, + topic=topic, + privacy_level=int(privacy_level), + send_start_notification=send_notification, + reason=reason, + ) + return StageInstance(guild=self.guild, state=self._state, data=payload) + + async def fetch_instance(self) -> StageInstance | None: + """|coro| + + Fetches the currently running stage instance. + + Returns + ------- + :class:`StageInstance` | None + The stage instance or ``None`` if not active. + + Raises + ------ + NotFound + The stage instance is not active or was deleted. + HTTPException + Fetching the stage instance failed. + """ + from ..stage_instance import StageInstance + + try: + payload = await self._state.http.get_stage_instance(self.id) + return StageInstance(guild=self.guild, state=self._state, data=payload) + except Exception: + return None diff --git a/discord/channel/text.py b/discord/channel/text.py new file mode 100644 index 0000000000..81149d7b52 --- /dev/null +++ b/discord/channel/text.py @@ -0,0 +1,326 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..enums import ChannelType +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildThreadableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from .category import CategoryChannel + from .thread import Thread + +__all__ = ("TextChannel",) + + +class TextChannel( + GuildTopLevelChannel["TextChannelPayload"], + GuildMessageableChannel, + GuildThreadableChannel, +): + """Represents a Discord guild text channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + ) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.text.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + type: ChannelType | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self | NewsChannel: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 1.4 + The ``type`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + .. versionchanged:: 3.0 + The ``default_thread_slowmode_delay`` keyword-only parameter was added. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for threads created in this channel. + type: :class:`ChannelType` + Change the type of this text channel. Currently, only conversion between + :attr:`ChannelType.text` and :attr:`ChannelType.news` is supported. This + is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.TextChannel` | :class:`.NewsChannel` + The newly edited channel. If the edit was only positional + then ``None`` is returned instead. If the type was changed, + the appropriate channel type is returned. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of channels, or if + the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if nsfw is not MISSING: + options["nsfw"] = nsfw + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if default_auto_archive_duration is not MISSING: + options["default_auto_archive_duration"] = default_auto_archive_duration + if default_thread_slowmode_delay is not MISSING: + options["default_thread_slowmode_delay"] = default_thread_slowmode_delay + if type is not MISSING: + options["type"] = type + if overwrites is not MISSING: + options["overwrites"] = overwrites + + payload = await self._edit(options, reason=reason) + if payload is not None: + # Check if type was changed to news + if payload.get("type") == ChannelType.news.value: + from .news import NewsChannel + + return await NewsChannel._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_thread( + self, + *, + name: str, + message: Snowflake | None = None, + auto_archive_duration: int | Undefined = MISSING, + type: ChannelType | None = None, + slowmode_delay: int | None = None, + invitable: bool | None = None, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this text channel. + + To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`. + For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the thread. + message: :class:`abc.Snowflake` | None + A snowflake representing the message to create the thread with. + If ``None`` is passed then a private thread is created. + Defaults to ``None``. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + type: :class:`ChannelType` | None + The type of thread to create. If a ``message`` is passed then this parameter + is ignored, as a thread created with a message is always a public thread. + By default, this creates a private thread if this is ``None``. + slowmode_delay: :class:`int` | None + Specifies the slowmode rate limit for users in this thread, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + invitable: :class:`bool` | None + Whether non-moderators can add other non-moderators to this thread. + Only available for private threads, where it defaults to True. + reason: :class:`str` | None + The reason for creating a new thread. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created thread + + Raises + ------ + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. + """ + from .thread import Thread + + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + type=type.value, + rate_limit_per_user=slowmode_delay or 0, + invitable=invitable, + reason=reason, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or 0, + reason=reason, + ) + + return Thread(guild=self.guild, state=self._state, data=data) diff --git a/discord/threads.py b/discord/channel/thread.py similarity index 84% rename from discord/threads.py rename to discord/channel/thread.py index 1f2a7600a4..920e65ab8a 100644 --- a/discord/threads.py +++ b/discord/channel/thread.py @@ -27,19 +27,24 @@ from typing import TYPE_CHECKING, Callable, Iterable +from typing_extensions import override + from discord import utils -from .abc import Messageable, _purge_messages_helper -from .enums import ( +from ..abc import Messageable, _purge_messages_helper +from ..enums import ( ChannelType, try_enum, ) -from .enums import ThreadArchiveDuration as ThreadArchiveDurationEnum -from .errors import ClientException -from .flags import ChannelFlags -from .mixins import Hashable -from .utils import MISSING -from .utils.private import get_as_snowflake, parse_time +from ..enums import ThreadArchiveDuration as ThreadArchiveDurationEnum +from .base import BaseChannel, GuildMessageableChannel +from ..errors import ClientException +from ..flags import ChannelFlags +from ..mixins import Hashable +from ..types.threads import Thread as ThreadPayload +from ..utils import MISSING +from ..utils.private import get_as_snowflake, parse_time +from .base import GuildMessageableChannel __all__ = ( "Thread", @@ -47,21 +52,20 @@ ) if TYPE_CHECKING: - from .abc import Snowflake, SnowflakeTime - from .app.state import ConnectionState - from .channel import CategoryChannel, ForumChannel, ForumTag, TextChannel - from .guild import Guild - from .member import Member - from .message import Message, PartialMessage - from .permissions import Permissions - from .role import Role - from .types.snowflake import SnowflakeList - from .types.threads import Thread as ThreadPayload - from .types.threads import ThreadArchiveDuration, ThreadMetadata - from .types.threads import ThreadMember as ThreadMemberPayload - - -class Thread(Messageable, Hashable): + from ..abc import Snowflake, SnowflakeTime + from ..app.state import ConnectionState + from ..guild import Guild + from ..member import Member + from ..message import Message, PartialMessage + from ..permissions import Permissions + from ..role import Role + from ..types.snowflake import SnowflakeList + from ..types.threads import ThreadArchiveDuration, ThreadMetadata + from ..types.threads import ThreadMember as ThreadMemberPayload + from . import CategoryChannel, ForumChannel, ForumTag, TextChannel + + +class Thread(BaseChannel[ThreadPayload], GuildMessageableChannel): """Represents a Discord thread. .. container:: operations @@ -86,55 +90,55 @@ class Thread(Messageable, Hashable): Attributes ---------- - name: :class:`str` + name: str The thread name. - guild: :class:`Guild` + guild: Guild The guild the thread belongs to. - id: :class:`int` + id: int The thread ID. .. note:: This ID is the same as the thread starting message ID. - parent_id: :class:`int` + parent_id: int The parent :class:`TextChannel` ID this thread belongs to. - owner_id: :class:`int` + owner_id: int The user's ID that created this thread. - last_message_id: Optional[:class:`int`] + last_message_id: int | None The last message ID of the message sent to this thread. It may *not* point to an existing or valid message. - slowmode_delay: :class:`int` + slowmode_delay: int The number of seconds a member must wait between sending messages in this thread. A value of `0` denotes that it is disabled. Bots and users with :attr:`~Permissions.manage_channels` or :attr:`~Permissions.manage_messages` bypass slowmode. - message_count: :class:`int` + message_count: int An approximate number of messages in this thread. This caps at 50. - member_count: :class:`int` + member_count: int An approximate number of members in this thread. This caps at 50. - me: Optional[:class:`ThreadMember`] + me: ThreadMember | None A thread member representing yourself, if you've joined the thread. This could not be available. - archived: :class:`bool` + archived: bool Whether the thread is archived. - locked: :class:`bool` + locked: bool Whether the thread is locked. - invitable: :class:`bool` + invitable: bool Whether non-moderators can add other non-moderators to this thread. This is always ``True`` for public threads. - auto_archive_duration: :class:`int` + auto_archive_duration: int The duration in minutes until the thread is automatically archived due to inactivity. Usually a value of 60, 1440, 4320 and 10080. - archive_timestamp: :class:`datetime.datetime` + archive_timestamp: datetime.datetime An aware timestamp of when the thread's archived status was last updated in UTC. - created_at: Optional[:class:`datetime.datetime`] + created_at: datetime.datetime | None An aware timestamp of when the thread was created. Only available for threads created after 2022-01-09. - flags: :class:`ChannelFlags` + flags: ChannelFlags Extra features of the thread. .. versionadded:: 2.0 - total_message_sent: :class:`int` + total_message_sent: int Number of messages ever sent in a thread. It's similar to message_count on message creation, but will not decrement the number when a message is deleted. @@ -142,20 +146,14 @@ class Thread(Messageable, Hashable): .. versionadded:: 2.3 """ - __slots__ = ( - "name", - "id", + __slots__: tuple[str, ...] = ( "guild", - "_type", - "_state", "_members", "_applied_tags", "owner_id", "parent_id", - "last_message_id", "message_count", "member_count", - "slowmode_delay", "me", "locked", "archived", @@ -163,86 +161,83 @@ class Thread(Messageable, Hashable): "auto_archive_duration", "archive_timestamp", "created_at", - "flags", "total_message_sent", ) - def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): - self._state: ConnectionState = state - self.guild = guild + @override + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + super().__init__(id, state) + self.guild: Guild = guild self._members: dict[int, ThreadMember] = {} - self._from_data(data) + + @classmethod + @override + async def _from_data( + cls, + *, + data: ThreadPayload, + state: ConnectionState, + guild: Guild, + ) -> Thread: + """Create thread instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self + + @override + async def _update(self, data: ThreadPayload) -> None: + """Update mutable attributes from API payload.""" + await super()._update(data) + + # Thread-specific attributes + self.parent_id: int = int(data.get("parent_id", self.parent_id if hasattr(self, "parent_id") else 0)) + self.owner_id: int | None = int(data["owner_id"]) if data.get("owner_id") is not None else None + self.message_count: int | None = data.get("message_count") + self.member_count: int | None = data.get("member_count") + self.total_message_sent: int | None = data.get("total_message_sent") + self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] + + # Handle thread metadata + if "thread_metadata" in data: + metadata = data["thread_metadata"] + self.archived: bool = metadata["archived"] + self.auto_archive_duration: int = metadata["auto_archive_duration"] + self.archive_timestamp = parse_time(metadata["archive_timestamp"]) + self.locked: bool = metadata["locked"] + self.invitable: bool = metadata.get("invitable", True) + self.created_at = parse_time(metadata.get("create_timestamp")) + + # Handle thread member data + if "member" in data: + self.me: ThreadMember | None = ThreadMember(self, data["member"]) + elif not hasattr(self, "me"): + self.me = None async def _get_channel(self): return self + @override def __repr__(self) -> str: return ( f"" ) - def __str__(self) -> str: - return self.name - - def _from_data(self, data: ThreadPayload): - # This data will always exist - self.id = int(data["id"]) - self.parent_id = int(data["parent_id"]) - self.name = data["name"] - self._type = try_enum(ChannelType, data["type"]) - - # This data may be missing depending on how this object is being created - self.owner_id = int(data.get("owner_id")) if data.get("owner_id", None) is not None else None - self.last_message_id = get_as_snowflake(data, "last_message_id") - self.slowmode_delay = data.get("rate_limit_per_user", 0) - self.message_count = data.get("message_count", None) - self.member_count = data.get("member_count", None) - self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) - self.total_message_sent = data.get("total_message_sent", None) - self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] - - # Here, we try to fill in potentially missing data - if thread := self.guild.get_thread(self.id) and data.pop("_invoke_flag", False): - self.owner_id = thread.owner_id if self.owner_id is None else self.owner_id - self.last_message_id = thread.last_message_id if self.last_message_id is None else self.last_message_id - self.message_count = thread.message_count if self.message_count is None else self.message_count - self.total_message_sent = ( - thread.total_message_sent if self.total_message_sent is None else self.total_message_sent - ) - self.member_count = thread.member_count if self.member_count is None else self.member_count - - self._unroll_metadata(data["thread_metadata"]) - - try: - member = data["member"] - except KeyError: - self.me = None - else: - self.me = ThreadMember(self, member) - - def _unroll_metadata(self, data: ThreadMetadata): - self.archived = data["archived"] - self.auto_archive_duration = data["auto_archive_duration"] - self.archive_timestamp = parse_time(data["archive_timestamp"]) - self.locked = data["locked"] - self.invitable = data.get("invitable", True) - self.created_at = parse_time(data.get("create_timestamp", None)) - - async def _update(self, data): - try: - self.name = data["name"] - except KeyError: - pass - - self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] - self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) - self.slowmode_delay = data.get("rate_limit_per_user", 0) + @property + def topic(self) -> None: + """Threads don't have topics. Always returns None.""" + return None - try: - self._unroll_metadata(data["thread_metadata"]) - except KeyError: - pass + @property + @override + def nsfw(self) -> bool: + """Whether the thread is NSFW. Inherited from parent channel.""" + parent = self.parent + return parent.nsfw if parent else False @property def type(self) -> ChannelType: @@ -654,7 +649,7 @@ async def edit( data = await self._state.http.edit_channel(self.id, **payload, reason=reason) # The data payload will always be a Thread payload - return Thread(data=data, state=self._state, guild=self.guild) # type: ignore + return await Thread._from_data(data=data, state=self._state, guild=self.guild) # type: ignore async def archive(self, locked: bool | utils.Undefined = MISSING) -> Thread: """|coro| diff --git a/discord/channel/voice.py b/discord/channel/voice.py new file mode 100644 index 0000000000..2841ac5290 --- /dev/null +++ b/discord/channel/voice.py @@ -0,0 +1,328 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..abc import Connectable +from ..enums import ChannelType, InviteTarget, VideoQualityMode, VoiceRegion, try_enum +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..enums import EmbeddedActivity + from ..invite import Invite + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..soundboard import PartialSoundboardSound + from ..types.channel import VoiceChannel as VoiceChannelPayload + from .category import CategoryChannel + +__all__ = ("VoiceChannel",) + + +class VoiceChannel( + GuildTopLevelChannel["VoiceChannelPayload"], + GuildMessageableChannel, + Connectable, +): + """Represents a Discord guild voice channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a voice channel. + A value of ``0`` indicates no limit. + rtc_region: :class:`VoiceRegion` | None + The region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + last_message_id: :class:`int` | None + The ID of the last message sent to this channel. It may not always point to an existing or valid message. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + status: :class:`str` | None + The channel's status, if set. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "bitrate", + "user_limit", + "rtc_region", + "video_quality_mode", + "status", + ) + + @override + async def _update(self, data: VoiceChannelPayload) -> None: + await super()._update(data) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + rtc = data.get("rtc_region") + self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + self.status: str | None = data.get("status") + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.voice.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("status", self.status), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + bitrate: int | Undefined = MISSING, + user_limit: int | Undefined = MISSING, + position: int | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + rtc_region: VoiceRegion | None | Undefined = MISSING, + video_quality_mode: VideoQualityMode | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + reason: str | None = None, + ) -> Self: + """|coro| + + Edits the voice channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + bitrate: :class:`int` + The new channel's bitrate. + user_limit: :class:`int` + The new channel's user limit. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + rtc_region: :class:`VoiceRegion` | None + The new region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.VoiceChannel` + The newly edited voice channel. If the edit was only positional then ``None`` is returned. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if bitrate is not MISSING: + options["bitrate"] = bitrate + if user_limit is not MISSING: + options["user_limit"] = user_limit + if position is not MISSING: + options["position"] = position + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if overwrites is not MISSING: + options["overwrites"] = overwrites + if rtc_region is not MISSING: + options["rtc_region"] = rtc_region + if video_quality_mode is not MISSING: + options["video_quality_mode"] = video_quality_mode + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if nsfw is not MISSING: + options["nsfw"] = nsfw + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_activity_invite(self, activity: EmbeddedActivity | int, **kwargs) -> Invite: + """|coro| + + A shortcut method that creates an instant activity invite. + + You must have :attr:`~discord.Permissions.start_embedded_activities` permission to do this. + + Parameters + ---------- + activity: :class:`EmbeddedActivity` | :class:`int` + The embedded activity to create an invite for. Can be an :class:`EmbeddedActivity` enum member + or the application ID as an integer. + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite doesn't expire. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there are unlimited uses. + temporary: :class:`bool` + Denotes that the invite grants temporary membership. + unique: :class:`bool` + Indicates if a unique invite URL should be created. + reason: :class:`str` | None + The reason for creating this invite. Shows up on the audit log. + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + HTTPException + Invite creation failed. + """ + from ..enums import EmbeddedActivity # noqa: PLC0415 + + if isinstance(activity, EmbeddedActivity): + activity = activity.value + + return await self.create_invite( + target_type=InviteTarget.embedded_application, + target_application_id=activity, + **kwargs, + ) + + async def set_status(self, status: str | None, *, reason: str | None = None) -> None: + """|coro| + + Sets the voice channel status. + + You must have :attr:`~discord.Permissions.manage_channels` and + :attr:`~discord.Permissions.connect` permissions to do this. + + Parameters + ---------- + status: :class:`str` | None + The new voice channel status. Set to ``None`` to remove the status. + reason: :class:`str` | None + The reason for setting the voice channel status. Shows up on the audit log. + + Raises + ------ + Forbidden + You do not have permissions to set the voice channel status. + HTTPException + Setting the voice channel status failed. + """ + await self._state.http.edit_voice_channel_status(self.id, status, reason=reason) + + async def send_soundboard_sound(self, sound: PartialSoundboardSound) -> None: + """|coro| + + Sends a soundboard sound to the voice channel. + + Parameters + ---------- + sound: :class:`PartialSoundboardSound` + The soundboard sound to send. + + Raises + ------ + Forbidden + You do not have proper permissions to send the soundboard sound. + HTTPException + Sending the soundboard sound failed. + """ + await self._state.http.send_soundboard_sound(self.id, sound) diff --git a/discord/client.py b/discord/client.py index 71bd01751e..13a6e42803 100644 --- a/discord/client.py +++ b/discord/client.py @@ -30,6 +30,7 @@ import signal import sys import traceback +from collections.abc import Awaitable from types import TracebackType from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Generator, Sequence, TypeVar @@ -40,6 +41,7 @@ from . import utils from .activity import ActivityTypes, BaseActivity, create_activity from .app.cache import Cache, MemoryCache +from .app.event_emitter import Event from .app.state import ConnectionState from .appinfo import AppInfo, PartialAppInfo from .application_role_connection import ApplicationRoleConnectionMetadata @@ -50,6 +52,7 @@ from .errors import * from .flags import ApplicationFlags, Intents from .gateway import * +from .gears import Gear from .guild import Guild from .http import HTTPClient from .invite import Invite @@ -61,13 +64,14 @@ from .stage_instance import StageInstance from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .template import Template -from .threads import Thread +from .channel.thread import Thread from .ui.view import View from .user import ClientUser, User from .utils import MISSING from .utils.private import ( SequenceProxy, bytes_to_base64_data, + copy_doc, resolve_invite, resolve_template, ) @@ -76,8 +80,8 @@ from .widget import Widget if TYPE_CHECKING: - from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime - from .channel import DMChannel + from .abc import PrivateChannel, Snowflake, SnowflakeTime + from .channel import DMChannel, GuildChannel from .interactions import Interaction from .member import Member from .message import Message @@ -243,7 +247,6 @@ def __init__( # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop - self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: int | None = options.get("shard_id") self.shard_count: int | None = options.get("shard_count") @@ -264,7 +267,14 @@ def __init__( self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook} self._enable_debug_events: bool = options.pop("enable_debug_events", False) - self._connection: ConnectionState = self._get_state(**options) + self._connection: ConnectionState = ConnectionState( + handlers=self._handlers, + hooks=self._hooks, + http=self.http, + loop=self.loop, + cache=MemoryCache(), + **options, + ) self._connection.shard_count = self.shard_count self._closed: bool = False self._ready: asyncio.Event = asyncio.Event() @@ -272,6 +282,10 @@ def __init__( self._connection._get_client = lambda: self self._event_handlers: dict[str, list[Coro]] = {} + self._main_gear: Gear = Gear() + + self._connection.emitter.add_receiver(self._handle_event) + if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False _log.warning("PyNaCl is not installed, voice will NOT be supported") @@ -279,6 +293,9 @@ def __init__( # Used to hard-reference tasks so they don't get garbage collected (discarded with done_callbacks) self._tasks = set() + async def _handle_event(self, event: Event) -> None: + await asyncio.gather(*self._main_gear._handle_event(event)) + async def __aenter__(self) -> Client: loop = asyncio.get_running_loop() self.loop = loop @@ -298,22 +315,44 @@ async def __aexit__( if not self.is_closed(): await self.close() + # Gear methods + + @copy_doc(Gear.attach_gear) + def attach_gear(self, gear: Gear) -> None: + return self._main_gear.attach_gear(gear) + + @copy_doc(Gear.detach_gear) + def detach_gear(self, gear: Gear) -> None: + return self._main_gear.detach_gear(gear) + + @copy_doc(Gear.add_listener) + def add_listener( + self, + callback: Callable[[Event], Awaitable[None]], + *, + event: type[Event] | Undefined = MISSING, + is_instance_function: bool = False, + once: bool = False, + ) -> None: + return self._main_gear.add_listener(callback, event=event, is_instance_function=is_instance_function, once=once) + + @copy_doc(Gear.remove_listener) + def remove_listener( + self, callback: Callable[[Event], Awaitable[None]], event: type[Event] | Undefined = MISSING, is_instance_function: bool = False + ) -> None: + return self._main_gear.remove_listener(callback, event=event, is_instance_function=is_instance_function) + + @copy_doc(Gear.listen) + def listen( + self, event: type[Event] | Undefined = MISSING, once: bool = False + ) -> Callable[[Callable[[Event], Awaitable[None]]], Callable[[Event], Awaitable[None]]]: + return self._main_gear.listen(event=event, once=once) + # internals def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket: return self.ws - def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState( - dispatch=self.dispatch, - handlers=self._handlers, - hooks=self._hooks, - http=self.http, - loop=self.loop, - cache=MemoryCache(), - **options, - ) - def _handle_ready(self) -> None: self._ready.set() @@ -465,71 +504,6 @@ def _schedule_event( task.add_done_callback(self._tasks.discard) return task - def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug("Dispatching event %s", event) - method = f"on_{event}" - - listeners = self._listeners.get(event) - if listeners: - removed = [] - for i, (future, condition) in enumerate(listeners): - if future.cancelled(): - removed.append(i) - continue - - try: - result = condition(*args) - except Exception as exc: - future.set_exception(exc) - removed.append(i) - else: - if result: - if len(args) == 0: - future.set_result(None) - elif len(args) == 1: - future.set_result(args[0]) - else: - future.set_result(args) - removed.append(i) - - if len(removed) == len(listeners): - self._listeners.pop(event) - else: - for idx in reversed(removed): - del listeners[idx] - - # Schedule the main handler registered with @event - try: - coro = getattr(self, method) - except AttributeError: - pass - else: - self._schedule_event(coro, method, *args, **kwargs) - - # collect the once listeners as removing them from the list - # while iterating over it causes issues - once_listeners = [] - - # Schedule additional handlers registered with @listen - for coro in self._event_handlers.get(method, []): - self._schedule_event(coro, method, *args, **kwargs) - - try: - if coro._once: # added using @listen() - once_listeners.append(coro) - - except AttributeError: # added using @Cog.add_listener() - # https://github.com/Pycord-Development/pycord/pull/1989 - # Although methods are similar to functions, attributes can't be added to them. - # This means that we can't add the `_once` attribute in the `add_listener` method - # and can only be added using the `@listen` decorator. - - continue - - # remove the once listeners - for coro in once_listeners: - self._event_handlers[method].remove(coro) - async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: """|coro| @@ -691,7 +665,7 @@ async def connect(self, *, reconnect: bool = True) -> None: await self.ws.poll_event() except ReconnectWebSocket as e: _log.info("Got a request to %s the websocket.", e.op) - self.dispatch("disconnect") + # self.dispatch("disconnect") # TODO: dispatch event ws_params.update( sequence=self.ws.sequence, resume=e.resume, @@ -1151,8 +1125,6 @@ async def get_all_members(self) -> AsyncGenerator[Member]: for member in guild.members: yield member - # listeners/waiters - async def wait_until_ready(self) -> None: """|coro| @@ -1160,275 +1132,6 @@ async def wait_until_ready(self) -> None: """ await self._ready.wait() - def wait_for( - self, - event: str, - *, - check: Callable[..., bool] | None = None, - timeout: float | None = None, - ) -> Any: - """|coro| - - Waits for a WebSocket event to be dispatched. - - This could be used to wait for a user to reply to a message, - or to react to a message, or to edit a message in a self-contained - way. - - The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default, - it does not timeout. Note that this does propagate the - :exc:`asyncio.TimeoutError` for you in case of timeout and is provided for - ease of use. - - In case the event returns multiple arguments, a :class:`tuple` containing those - arguments is returned instead. Please check the - :ref:`documentation ` for a list of events and their - parameters. - - This function returns the **first event that meets the requirements**. - - Parameters - ---------- - event: :class:`str` - The event name, similar to the :ref:`event reference `, - but without the ``on_`` prefix, to wait for. - check: Optional[Callable[..., :class:`bool`]] - A predicate to check what to wait for. The arguments must meet the - parameters of the event being waited for. - timeout: Optional[:class:`float`] - The number of seconds to wait before timing out and raising - :exc:`asyncio.TimeoutError`. - - Returns - ------- - Any - Returns no arguments, a single argument, or a :class:`tuple` of multiple - arguments that mirrors the parameters passed in the - :ref:`event reference `. - - Raises - ------ - asyncio.TimeoutError - Raised if a timeout is provided and reached. - - Examples - -------- - - Waiting for a user reply: :: - - @client.event - async def on_message(message): - if message.content.startswith("$greet"): - channel = message.channel - await channel.send("Say hello!") - - def check(m): - return m.content == "hello" and m.channel == channel - - msg = await client.wait_for("message", check=check) - await channel.send(f"Hello {msg.author}!") - - Waiting for a thumbs up reaction from the message author: :: - - @client.event - async def on_message(message): - if message.content.startswith("$thumb"): - channel = message.channel - await channel.send("Send me that \N{THUMBS UP SIGN} reaction, mate") - - def check(reaction, user): - return user == message.author and str(reaction.emoji) == "\N{THUMBS UP SIGN}" - - try: - reaction, user = await client.wait_for("reaction_add", timeout=60.0, check=check) - except asyncio.TimeoutError: - await channel.send("\N{THUMBS DOWN SIGN}") - else: - await channel.send("\N{THUMBS UP SIGN}") - """ - - future = self.loop.create_future() - if check is None: - - def _check(*args): - return True - - check = _check - - ev = event.lower() - try: - listeners = self._listeners[ev] - except KeyError: - listeners = [] - self._listeners[ev] = listeners - - listeners.append((future, check)) - return asyncio.wait_for(future, timeout) - - # event registration - def add_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: - """The non decorator alternative to :meth:`.listen`. - - Parameters - ---------- - func: :ref:`coroutine ` - The function to call. - name: :class:`str` - The name of the event to listen for. Defaults to ``func.__name__``. - - Raises - ------ - TypeError - The ``func`` parameter is not a coroutine function. - ValueError - The ``name`` (event name) does not start with ``on_``. - - Example - ------- - - .. code-block:: python3 - - async def on_ready(): - pass - - - async def my_message(message): - pass - - - client.add_listener(on_ready) - client.add_listener(my_message, "on_message") - """ - name = func.__name__ if name is MISSING else name - - if not name.startswith("on_"): - raise ValueError("The 'name' parameter must start with 'on_'") - - if not asyncio.iscoroutinefunction(func): - raise TypeError("Listeners must be coroutines") - - if name in self._event_handlers: - self._event_handlers[name].append(func) - else: - self._event_handlers[name] = [func] - - _log.debug( - "%s has successfully been registered as a handler for event %s", - func.__name__, - name, - ) - - def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: - """Removes a listener from the pool of listeners. - - Parameters - ---------- - func - The function that was used as a listener to remove. - name: :class:`str` - The name of the event we want to remove. Defaults to - ``func.__name__``. - """ - - name = func.__name__ if name is MISSING else name - - if name in self._event_handlers: - try: - self._event_handlers[name].remove(func) - except ValueError: - pass - - def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> Callable[[Coro], Coro]: - """A decorator that registers another function as an external - event listener. Basically this allows you to listen to multiple - events from different places e.g. such as :func:`.on_ready` - - The functions being listened to must be a :ref:`coroutine `. - - Raises - ------ - TypeError - The function being listened to is not a coroutine. - ValueError - The ``name`` (event name) does not start with ``on_``. - - Example - ------- - - .. code-block:: python3 - - @client.listen() - async def on_message(message): - print("one") - - - # in some other file... - - - @client.listen("on_message") - async def my_message(message): - print("two") - - - # listen to the first event only - @client.listen("on_ready", once=True) - async def on_ready(): - print("ready!") - - Would print one and two in an unspecified order. - """ - - def decorator(func: Coro) -> Coro: - # Special case, where default should be overwritten - if name == "on_application_command_error": - return self.event(func) - - func._once = once - self.add_listener(func, name) - return func - - if asyncio.iscoroutinefunction(name): - coro = name - name = coro.__name__ - return decorator(coro) - - return decorator - - def event(self, coro: Coro) -> Coro: - """A decorator that registers an event to listen to. - - You can find more info about the events on the :ref:`documentation below `. - - The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised. - - .. note:: - - This replaces any default handlers. - Developers are encouraged to use :py:meth:`~discord.Client.listen` for adding additional handlers - instead of :py:meth:`~discord.Client.event` unless default method replacement is intended. - - Raises - ------ - TypeError - The coroutine passed is not actually a coroutine. - - Example - ------- - - .. code-block:: python3 - - @client.event - async def on_ready(): - print("Ready!") - """ - - if not asyncio.iscoroutinefunction(coro): - raise TypeError("event registered must be a coroutine function") - - setattr(self, coro.__name__, coro) - _log.debug("%s has successfully been registered as an event", coro.__name__) - return coro - async def change_presence( self, *, diff --git a/discord/commands/core.py b/discord/commands/core.py index 76a90e6d9b..6e249a8c5f 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -66,7 +66,7 @@ from ..message import Attachment, Message from ..object import Object from ..role import Role -from ..threads import Thread +from ..channel.thread import Thread from ..user import User from ..utils import MISSING, find, utcnow from ..utils.private import async_all, maybe_awaitable, warn_deprecated @@ -464,7 +464,9 @@ async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> Non wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - ctx.bot.dispatch("application_command_error", ctx, error) + ctx.bot.dispatch( + "application_command_error", ctx, error + ) # TODO: Remove this when migrating away from ApplicationContext def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) diff --git a/discord/commands/options.py b/discord/commands/options.py index a055022830..8137083553 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -33,20 +33,26 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Literal, Optional, + Sequence, Type, - TypeVar, Union, get_args, + overload, ) +from typing_extensions import TypeVar + +from discord.channel.base import BaseChannel, GuildChannel + if sys.version_info >= (3, 12): from typing import TypeAliasType else: from typing_extensions import TypeAliasType -from ..abc import GuildChannel, Mentionable +from ..abc import Mentionable from ..channel import ( CategoryChannel, DMChannel, @@ -71,36 +77,26 @@ from ..user import User InputType = ( - Type[str] - | Type[bool] - | Type[int] - | Type[float] - | Type[GuildChannel] - | Type[Thread] - | Type[Member] - | Type[User] - | Type[Attachment] - | Type[Role] - | Type[Mentionable] + type[ + str | bool | int | float | GuildChannel | Thread | Member | User | Attachment | Role | Mentionable + # | Converter + ] | SlashCommandOptionType - | Converter - | Type[Converter] - | Type[Enum] - | Type[DiscordEnum] + # | Converter ) AutocompleteReturnType = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] - T = TypeVar("T", bound=AutocompleteReturnType) - MaybeAwaitable = T | Awaitable[T] + AR_T = TypeVar("AR_T =", bound=AutocompleteReturnType) + MaybeAwaitable = AR_T | Awaitable[AR_T] AutocompleteFunction = ( Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] | Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] | Callable[ - [AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + [AutocompleteContext, Any], MaybeAwaitable[AutocompleteReturnType], ] | Callable[ - [Cog, AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + [Cog, AutocompleteContext, Any], MaybeAwaitable[AutocompleteReturnType], ] ) @@ -110,7 +106,6 @@ "ThreadOption", "Option", "OptionChoice", - "option", ) CHANNEL_TYPE_MAP = { @@ -147,7 +142,10 @@ def __init__(self, thread_type: Literal["public", "private", "news"]): self._type = type_map[thread_type] -class Option: +T = TypeVar("T", bound="str | int | float", default="str") + + +class Option(Generic[T]): """Represents a selectable option for a slash command. Attributes @@ -211,77 +209,78 @@ async def hello( .. versionadded:: 2.0 """ - input_type: SlashCommandOptionType - converter: Converter | type[Converter] | None = None - - def __init__(self, input_type: InputType = str, /, description: str | None = None, **kwargs) -> None: - self.name: str | None = kwargs.pop("name", None) - if self.name is not None: - self.name = str(self.name) - self._parameter_name = self.name # default - input_type = self._parse_type_alias(input_type) - input_type = self._strip_none_type(input_type) - self._raw_type: InputType | tuple = input_type - - enum_choices = [] - input_type_is_class = isinstance(input_type, type) - if input_type_is_class and issubclass(input_type, (Enum, DiscordEnum)): - if description is None and input_type.__doc__ is not None: - description = inspect.cleandoc(input_type.__doc__) - if description and len(description) > 100: - description = description[:97] + "..." - _log.warning( - "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", - self.name, - input_type, - ) - enum_choices = [OptionChoice(e.name, e.value) for e in input_type] - value_class = enum_choices[0].value.__class__ - if value_class in SlashCommandOptionType.__members__ and all( - isinstance(elem.value, value_class) for elem in enum_choices - ): - input_type = SlashCommandOptionType.from_datatype(enum_choices[0].value.__class__) - else: - enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] - input_type = SlashCommandOptionType.string - - self.description = description or "No description provided" - self.channel_types: list[ChannelType] = kwargs.pop("channel_types", []) + @overload + def __init__( + self, + name: str, + input_type: type[T] = str, + *, + choices: OptionChoice[T], + description: str | None = None, + channel_types: None = None, + ) -> None: ... + + @overload + def __init__( + self, + name: str, + input_type: Literal[SlashCommandOptionType.channel] = SlashCommandOptionType.channel, + *, + choices: None = None, + description: str | None = None, + channel_types: Sequence[ChannelType] | None = None, + ) -> None: ... + + def __init__( + self, + name: str, + input_type: InputType | type[T] = str, + *, + description: str | None = None, + choices: Sequence[OptionChoice[T]] | None = None, + channel_types: Sequence[ChannelType] | None = None, + ) -> None: + self.name: str = name + + self.description: str | None = description - if self.channel_types: - self.input_type = SlashCommandOptionType.channel - elif isinstance(input_type, SlashCommandOptionType): + self.choices: list[OptionChoice[T]] | None = choices + if self.choices is not None: + if len(self.choices) > 25: + raise InvalidArgument("Option choices cannot exceed 25 items.") + if not issubclass(input_type, (str, int, float)): + raise InvalidArgument("Option choices can only be used with str, int, or float input types.") + + self.channel_types: list[ChannelType] | None = list(channel_types) if channel_types is not None else None + + self.input_type: SlashCommandOptionType + + if isinstance(input_type, SlashCommandOptionType): self.input_type = input_type - else: - from ..ext.commands import Converter # noqa: PLC0415 - - if isinstance(input_type, tuple) and any(issubclass(op, ApplicationContext) for op in input_type): - input_type = next(op for op in input_type if issubclass(op, ApplicationContext)) - - if isinstance(input_type, Converter) or input_type_is_class and issubclass(input_type, Converter): - self.converter = input_type - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - try: - self.input_type = SlashCommandOptionType.from_datatype(input_type) - except TypeError as exc: - from ..ext.commands.converter import CONVERTER_MAPPING # noqa: PLC0415 - - if input_type not in CONVERTER_MAPPING: - raise exc - self.converter = CONVERTER_MAPPING[input_type] - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - if self.input_type == SlashCommandOptionType.channel: - if not isinstance(self._raw_type, tuple): - if hasattr(input_type, "__args__"): - self._raw_type = input_type.__args__ # type: ignore # Union.__args__ - else: - self._raw_type = (input_type,) - if not self.channel_types: - self.channel_types = [CHANNEL_TYPE_MAP[t] for t in self._raw_type if t is not GuildChannel] + elif issubclass(input_type, str): + self.input_type = SlashCommandOptionType.string + elif issubclass(input_type, bool): + self.input_type = SlashCommandOptionType.boolean + elif issubclass(input_type, int): + self.input_type = SlashCommandOptionType.integer + elif issubclass(input_type, float): + self.input_type = SlashCommandOptionType.number + elif issubclass(input_type, Attachment): + self.input_type = SlashCommandOptionType.attachment + elif issubclass(input_type, User): + self.input_type = SlashCommandOptionType.user + elif issubclass(input_type, Mentionable): + self.input_type = SlashCommandOptionType.mentionable + elif issubclass(input_type, Role): + self.input_type = SlashCommandOptionType.role + elif issubclass(input_type, BaseChannel): + self.input_type = SlashCommandOptionType.channel + + if self.channel_types is not None: + self.input_type = SlashCommandOptionType.channel + if len(self.channel_types) == 0: + raise InvalidArgument("channel_types must contain at least one ChannelType.") + self.required: bool = kwargs.pop("required", True) if "default" not in kwargs else False self.default = kwargs.pop("default", None) @@ -456,7 +455,7 @@ def autocomplete(self, value: AutocompleteFunction | None) -> None: ) -class OptionChoice: +class OptionChoice(Generic[T]): """ Represents a name:value pairing for a selected :class:`.Option`. @@ -466,9 +465,9 @@ class OptionChoice: ---------- name: :class:`str` The name of the choice. Shown in the UI when selecting an option. - value: Optional[Union[:class:`str`, :class:`int`, :class:`float`]] + value: :class:`str` | :class:`int` | :class:`float` The value of the choice. If not provided, will use the value of ``name``. - name_localizations: Dict[:class:`str`, :class:`str`] + name_localizations: dict[:class:`str`, :class:`str`] The name localizations for this choice. The values of this should be ``"locale": "name"``. See `here `_ for a list of valid locales. """ @@ -476,37 +475,16 @@ class OptionChoice: def __init__( self, name: str, - value: str | int | float | None = None, - name_localizations: dict[str, str] | Undefined = MISSING, + value: T | None = None, + name_localizations: dict[str, str] | None = None, ): - self.name = str(name) - self.value = value if value is not None else name - self.name_localizations = name_localizations + self.name: str = str(name) + self.value: T = value if value is not None else name # pyright: ignore [reportAttributeAccessIssue] + self.name_localizations: dict[str, str] | None = name_localizations - def to_dict(self) -> dict[str, str | int | float]: - as_dict = {"name": self.name, "value": self.value} - if self.name_localizations is not MISSING: + def to_dict(self) -> dict[str, Any]: + as_dict: dict[str, Any] = {"name": self.name, "value": self.value} + if self.name_localizations is not None: as_dict["name_localizations"] = self.name_localizations return as_dict - - -def option(name, input_type=None, **kwargs): - """A decorator that can be used instead of typehinting :class:`.Option`. - - .. versionadded:: 2.0 - - Attributes - ---------- - parameter_name: :class:`str` - The name of the target function parameter this option is mapped to. - This allows you to have a separate UI ``name`` and parameter name. - """ - - def decorator(func): - resolved_name = kwargs.pop("parameter_name", None) or name - itype = kwargs.pop("type", None) or input_type or func.__annotations__.get(resolved_name, str) - func.__annotations__[resolved_name] = Option(itype, name=name, **kwargs) - return func - - return decorator diff --git a/discord/enums.py b/discord/enums.py index be7efc48d8..4e3f678b5a 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -737,69 +737,6 @@ class SlashCommandOptionType(Enum): number = 10 attachment = 11 - @classmethod - def from_datatype(cls, datatype): - if isinstance(datatype, tuple): # typing.Union has been used - datatypes = [cls.from_datatype(op) for op in datatype] - if all(x == cls.channel for x in datatypes): - return cls.channel - elif set(datatypes) <= {cls.role, cls.user}: - return cls.mentionable - else: - raise TypeError("Invalid usage of typing.Union") - - py_3_10_union_type = hasattr(types, "UnionType") and isinstance(datatype, types.UnionType) - - if py_3_10_union_type or getattr(datatype, "__origin__", None) is Union: - # Python 3.10+ "|" operator or typing.Union has been used. The __args__ attribute is a tuple of the types. - # Type checking fails for this case, so ignore it. - return cls.from_datatype(datatype.__args__) # type: ignore - - if isinstance(datatype, str): - datatype_name = datatype - else: - datatype_name = datatype.__name__ - if datatype_name in ["Member", "User"]: - return cls.user - if datatype_name in [ - "GuildChannel", - "TextChannel", - "VoiceChannel", - "StageChannel", - "CategoryChannel", - "ThreadOption", - "Thread", - "ForumChannel", - "MediaChannel", - "DMChannel", - ]: - return cls.channel - if datatype_name == "Role": - return cls.role - if datatype_name == "Attachment": - return cls.attachment - if datatype_name == "Mentionable": - return cls.mentionable - - if isinstance(datatype, str) or issubclass(datatype, str): - return cls.string - if issubclass(datatype, bool): - return cls.boolean - if issubclass(datatype, int): - return cls.integer - if issubclass(datatype, float): - return cls.number - - from .commands.context import ApplicationContext # noqa: PLC0415 - from .ext.bridge import BridgeContext # noqa: PLC0415 - - if not issubclass( - datatype, (ApplicationContext, BridgeContext) - ): # TODO: prevent ctx being passed here in cog commands - raise TypeError( - f"Invalid class {datatype} used as an input type for an Option" - ) # TODO: Improve the error message - class EmbeddedActivity(Enum): """Embedded activity""" diff --git a/discord/errors.py b/discord/errors.py index d694180bc6..6be0198281 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -40,6 +40,7 @@ from .interactions import Interaction __all__ = ( + "AnnotationMismatch", "DiscordException", "ClientException", "NoMoreItems", @@ -97,6 +98,10 @@ class ValidationError(DiscordException): """An Exception that is raised when there is a Validation Error.""" +class AnnotationMismatch(SyntaxError, ValidationError): + """An Exception that is raised when an annotation does not match the type of the value.""" + + def _flatten_error_dict(d: dict[str, Any], key: str = "") -> dict[str, str]: items: list[tuple[str, str]] = [] for k, v in d.items(): diff --git a/discord/events/__init__.py b/discord/events/__init__.py new file mode 100644 index 0000000000..f0d924360a --- /dev/null +++ b/discord/events/__init__.py @@ -0,0 +1,320 @@ +from ..app.event_emitter import Event +from .audit_log import GuildAuditLogEntryCreate +from .automod import ( + AutoModActionExecution, + AutoModRuleCreate, + AutoModRuleDelete, + AutoModRuleUpdate, +) +from .channel import ( + ChannelCreate, + ChannelDelete, + ChannelPinsUpdate, + ChannelUpdate, + GuildChannelUpdate, + PrivateChannelUpdate, +) +from .entitlement import EntitlementCreate, EntitlementDelete, EntitlementUpdate +from .gateway import ( + ApplicationCommandPermissionsUpdate, + PresenceUpdate, + Ready, + Resumed, + UserUpdate, + _CacheAppEmojis, +) +from .gateway import GuildAvailable as GatewayGuildAvailable +from .gateway import GuildCreate as GatewayGuildCreate +from .gateway import GuildJoin as GatewayGuildJoin +from .guild import ( + GuildAvailable, + GuildBanAdd, + GuildBanRemove, + GuildCreate, + GuildDelete, + GuildEmojisUpdate, + GuildJoin, + GuildMemberJoin, + GuildMemberRemove, + GuildMembersChunk, + GuildMemberUpdate, + GuildRoleCreate, + GuildRoleDelete, + GuildRoleUpdate, + GuildStickersUpdate, + GuildUnavailable, + GuildUpdate, +) +from .integration import ( + GuildIntegrationsUpdate, + IntegrationCreate, + IntegrationDelete, + IntegrationUpdate, +) +from .interaction import InteractionCreate +from .invite import InviteCreate, InviteDelete +from .message import ( + MessageCreate, + MessageDelete, + MessageDeleteBulk, + MessageUpdate, + PollVoteAdd, + PollVoteRemove, + ReactionAdd, + ReactionClear, + ReactionRemove, + ReactionRemoveEmoji, +) +from .scheduled_event import ( + GuildScheduledEventCreate, + GuildScheduledEventDelete, + GuildScheduledEventUpdate, + GuildScheduledEventUserAdd, + GuildScheduledEventUserRemove, +) +from .soundboard import ( + GuildSoundboardSoundCreate, + GuildSoundboardSoundDelete, + GuildSoundboardSoundsUpdate, + GuildSoundboardSoundUpdate, + SoundboardSounds, +) +from .stage_instance import StageInstanceCreate, StageInstanceDelete, StageInstanceUpdate +from .subscription import SubscriptionCreate, SubscriptionDelete, SubscriptionUpdate +from .thread import ( + BulkThreadMemberUpdate, + ThreadCreate, + ThreadDelete, + ThreadJoin, + ThreadListSync, + ThreadMemberJoin, + ThreadMemberRemove, + ThreadMemberUpdate, + ThreadRemove, + ThreadUpdate, +) +from .typing import TypingStart +from .voice import VoiceChannelEffectSend, VoiceChannelStatusUpdate, VoiceServerUpdate, VoiceStateUpdate +from .webhook import WebhooksUpdate + +__all__ = ( + "ALL_EVENTS", + "Event", + # Audit Log + "GuildAuditLogEntryCreate", + # AutoMod + "AutoModActionExecution", + "AutoModRuleCreate", + "AutoModRuleDelete", + "AutoModRuleUpdate", + # Channel + "ChannelCreate", + "ChannelDelete", + "ChannelPinsUpdate", + "ChannelUpdate", + "GuildChannelUpdate", + "PrivateChannelUpdate", + # Entitlement + "EntitlementCreate", + "EntitlementDelete", + "EntitlementUpdate", + # Gateway + "ApplicationCommandPermissionsUpdate", + "GatewayGuildAvailable", + "GatewayGuildCreate", + "GatewayGuildJoin", + "PresenceUpdate", + "Ready", + "Resumed", + "UserUpdate", + "_CacheAppEmojis", + # Guild + "GuildAvailable", + "GuildBanAdd", + "GuildBanRemove", + "GuildCreate", + "GuildDelete", + "GuildEmojisUpdate", + "GuildJoin", + "GuildMemberJoin", + "GuildMemberRemove", + "GuildMembersChunk", + "GuildMemberUpdate", + "GuildRoleCreate", + "GuildRoleDelete", + "GuildRoleUpdate", + "GuildStickersUpdate", + "GuildUnavailable", + "GuildUpdate", + # Integration + "GuildIntegrationsUpdate", + "IntegrationCreate", + "IntegrationDelete", + "IntegrationUpdate", + # Interaction + "InteractionCreate", + # Invite + "InviteCreate", + "InviteDelete", + # Message + "MessageCreate", + "MessageDelete", + "MessageDeleteBulk", + "MessageUpdate", + "PollVoteAdd", + "PollVoteRemove", + "ReactionAdd", + "ReactionClear", + "ReactionRemove", + "ReactionRemoveEmoji", + # Scheduled Event + "GuildScheduledEventCreate", + "GuildScheduledEventDelete", + "GuildScheduledEventUpdate", + "GuildScheduledEventUserAdd", + "GuildScheduledEventUserRemove", + # Soundboard + "GuildSoundboardSoundCreate", + "GuildSoundboardSoundDelete", + "GuildSoundboardSoundsUpdate", + "GuildSoundboardSoundUpdate", + "SoundboardSounds", + # Stage Instance + "StageInstanceCreate", + "StageInstanceDelete", + "StageInstanceUpdate", + # Subscription + "SubscriptionCreate", + "SubscriptionDelete", + "SubscriptionUpdate", + # Thread + "BulkThreadMemberUpdate", + "ThreadCreate", + "ThreadDelete", + "ThreadJoin", + "ThreadListSync", + "ThreadMemberJoin", + "ThreadMemberRemove", + "ThreadMemberUpdate", + "ThreadRemove", + "ThreadUpdate", + # Typing + "TypingStart", + # Voice + "VoiceChannelEffectSend", + "VoiceChannelStatusUpdate", + "VoiceServerUpdate", + "VoiceStateUpdate", + # Webhook + "WebhooksUpdate", +) + +ALL_EVENTS: list[type[Event]] = [ + # Audit Log + GuildAuditLogEntryCreate, + # AutoMod + AutoModActionExecution, + AutoModRuleCreate, + AutoModRuleDelete, + AutoModRuleUpdate, + # Channel + ChannelCreate, + ChannelDelete, + ChannelPinsUpdate, + ChannelUpdate, + GuildChannelUpdate, + PrivateChannelUpdate, + # Entitlement + EntitlementCreate, + EntitlementDelete, + EntitlementUpdate, + # Gateway + ApplicationCommandPermissionsUpdate, + GatewayGuildAvailable, + GatewayGuildCreate, + GatewayGuildJoin, + PresenceUpdate, + Ready, + Resumed, + UserUpdate, + _CacheAppEmojis, + # Guild + GuildAvailable, + GuildBanAdd, + GuildBanRemove, + GuildCreate, + GuildDelete, + GuildEmojisUpdate, + GuildJoin, + GuildMemberJoin, + GuildMemberRemove, + GuildMembersChunk, + GuildMemberUpdate, + GuildRoleCreate, + GuildRoleDelete, + GuildRoleUpdate, + GuildStickersUpdate, + GuildUnavailable, + GuildUpdate, + # Integration + GuildIntegrationsUpdate, + IntegrationCreate, + IntegrationDelete, + IntegrationUpdate, + # Interaction + InteractionCreate, + # Invite + InviteCreate, + InviteDelete, + # Message + MessageCreate, + MessageDelete, + MessageDeleteBulk, + MessageUpdate, + PollVoteAdd, + PollVoteRemove, + ReactionAdd, + ReactionClear, + ReactionRemove, + ReactionRemoveEmoji, + # Scheduled Event + GuildScheduledEventCreate, + GuildScheduledEventDelete, + GuildScheduledEventUpdate, + GuildScheduledEventUserAdd, + GuildScheduledEventUserRemove, + # Soundboard + GuildSoundboardSoundCreate, + GuildSoundboardSoundDelete, + GuildSoundboardSoundsUpdate, + GuildSoundboardSoundUpdate, + SoundboardSounds, + # Stage Instance + StageInstanceCreate, + StageInstanceDelete, + StageInstanceUpdate, + # Subscription + SubscriptionCreate, + SubscriptionDelete, + SubscriptionUpdate, + # Thread + BulkThreadMemberUpdate, + ThreadCreate, + ThreadDelete, + ThreadJoin, + ThreadListSync, + ThreadMemberJoin, + ThreadMemberRemove, + ThreadMemberUpdate, + ThreadRemove, + ThreadUpdate, + # Typing + TypingStart, + # Voice + VoiceChannelEffectSend, + VoiceChannelStatusUpdate, + VoiceServerUpdate, + VoiceStateUpdate, + # Webhook + WebhooksUpdate, +] diff --git a/discord/events/audit_log.py b/discord/events/audit_log.py index 27b555543b..a5a0a77816 100644 --- a/discord/events/audit_log.py +++ b/discord/events/audit_log.py @@ -23,7 +23,9 @@ """ import logging -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState @@ -34,13 +36,27 @@ class GuildAuditLogEntryCreate(Event, AuditLogEntry): - __event_name__ = "GUILD_AUDIT_LOG_ENTRY_CREATE" + """Called when an audit log entry is created. + + The bot must have :attr:`~Permissions.view_audit_log` to receive this, and + :attr:`Intents.moderation` must be enabled. + + This event inherits from :class:`AuditLogEntry`. + + Attributes + ---------- + raw: :class:`RawAuditLogEntryEvent` + The raw event payload data. + """ + + __event_name__: str = "GUILD_AUDIT_LOG_ENTRY_CREATE" raw: RawAuditLogEntryEvent def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: diff --git a/discord/events/automod.py b/discord/events/automod.py index 248f0705e0..428d594ac7 100644 --- a/discord/events/automod.py +++ b/discord/events/automod.py @@ -22,7 +22,9 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.state import ConnectionState from discord.automod import AutoModRule @@ -32,12 +34,24 @@ class AutoModRuleCreate(Event): - __event_name__ = "AUTO_MODERATION_RULE_CREATE" + """Called when an auto moderation rule is created. + + The bot must have :attr:`~Permissions.manage_guild` to receive this, and + :attr:`Intents.auto_moderation_configuration` must be enabled. + + Attributes + ---------- + rule: :class:`AutoModRule` + The newly created rule. + """ + + __event_name__: str = "AUTO_MODERATION_RULE_CREATE" __slots__ = ("rule",) rule: AutoModRule @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.rule = AutoModRule(state=state, data=data) @@ -45,12 +59,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class AutoModRuleUpdate(Event): - __event_name__ = "AUTO_MODERATION_RULE_UPDATE" + """Called when an auto moderation rule is updated. + + The bot must have :attr:`~Permissions.manage_guild` to receive this, and + :attr:`Intents.auto_moderation_configuration` must be enabled. + + Attributes + ---------- + rule: :class:`AutoModRule` + The updated rule. + """ + + __event_name__: str = "AUTO_MODERATION_RULE_UPDATE" __slots__ = ("rule",) rule: AutoModRule @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.rule = AutoModRule(state=state, data=data) @@ -58,12 +84,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class AutoModRuleDelete(Event): - __event_name__ = "AUTO_MODERATION_RULE_DELETE" + """Called when an auto moderation rule is deleted. + + The bot must have :attr:`~Permissions.manage_guild` to receive this, and + :attr:`Intents.auto_moderation_configuration` must be enabled. + + Attributes + ---------- + rule: :class:`AutoModRule` + The deleted rule. + """ + + __event_name__: str = "AUTO_MODERATION_RULE_DELETE" __slots__ = ("rule",) rule: AutoModRule @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.rule = AutoModRule(state=state, data=data) @@ -71,11 +109,18 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class AutoModActionExecution(Event, AutoModActionExecutionEvent): - """Represents the `AUTO_MODERATION_ACTION_EXECUTION` event""" + """Called when an auto moderation action is executed. + + The bot must have :attr:`~Permissions.manage_guild` to receive this, and + :attr:`Intents.auto_moderation_execution` must be enabled. + + This event inherits from :class:`AutoModActionExecutionEvent`. + """ - __event_name__ = "AUTO_MODERATION_ACTION_EXECUTION" + __event_name__: str = "AUTO_MODERATION_ACTION_EXECUTION" @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() event = await AutoModActionExecutionEvent.from_data(state, data) diff --git a/discord/events/channel.py b/discord/events/channel.py index c88579c499..53b9e221d2 100644 --- a/discord/events/channel.py +++ b/discord/events/channel.py @@ -24,25 +24,71 @@ from copy import copy from datetime import datetime -from typing import Any, Self, TypeVar, cast +from functools import lru_cache +from typing import Any, TypeVar, cast -from discord.abc import GuildChannel, PrivateChannel +from typing_extensions import Self, override + +from discord.abc import PrivateChannel from discord.app.event_emitter import Event from discord.app.state import ConnectionState -from discord.channel import GroupChannel, _channel_factory +from discord.channel import GroupChannel, GuildChannel, _channel_factory from discord.enums import ChannelType, try_enum -from discord.threads import Thread +from discord.channel.thread import Thread from discord.utils.private import get_as_snowflake, parse_time T = TypeVar("T") +@lru_cache(maxsize=128) +def _create_event_channel_class(event_cls: type[Event], channel_cls: type[GuildChannel]) -> type[GuildChannel]: + """ + Dynamically create a class that inherits from both an Event and a Channel type. + + This allows the event to have the correct channel type while also being an Event. + Results are cached to avoid recreating the same class multiple times. + + Parameters + ---------- + event_cls: type[Event] + The event class (e.g., ChannelCreate) + channel_cls: type[GuildChannel] + The channel class (e.g., TextChannel, VoiceChannel) + + Returns + ------- + type[GuildChannel] + A new class that inherits from both the event and channel + """ + class EventChannel(event_cls, channel_cls): # type: ignore + __slots__ = () + + EventChannel.__name__ = f"{event_cls.__name__}_{channel_cls.__name__}" + EventChannel.__qualname__ = f"{event_cls.__qualname__}_{channel_cls.__name__}" + + return EventChannel # type: ignore + + class ChannelCreate(Event, GuildChannel): - __event_name__ = "CHANNEL_CREATE" + """Called when a guild channel is created. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from the actual channel type that was created + (e.g., :class:`TextChannel`, :class:`VoiceChannel`, :class:`ForumChannel`, etc.). + You can access all channel attributes directly on the event object. + + .. note:: + While this class shows :class:`GuildChannel` in the signature, at runtime + the event will be an instance of the specific channel type that was created. + """ + + __event_name__: str = "CHANNEL_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | None: factory, _ = _channel_factory(data["type"]) if factory is None: @@ -50,53 +96,102 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | guild_id = get_as_snowflake(data, "guild_id") guild = await state._get_guild(guild_id) - if guild is not None: - # the factory can't be a DMChannel or GroupChannel here - channel = factory(guild=guild, state=self, data=data) # type: ignore # noqa: F821 # self is unbound - guild._add_channel(channel) # type: ignore - self = cls() - self.__dict__.update(channel.__dict__) - return self - else: + if guild is None: return + # the factory can't be a DMChannel or GroupChannel here + # Create the real channel object to be stored in the guild + channel = await factory._from_data(guild=guild, state=state, data=data) # type: ignore + guild._add_channel(channel) # type: ignore + + # Create a dynamic event class that combines this event type with the specific channel type + event_channel_cls = _create_event_channel_class(cls, factory) # type: ignore + # Instantiate it using the event's stub __init__ (no arguments) + self = event_channel_cls() # type: ignore + # Populate the event instance with data from the real channel + self._populate_from_slots(channel) + return self # type: ignore class PrivateChannelUpdate(Event, PrivateChannel): - __event_name__ = "PRIVATE_CHANNEL_UPDATE" + """Called whenever a private group DM is updated (e.g., changed name or topic). + + This requires :attr:`Intents.messages` to be enabled. + + This event inherits from :class:`GroupChannel`. + + Attributes + ---------- + old: :class:`GroupChannel` | None + The channel's old info before the update, or None if not in cache. + """ + + __event_name__: str = "PRIVATE_CHANNEL_UPDATE" old: PrivateChannel | None def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: tuple[PrivateChannel | None, PrivateChannel], _: ConnectionState) -> Self | None: + @override + async def __load__(cls, data: tuple[PrivateChannel | None, PrivateChannel], state: ConnectionState) -> Self | None: self = cls() self.old = data[0] - self.__dict__.update(data[1].__dict__) + self._populate_from_slots(data[1]) return self class GuildChannelUpdate(Event, PrivateChannel): - __event_name__ = "GUILD_CHANNEL_UPDATE" + """Called whenever a guild channel is updated (e.g., changed name, topic, permissions). + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from the actual channel type that was updated + (e.g., :class:`TextChannel`, :class:`VoiceChannel`, :class:`ForumChannel`, etc.). + + .. note:: + While this class shows :class:`GuildChannel` in the signature, at runtime + the event will be an instance of the specific channel type that was updated. + + Attributes + ---------- + old: :class:`TextChannel` | :class:`VoiceChannel` | :class:`CategoryChannel` | :class:`StageChannel` | :class:`ForumChannel` | None + The channel's old info before the update, or None if not in cache. + This will be the same type as the event itself. + """ + + __event_name__: str = "GUILD_CHANNEL_UPDATE" old: GuildChannel | None def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: tuple[GuildChannel | None, GuildChannel], _: ConnectionState) -> Self | None: - self = cls() + @override + async def __load__(cls, data: tuple[GuildChannel | None, GuildChannel], state: ConnectionState) -> Self | None: + channel = data[1] + # Create a dynamic event class that combines this event type with the specific channel type + event_channel_cls = _create_event_channel_class(cls, type(channel)) # type: ignore + # Instantiate it using the event's stub __init__ (no arguments) + self = event_channel_cls() # type: ignore + # Set the old channel and populate from the new channel self.old = data[0] - self.__dict__.update(data[1].__dict__) - return self + self._populate_from_slots(channel) + return self # type: ignore class ChannelUpdate(Event, GuildChannel): - __event_name__ = "CHANNEL_UPDATE" + """Internal event that dispatches to either :class:`PrivateChannelUpdate` or :class:`GuildChannelUpdate`. + + This event is not directly received by user code. It automatically routes to the appropriate + specific channel update event based on the channel type. + """ + + __event_name__: str = "CHANNEL_UPDATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | None: channel_type = try_enum(ChannelType, data.get("type")) channel_id = int(data["id"]) @@ -113,17 +208,31 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: - old_channel = copy.copy(channel) + old_channel = copy(channel) await channel._update(data) # type: ignore await state.emitter.emit("GUILD_CHANNEL_UPDATE", (old_channel, channel)) class ChannelDelete(Event, GuildChannel): - __event_name__ = "CHANNEL_DELETE" + """Called when a guild channel is deleted. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from the actual channel type that was deleted + (e.g., :class:`TextChannel`, :class:`VoiceChannel`, :class:`ForumChannel`, etc.). + You can access all channel attributes directly on the event object. + + .. note:: + While this class shows :class:`GuildChannel` in the signature, at runtime + the event will be an instance of the specific channel type that was deleted. + """ + + __event_name__: str = "CHANNEL_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | None: guild = await state._get_guild(get_as_snowflake(data, "guild_id")) channel_id = int(data["id"]) @@ -131,16 +240,32 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | channel = guild.get_channel(channel_id) if channel is not None: guild._remove_channel(channel) - self = cls() - self.__dict__.update(channel.__dict__) - return self + # Create a dynamic event class that combines this event type with the specific channel type + event_channel_cls = _create_event_channel_class(cls, type(channel)) # type: ignore + # Instantiate it using the event's stub __init__ (no arguments) + self = event_channel_cls() # type: ignore + # Populate the event instance with data from the real channel + self._populate_from_slots(channel) + return self # type: ignore class ChannelPinsUpdate(Event): + """Called whenever a message is pinned or unpinned from a channel. + + Attributes + ---------- + channel: :class:`abc.PrivateChannel` | :class:`TextChannel` | :class:`VoiceChannel` | :class:`StageChannel` | :class:`ForumChannel` | :class:`Thread` + The channel that had its pins updated. Can be any messageable channel type. + last_pin: :class:`datetime.datetime` | None + The latest message that was pinned as an aware datetime in UTC, or None if no pins exist. + """ + + __event_name__: str = "CHANNEL_PINS_UPDATE" channel: PrivateChannel | GuildChannel | Thread last_pin: datetime | None @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | None: channel_id = int(data["channel_id"]) try: diff --git a/discord/events/entitlement.py b/discord/events/entitlement.py index d26bb34857..4f58b8d9bd 100644 --- a/discord/events/entitlement.py +++ b/discord/events/entitlement.py @@ -22,7 +22,9 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.types.monetization import Entitlement as EntitlementPayload @@ -32,12 +34,18 @@ class EntitlementCreate(Event, Entitlement): - __event_name__ = "ENTITLEMENT_CREATE" + """Called when a user subscribes to an SKU. + + This event inherits from :class:`Entitlement`. + """ + + __event_name__: str = "ENTITLEMENT_CREATE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Entitlement(data=data, state=state).__dict__) @@ -45,12 +53,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class EntitlementUpdate(Event, Entitlement): - __event_name__ = "ENTITLEMENT_UPDATE" + """Called when a user's subscription to an Entitlement is cancelled. + + .. note:: + Before October 1, 2024, this event was called when a user's subscription was renewed. + + Entitlements that no longer follow this behavior will have a type of :attr:`EntitlementType.purchase`. + Those that follow the old behavior will have a type of :attr:`EntitlementType.application_subscription`. + + This event inherits from :class:`Entitlement`. + """ + + __event_name__: str = "ENTITLEMENT_UPDATE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Entitlement(data=data, state=state).__dict__) @@ -58,12 +78,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class EntitlementDelete(Event, Entitlement): - __event_name__ = "ENTITLEMENT_DELETE" + """Called when a user's entitlement is deleted. + + Entitlements are usually only deleted when Discord issues a refund for a subscription, + or manually removes an entitlement from a user. + + .. note:: + This is not called when a user's subscription is cancelled. + + This event inherits from :class:`Entitlement`. + """ + + __event_name__: str = "ENTITLEMENT_DELETE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Entitlement(data=data, state=state).__dict__) diff --git a/discord/events/gateway.py b/discord/events/gateway.py index d653a7dc51..974de62c69 100644 --- a/discord/events/gateway.py +++ b/discord/events/gateway.py @@ -22,11 +22,13 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self, cast +from typing import Any, cast + +from typing_extensions import Self, override from discord.emoji import Emoji from discord.flags import ApplicationFlags -from discord.guild import Guild, GuildChannel +from discord.guild import Guild from discord.member import Member from discord.role import Role from discord.sticker import Sticker @@ -47,7 +49,9 @@ class Resumed(Event): - __event_name__ = "RESUMED" + """Called when the client has resumed a session.""" + + __event_name__: str = "RESUMED" @classmethod async def __load__(cls, _data: Any, _state: ConnectionState) -> Self | None: @@ -55,18 +59,38 @@ async def __load__(cls, _data: Any, _state: ConnectionState) -> Self | None: class Ready(Event): - __event_name__ = "READY" + """Called when the client is done preparing the data received from Discord. + + Usually after login is successful and the client's guilds and cache are filled up. + + .. warning:: + This event is not guaranteed to be the first event called. + Likewise, this event is **not** guaranteed to only be called once. + This library implements reconnection logic and thus will end up calling + this event whenever a RESUME request fails. + + Attributes + ---------- + user: :class:`ClientUser` + An instance representing the connected application user. + application_id: :class:`int` + A snowflake of the application's ID. + application_flags: :class:`ApplicationFlags` + An instance representing the application flags. + guilds: list[:class:`Guild`] + A list of guilds received in this event. Note it may have incomplete data + as ``GUILD_CREATE`` fills up other parts of guild data. + """ + + __event_name__: str = "READY" user: ClientUser - """An instance of :class:`.user.ClientUser` representing the application""" application_id: int - """A snowflake of the application's id""" application_flags: ApplicationFlags - """An instance of :class:`.flags.ApplicationFlags` representing the application flags""" guilds: list[Guild] - """A list of guilds received in this event. Note it may have incomplete data as `GUILD_CREATE` fills up other parts of guild data.""" @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self: self = cls() self.user = ClientUser(state=state, data=data["user"]) @@ -98,9 +122,10 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self: class _CacheAppEmojis(Event): - __event_name__ = "CACHE_APP_EMOJIS" + __event_name__: str = "CACHE_APP_EMOJIS" @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: if state.cache_app_emojis and state.application_id: data = await state.http.get_all_application_emojis(state.application_id) @@ -109,9 +134,18 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildCreate(Event, Guild): - """An event which represents a guild becoming available via the gateway. Trickles down to the more distinct :class:`.GuildJoin` and :class:`.GuildAvailable` events.""" + """Internal event representing a guild becoming available via the gateway. + + This event trickles down to the more distinct :class:`GuildJoin` and :class:`GuildAvailable` events. + Users should typically listen to those events instead. - __event_name__ = "GUILD_CREATE" + Attributes + ---------- + guild: :class:`Guild` + The guild that became available. + """ + + __event_name__: str = "GUILD_CREATE" guild: Guild @@ -119,6 +153,7 @@ def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: GuildPayload, state: ConnectionState) -> Self: self = cls() guild = await state._get_guild(int(data["id"])) @@ -126,7 +161,7 @@ async def __load__(cls, data: GuildPayload, state: ConnectionState) -> Self: guild = await Guild._from_data(data, state) await state._add_guild(guild) self.guild = guild - self.__dict__.update(self.guild.__dict__) + # self.__dict__.update(self.guild.__dict__) # TODO: Find another way to do this if state._guild_needs_chunking(guild): await state.chunk_guild(guild) if guild.unavailable: @@ -137,9 +172,17 @@ async def __load__(cls, data: GuildPayload, state: ConnectionState) -> Self: class GuildJoin(Event, Guild): - """An event which represents joining a new guild.""" + """Called when the client joins a new guild or when a guild is created. + + This requires :attr:`Intents.guilds` to be enabled. - __event_name__ = "GUILD_JOIN" + Attributes + ---------- + guild: :class:`Guild` + The guild that was joined. + """ + + __event_name__: str = "GUILD_JOIN" guild: Guild @@ -147,17 +190,27 @@ def __init__(self) -> None: pass @classmethod - async def __load__(cls, data: Guild, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Guild, state: ConnectionState) -> Self: self = cls() self.guild = data - self.__dict__.update(self.guild.__dict__) + # self.__dict__.update(self.guild.__dict__) # TODO: Find another way to do this return self class GuildAvailable(Event, Guild): - """An event which represents a guild previously joined becoming available.""" + """Called when a guild becomes available. + + The guild must have existed in the client's cache. + This requires :attr:`Intents.guilds` to be enabled. + + Attributes + ---------- + guild: :class:`Guild` + The guild that became available. + """ - __event_name__ = "GUILD_AVAILABLE" + __event_name__: str = "GUILD_AVAILABLE" guild: Guild @@ -165,10 +218,11 @@ def __init__(self) -> None: pass @classmethod - async def __load__(cls, data: Guild, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Guild, state: ConnectionState) -> Self: self = cls() self.guild = data - self.__dict__.update(self.guild.__dict__) + # self.__dict__.update(self.guild.__dict__) # TODO: Find another way to do this return self @@ -183,20 +237,31 @@ def __init__(self, data: ApplicationCommandPermissionsPayload) -> None: class ApplicationCommandPermissionsUpdate(Event): - """Represents an Application Command having permissions updated in a guild""" + """Called when application command permissions are updated for a guild. + + This requires :attr:`Intents.guilds` to be enabled. + + Attributes + ---------- + id: :class:`int` + The ID of the command or application. + application_id: :class:`int` + The application ID. + guild_id: :class:`int` + The ID of the guild where permissions were updated. + permissions: list[:class:`ApplicationCommandPermission`] + The updated permissions for this application command. + """ - __event_name__ = "APPLICATION_COMMAND_PERMISSIONS_UPDATE" + __event_name__: str = "APPLICATION_COMMAND_PERMISSIONS_UPDATE" id: int - """A snowflake of the application command's id""" application_id: int - """A snowflake of the application's id""" guild_id: int - """A snowflake of the guild's id where the permissions have been updated""" permissions: list[ApplicationCommandPermission] - """A list of permissions this Application Command has""" @classmethod + @override async def __load__(cls, data: GuildApplicationCommandPermissions, state: ConnectionState) -> Self: self = cls() self.id = int(data["id"]) @@ -207,12 +272,29 @@ async def __load__(cls, data: GuildApplicationCommandPermissions, state: Connect class PresenceUpdate(Event): - __event_name__ = "PRESENCE_UPDATE" + """Called when a member updates their presence. + + This is called when one or more of the following things change: + - status + - activity + + This requires :attr:`Intents.presences` and :attr:`Intents.members` to be enabled. + + Attributes + ---------- + old: :class:`Member` + The member's old presence info. + new: :class:`Member` + The member's updated presence info. + """ + + __event_name__: str = "PRESENCE_UPDATE" old: Member new: Member @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: self = cls() guild_id = get_as_snowflake(data, "guild_id") @@ -234,13 +316,32 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class UserUpdate(Event, User): - __event_name__ = "USER_UPDATE" + """Called when a user updates their profile. + + This is called when one or more of the following things change: + - avatar + - username + - discriminator + - global_name + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`User`. + + Attributes + ---------- + old: :class:`User` + The user's old info before the update. + """ + + __event_name__: str = "USER_UPDATE" old: User def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: tuple[User, User] | Any, state: ConnectionState) -> Self | None: self = cls() if isinstance(data, tuple): @@ -252,4 +353,4 @@ async def __load__(cls, data: tuple[User, User] | Any, state: ConnectionState) - await user._update(data) # type: ignore ref = await state.cache.get_user(user.id) if ref is not None: - await ref._update(data) + ref._update(data) diff --git a/discord/events/guild.py b/discord/events/guild.py index 3c4d369faf..195ee36233 100644 --- a/discord/events/guild.py +++ b/discord/events/guild.py @@ -27,16 +27,16 @@ import logging from typing import TYPE_CHECKING, Any -from typing_extensions import Self +from typing_extensions import Self, override -from discord import Role -from discord.app.event_emitter import Event -from discord.app.state import ConnectionState -from discord.emoji import Emoji -from discord.guild import Guild -from discord.member import Member -from discord.raw_models import RawMemberRemoveEvent -from discord.sticker import Sticker +from ..app.event_emitter import Event +from ..app.state import ConnectionState +from ..emoji import Emoji +from ..guild import Guild +from ..member import Member +from ..raw_models import RawMemberRemoveEvent +from ..role import Role +from ..sticker import Sticker if TYPE_CHECKING: from ..types.member import MemberWithUser @@ -45,11 +45,19 @@ class GuildMemberJoin(Event, Member): - __event_name__ = "GUILD_MEMBER_JOIN" + """Called when a member joins a guild. + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`Member`. + """ + + __event_name__: str = "GUILD_MEMBER_JOIN" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -59,7 +67,7 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: ) return - member = Member(guild=guild, data=data, state=state) + member = await Member._from_data(guild=guild, data=data, state=state) if state.member_cache_flags.joined: await guild._add_member(member) @@ -67,16 +75,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild._member_count += 1 self = cls() - self.__dict__.update(member.__dict__) + self._populate_from_slots(member) return self class GuildMemberRemove(Event, Member): - __event_name__ = "GUILD_MEMBER_REMOVE" + """Called when a member leaves a guild. + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`Member`. + """ + + __event_name__: str = "GUILD_MEMBER_REMOVE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: user = await state.store_user(data["user"]) raw = RawMemberRemoveEvent(data, user) @@ -91,7 +107,7 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: raw.user = member guild._remove_member(member) # type: ignore self = cls() - self.__dict__.update(member.__dict__) + self._populate_from_slots(member) return self else: _log.debug( @@ -101,13 +117,33 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildMemberUpdate(Event, Member): - __event_name__ = "GUILD_MEMBER_UPDATE" + """Called when a member updates their profile. + + This is called when one or more of the following things change: + - nickname + - roles + - pending + - communication_disabled_until + - timed_out + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`Member`. + + Attributes + ---------- + old: :class:`Member` + The member's old info before the update. + """ + + __event_name__: str = "GUILD_MEMBER_UPDATE" old: Member def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) user = data["user"] @@ -128,12 +164,12 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: await state.emitter.emit("USER_UPDATE", user_update) self = cls() - self.__dict__.update(member.__dict__) + self._populate_from_slots(member) self.old = old_member return self else: if state.member_cache_flags.joined: - member = Member(data=data, guild=guild, state=state) + member = await Member._from_data(data=data, guild=guild, state=state) # Force an update on the inner user if necessary user_update = member._update_inner_user(user) @@ -147,13 +183,89 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: ) +class GuildMembersChunk(Event): + """Called when a chunk of guild members is received. + + This is sent when you request offline members via :meth:`Guild.chunk`. + This requires :attr:`Intents.members` to be enabled. + + Attributes + ---------- + guild: :class:`Guild` + The guild the members belong to. + members: list[:class:`Member`] + The members in this chunk. + chunk_index: :class:`int` + The chunk index in the expected chunks for this response (0 <= chunk_index < chunk_count). + chunk_count: :class:`int` + The total number of expected chunks for this response. + not_found: list[:class:`int`] + List of user IDs that were not found. + presences: list[Any] + List of presence data. + nonce: :class:`str` + The nonce used in the request, if any. + """ + + __event_name__: str = "GUILD_MEMBERS_CHUNK" + guild: Guild + members: list[Member] + chunk_index: int + chunk_count: int + not_found: list[int] + presences: list[Any] + nonce: str + + @classmethod + @override + async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: + guild_id = int(data["guild_id"]) + guild = state._get_guild(guild_id) + presences = data.get("presences", []) + + # the guild won't be None here + member_data_list = data.get("members", []) + members = await asyncio.gather( + *[Member._from_data(guild=guild, data=member, state=state) for member in member_data_list] + ) # type: ignore + _log.debug("Processed a chunk for %s members in guild ID %s.", len(members), guild_id) + + if presences: + member_dict = {str(member.id): member for member in members} + for presence in presences: + user = presence["user"] + member_id = user["id"] + member = member_dict.get(member_id) + if member is not None: + member._presence_update(presence, user) + + complete = data.get("chunk_index", 0) + 1 == data.get("chunk_count") + state.process_chunk_requests(guild_id, data.get("nonce"), members, complete) + return None + + class GuildEmojisUpdate(Event): - __event_name__ = "GUILD_EMOJIS_UPDATE" + """Called when a guild adds or removes emojis. + + This requires :attr:`Intents.emojis_and_stickers` to be enabled. + + Attributes + ---------- + guild: :class:`Guild` + The guild who got their emojis updated. + emojis: list[:class:`Emoji`] + The list of emojis after the update. + old_emojis: list[:class:`Emoji`] + The list of emojis before the update. + """ + + __event_name__: str = "GUILD_EMOJIS_UPDATE" guild: Guild emojis: list[Emoji] old_emojis: list[Emoji] @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -178,13 +290,28 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildStickersUpdate(Event): - __event_name__ = "GUILD_STICKERS_UPDATE" + """Called when a guild adds or removes stickers. + + This requires :attr:`Intents.emojis_and_stickers` to be enabled. + + Attributes + ---------- + guild: :class:`Guild` + The guild who got their stickers updated. + stickers: list[:class:`GuildSticker`] + The list of stickers after the update. + old_stickers: list[:class:`GuildSticker`] + The list of stickers before the update. + """ + + __event_name__: str = "GUILD_STICKERS_UPDATE" guild: Guild stickers: list[Sticker] old_stickers: list[Sticker] @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -209,47 +336,82 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildAvailable(Event, Guild): - __event_name__ = "GUILD_AVAILABLE" + """Called when a guild becomes available. + + The guild must have existed in the client's cache. + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Guild`. + """ + + __event_name__: str = "GUILD_AVAILABLE" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: Guild, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Guild, state: ConnectionState) -> Self: self = cls() - self.__dict__.update(data.__dict__) + # self.__dict__.update(data.__dict__) # TODO: Find another way to do this return self class GuildUnavailable(Event, Guild): - __event_name__ = "GUILD_UNAVAILABLE" + """Called when a guild becomes unavailable. + + The guild must have existed in the client's cache. + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Guild`. + """ + + __event_name__: str = "GUILD_UNAVAILABLE" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: Guild, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Guild, state: ConnectionState) -> Self: self = cls() self.__dict__.update(data.__dict__) return self class GuildJoin(Event, Guild): - __event_name__ = "GUILD_JOIN" + """Called when the client joins a new guild or when a guild is created. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Guild`. + """ + + __event_name__: str = "GUILD_JOIN" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: Guild, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Guild, state: ConnectionState) -> Self: self = cls() - self.__dict__.update(data.__dict__) + # self.__dict__.update(data.__dict__) # TODO: Find another way to do this return self class GuildCreate(Event, Guild): - __event_name__ = "GUILD_CREATE" + """Internal event representing a guild becoming available via the gateway. + + This event trickles down to the more distinct :class:`GuildJoin` and :class:`GuildAvailable` events. + Users should typically listen to those events instead. + + This event inherits from :class:`Guild`. + """ + + __event_name__: str = "GUILD_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: unavailable = data.get("unavailable") if unavailable is True: @@ -279,25 +441,44 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: await state.emitter.emit("GUILD_JOIN", guild) self = cls() - self.__dict__.update(data.__dict__) + # self.__dict__.update(data.__dict__) # TODO: Find another way to do this return self class GuildUpdate(Event, Guild): - __event_name__ = "GUILD_UPDATE" + """Called when a guild is updated. + + Examples of when this is called: + - Changed name + - Changed AFK channel + - Changed AFK timeout + - etc. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Guild`. + + Attributes + ---------- + old: :class:`Guild` + The guild prior to being updated. + """ + + __event_name__: str = "GUILD_UPDATE" old: Guild def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["id"])) if guild is not None: old_guild = copy.copy(guild) guild = await guild._from_data(data, state) self = cls() - self.__dict__.update(guild.__dict__) + self._populate_from_slots(guild) self.old = old_guild return self else: @@ -308,13 +489,35 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildDelete(Event, Guild): - __event_name__ = "GUILD_DELETE" + """Called when a guild is removed from the client. + + This happens through, but not limited to, these circumstances: + - The client got banned. + - The client got kicked. + - The client left the guild. + - The client or the guild owner deleted the guild. + + In order for this event to be invoked then the client must have been part of the guild + to begin with (i.e., it is part of :attr:`Client.guilds`). + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Guild`. + + Attributes + ---------- + old: :class:`Guild` + The guild that was removed. + """ + + __event_name__: str = "GUILD_DELETE" old: Guild def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["id"])) if guild is None: @@ -337,16 +540,24 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: await state._remove_guild(guild) self = cls() - self.__dict__.update(guild.__dict__) + self._populate_from_slots(guild) return self class GuildBanAdd(Event, Member): - __event_name__ = "GUILD_BAN_ADD" + """Called when a user gets banned from a guild. + + This requires :attr:`Intents.moderation` to be enabled. + + This event inherits from :class:`Member`. + """ + + __event_name__: str = "GUILD_BAN_ADD" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -365,19 +576,27 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: "deaf": False, "mute": False, } - member = Member(guild=guild, data=fake_data, state=state) + member = await Member._from_data(guild=guild, data=fake_data, state=state) self = cls() - self.__dict__.update(member.__dict__) + self._populate_from_slots(member) return self class GuildBanRemove(Event, Member): - __event_name__ = "GUILD_BAN_REMOVE" + """Called when a user gets unbanned from a guild. + + This requires :attr:`Intents.moderation` to be enabled. + + This event inherits from :class:`Member`. + """ + + __event_name__: str = "GUILD_BAN_REMOVE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -396,19 +615,28 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: "deaf": False, "mute": False, } - member = Member(guild=guild, data=fake_data, state=state) + member = await Member._from_data(guild=guild, data=fake_data, state=state) self = cls() - self.__dict__.update(member.__dict__) + self._populate_from_slots(member) return self class GuildRoleCreate(Event, Role): - __event_name__ = "GUILD_ROLE_CREATE" + """Called when a guild creates a role. + + To get the guild it belongs to, use :attr:`Role.guild`. + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Role`. + """ + + __event_name__: str = "GUILD_ROLE_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -422,18 +650,31 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild._add_role(role) self = cls() - self.__dict__.update(role.__dict__) + self._populate_from_slots(role) return self class GuildRoleUpdate(Event, Role): - __event_name__ = "GUILD_ROLE_UPDATE" + """Called when a role is changed guild-wide. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Role`. + + Attributes + ---------- + old: :class:`Role` + The updated role's old info. + """ + + __event_name__: str = "GUILD_ROLE_UPDATE" old: Role def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -441,7 +682,7 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: "GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"], ) - return + return None role_id: int = int(data["role"]["id"]) role = guild.get_role(role_id) @@ -450,23 +691,32 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: "GUILD_ROLE_UPDATE referencing an unknown role ID: %s. Discarding.", data["role"]["id"], ) - return + return None old_role = copy.copy(role) - await role._update(data["role"]) + role._update(data["role"]) self = cls() - self.__dict__.update(role.__dict__) + self._populate_from_slots(role) self.old = old_role return self class GuildRoleDelete(Event, Role): - __event_name__ = "GUILD_ROLE_DELETE" + """Called when a guild deletes a role. + + To get the guild it belongs to, use :attr:`Role.guild`. + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Role`. + """ + + __event_name__: str = "GUILD_ROLE_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -488,5 +738,5 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild._remove_role(role_id) self = cls() - self.__dict__.update(role.__dict__) + self._populate_from_slots(role) return self diff --git a/discord/events/integration.py b/discord/events/integration.py index c6dcfe77eb..be04fa8d9b 100644 --- a/discord/events/integration.py +++ b/discord/events/integration.py @@ -23,7 +23,9 @@ """ import logging -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState @@ -35,11 +37,22 @@ class GuildIntegrationsUpdate(Event): - __event_name__ = "GUILD_INTEGRATIONS_UPDATE" + """Called whenever an integration is created, modified, or removed from a guild. + + This requires :attr:`Intents.integrations` to be enabled. + + Attributes + ---------- + guild: :class:`Guild` + The guild that had its integrations updated. + """ + + __event_name__: str = "GUILD_INTEGRATIONS_UPDATE" guild: Guild @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -55,11 +68,19 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class IntegrationCreate(Event, Integration): - __event_name__ = "INTEGRATION_CREATE" + """Called when an integration is created. + + This requires :attr:`Intents.integrations` to be enabled. + + This event inherits from :class:`Integration`. + """ + + __event_name__: str = "INTEGRATION_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: data_copy = data.copy() guild_id = int(data_copy.pop("guild_id")) @@ -80,11 +101,19 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class IntegrationUpdate(Event, Integration): - __event_name__ = "INTEGRATION_UPDATE" + """Called when an integration is updated. + + This requires :attr:`Intents.integrations` to be enabled. + + This event inherits from :class:`Integration`. + """ + + __event_name__: str = "INTEGRATION_UPDATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: data_copy = data.copy() guild_id = int(data_copy.pop("guild_id")) @@ -105,11 +134,22 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class IntegrationDelete(Event): - __event_name__ = "INTEGRATION_DELETE" + """Called when an integration is deleted. + + This requires :attr:`Intents.integrations` to be enabled. + + Attributes + ---------- + raw: :class:`RawIntegrationDeleteEvent` + The raw event payload data. + """ + + __event_name__: str = "INTEGRATION_DELETE" raw: RawIntegrationDeleteEvent @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) diff --git a/discord/events/interaction.py b/discord/events/interaction.py index 2d4f035b03..5903174d3c 100644 --- a/discord/events/interaction.py +++ b/discord/events/interaction.py @@ -22,7 +22,9 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.enums import InteractionType from discord.types.interactions import Interaction as InteractionPayload @@ -33,14 +35,28 @@ class InteractionCreate(Event, Interaction): - __event_name__ = "INTERACTION_CREATE" + """Called when an interaction is created. + + This currently happens due to application command invocations or components being used. + + .. warning:: + This is a low level event that is not generally meant to be used. + If you are working with components, consider using the callbacks associated + with the :class:`~discord.ui.View` instead as it provides a nicer user experience. + + This event inherits from :class:`Interaction`. + """ + + __event_name__: str = "INTERACTION_CREATE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: interaction = Interaction(data=data, state=state) + await interaction.load_data() if data["type"] == 3: custom_id = interaction.data["custom_id"] # type: ignore component_type = interaction.data["component_type"] # type: ignore @@ -72,5 +88,5 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: except Exception as e: return await modal.on_error(e, interaction) self = cls() - self.__dict__.update(interaction.__dict__) + self._populate_from_slots(interaction) return self diff --git a/discord/events/invite.py b/discord/events/invite.py index 713ea66688..142fcdd987 100644 --- a/discord/events/invite.py +++ b/discord/events/invite.py @@ -22,11 +22,13 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override -from discord.abc import GuildChannel from discord.app.event_emitter import Event from discord.app.state import ConnectionState +from discord.channel.base import GuildChannel from discord.guild import Guild from discord.invite import Invite, PartialInviteChannel, PartialInviteGuild from discord.types.invite import GatewayInvite, VanityInvite @@ -34,11 +36,25 @@ class InviteCreate(Event, Invite): - __event_name__ = "INVITE_CREATE" + """Called when an invite is created. + + You must have :attr:`~Permissions.manage_channels` permission to receive this. + + .. note:: + There is a rare possibility that the :attr:`Invite.guild` and :attr:`Invite.channel` + attributes will be of :class:`Object` rather than the respective models. + + This requires :attr:`Intents.invites` to be enabled. + + This event inherits from :class:`Invite`. + """ + + __event_name__: str = "INVITE_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: GatewayInvite, state: ConnectionState) -> Self | None: invite = await Invite.from_gateway(state=state, data=data) self = cls() @@ -46,11 +62,28 @@ async def __load__(cls, data: GatewayInvite, state: ConnectionState) -> Self | N class InviteDelete(Event, Invite): - __event_name__ = "INVITE_DELETE" + """Called when an invite is deleted. + + You must have :attr:`~Permissions.manage_channels` permission to receive this. + + .. note:: + There is a rare possibility that the :attr:`Invite.guild` and :attr:`Invite.channel` + attributes will be of :class:`Object` rather than the respective models. + + Outside of those two attributes, the only other attribute guaranteed to be + filled by the Discord gateway for this event is :attr:`Invite.code`. + + This requires :attr:`Intents.invites` to be enabled. + + This event inherits from :class:`Invite`. + """ + + __event_name__: str = "INVITE_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: GatewayInvite, state: ConnectionState) -> Self | None: invite = await Invite.from_gateway(state=state, data=data) self = cls() diff --git a/discord/events/message.py b/discord/events/message.py index fcde0c6b1c..94315b2a62 100644 --- a/discord/events/message.py +++ b/discord/events/message.py @@ -22,7 +22,9 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.state import ConnectionState from discord.channel import StageChannel, TextChannel, VoiceChannel @@ -40,7 +42,7 @@ RawReactionClearEvent, ) from discord.reaction import Reaction -from discord.threads import Thread +from discord.channel.thread import Thread from discord.types.message import Reaction as ReactionPayload from discord.types.raw_models import ReactionActionEvent, ReactionClearEvent from discord.user import User @@ -52,16 +54,33 @@ class MessageCreate(Event, Message): - __event_name__ = "MESSAGE_CREATE" + """Called when a message is created and sent. + + This requires :attr:`Intents.messages` to be enabled. + + .. warning:: + Your bot's own messages and private messages are sent through this event. + This can lead to cases of 'recursion' depending on how your bot was programmed. + If you want the bot to not reply to itself, consider checking if :attr:`author` + equals the bot user. + + This event inherits from :class:`Message`. + """ + + __event_name__: str = "MESSAGE_CREATE" + + def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: channel, _ = await state._get_guild_channel(data) message = await Message._from_data(channel=channel, data=data, state=state) self = cls() - self.__dict__.update(message.__dict__) + self._populate_from_slots(message) + + await state.cache.store_built_message(message) - await state.cache.store_message(data, channel) # we ensure that the channel is either a TextChannel, VoiceChannel, StageChannel, or Thread if channel and channel.__class__ in ( TextChannel, @@ -75,12 +94,27 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class MessageDelete(Event, Message): - __event_name__ = "MESSAGE_DELETE" + """Called when a message is deleted. + + This requires :attr:`Intents.messages` to be enabled. + + This event inherits from :class:`Message`. + + Attributes + ---------- + raw: :class:`RawMessageDeleteEvent` + The raw event payload data. + is_cached: :class:`bool` + Whether the message was found in the internal cache. + """ + + __event_name__: str = "MESSAGE_DELETE" raw: RawMessageDeleteEvent is_cached: bool @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: self = cls() raw = RawMessageDeleteEvent(data) @@ -99,12 +133,25 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class MessageDeleteBulk(Event): - __event_name__ = "MESSAGE_DELETE_BULK" + """Called when messages are bulk deleted. + + This requires :attr:`Intents.messages` to be enabled. + + Attributes + ---------- + raw: :class:`RawBulkMessageDeleteEvent` + The raw event payload data. + messages: list[:class:`Message`] + The messages that have been deleted (only includes cached messages). + """ + + __event_name__: str = "MESSAGE_DELETE_BULK" raw: RawBulkMessageDeleteEvent messages: list[Message] @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() raw = RawBulkMessageDeleteEvent(data) @@ -118,12 +165,35 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class MessageUpdate(Event, Message): - __event_name__ = "MESSAGE_UPDATE" + """Called when a message receives an update event. + + This requires :attr:`Intents.messages` to be enabled. + + The following non-exhaustive cases trigger this event: + - A message has been pinned or unpinned. + - The message content has been changed. + - The message has received an embed. + - The message's embeds were suppressed or unsuppressed. + - A call message has received an update to its participants or ending time. + - A poll has ended and the results have been finalized. + + This event inherits from :class:`Message`. + + Attributes + ---------- + raw: :class:`RawMessageUpdateEvent` + The raw event payload data. + old: :class:`Message` | :class:`Undefined` + The previous version of the message (if it was cached). + """ + + __event_name__: str = "MESSAGE_UPDATE" raw: RawMessageUpdateEvent old: Message | Undefined @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() raw = RawMessageUpdateEvent(data) @@ -148,13 +218,31 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class ReactionAdd(Event): - __event_name__ = "MESSAGE_REACTION_ADD" + """Called when a message has a reaction added to it. + + This requires :attr:`Intents.reactions` to be enabled. + + .. note:: + To get the :class:`Message` being reacted to, access it via :attr:`reaction.message`. + + Attributes + ---------- + raw: :class:`RawReactionActionEvent` + The raw event payload data. + user: :class:`Member` | :class:`User` | :class:`Undefined` + The user who added the reaction. + reaction: :class:`Reaction` + The current state of the reaction. + """ + + __event_name__: str = "MESSAGE_REACTION_ADD" raw: RawReactionActionEvent user: Member | User | Undefined reaction: Reaction @classmethod + @override async def __load__(cls, data: ReactionActionEvent, state: ConnectionState) -> Self: self = cls() emoji = data["emoji"] @@ -188,13 +276,28 @@ async def __load__(cls, data: ReactionActionEvent, state: ConnectionState) -> Se class ReactionClear(Event): - __event_name__ = "MESSAGE_REACTION_REMOVE_ALL" + """Called when a message has all its reactions removed from it. + + This requires :attr:`Intents.reactions` to be enabled. + + Attributes + ---------- + raw: :class:`RawReactionClearEvent` + The raw event payload data. + message: :class:`Message` | :class:`Undefined` + The message that had its reactions cleared. + old_reactions: list[:class:`Reaction`] | :class:`Undefined` + The reactions that were removed. + """ + + __event_name__: str = "MESSAGE_REACTION_REMOVE_ALL" raw: RawReactionClearEvent message: Message | Undefined old_reactions: list[Reaction] | Undefined @classmethod + @override async def __load__(cls, data: ReactionClearEvent, state: ConnectionState) -> Self | None: self = cls() self.raw = RawReactionClearEvent(data) @@ -211,13 +314,31 @@ async def __load__(cls, data: ReactionClearEvent, state: ConnectionState) -> Sel class ReactionRemove(Event): - __event_name__ = "MESSAGE_REACTION_REMOVE" + """Called when a message has a reaction removed from it. + + This requires :attr:`Intents.reactions` to be enabled. + + .. note:: + To get the :class:`Message` being reacted to, access it via :attr:`reaction.message`. + + Attributes + ---------- + raw: :class:`RawReactionActionEvent` + The raw event payload data. + user: :class:`Member` | :class:`User` | :class:`Undefined` + The user who removed the reaction. + reaction: :class:`Reaction` + The current state of the reaction. + """ + + __event_name__: str = "MESSAGE_REACTION_REMOVE" raw: RawReactionActionEvent user: Member | User | Undefined reaction: Reaction @classmethod + @override async def __load__(cls, data: ReactionActionEvent, state: ConnectionState) -> Self: self = cls() emoji = data["emoji"] @@ -254,12 +375,20 @@ async def __load__(cls, data: ReactionActionEvent, state: ConnectionState) -> Se class ReactionRemoveEmoji(Event, Reaction): - __event_name__ = "MESSAGE_REACTION_REMOVE_EMOJI" + """Called when a message has a specific reaction removed from it. + + This requires :attr:`Intents.reactions` to be enabled. + + This event inherits from :class:`Reaction`. + """ + + __event_name__: str = "MESSAGE_REACTION_REMOVE_EMOJI" def __init__(self): pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: emoji = data["emoji"] emoji_id = utils.get_as_snowflake(emoji, "id") @@ -281,7 +410,25 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class PollVoteAdd(Event): - __event_name__ = "MESSAGE_POLL_VOTE_ADD" + """Called when a vote is cast on a poll. + + This requires :attr:`Intents.polls` to be enabled. + + Attributes + ---------- + raw: :class:`RawMessagePollVoteEvent` + The raw event payload data. + guild: :class:`Guild` | :class:`Undefined` + The guild where the poll vote occurred, if in a guild. + user: :class:`User` | :class:`Member` | None + The user who added the vote. + poll: :class:`Poll` + The current state of the poll. + answer: :class:`PollAnswer` + The answer that was voted for. + """ + + __event_name__: str = "MESSAGE_POLL_VOTE_ADD" raw: RawMessagePollVoteEvent guild: Guild | Undefined @@ -290,6 +437,7 @@ class PollVoteAdd(Event): answer: PollAnswer @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: self = cls() raw = RawMessagePollVoteEvent(data, False) @@ -317,7 +465,25 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class PollVoteRemove(Event): - __event_name__ = "MESSAGE_POLL_VOTE_REMOVE" + """Called when a vote is removed from a poll. + + This requires :attr:`Intents.polls` to be enabled. + + Attributes + ---------- + raw: :class:`RawMessagePollVoteEvent` + The raw event payload data. + guild: :class:`Guild` | :class:`Undefined` + The guild where the poll vote occurred, if in a guild. + user: :class:`User` | :class:`Member` | None + The user who removed the vote. + poll: :class:`Poll` + The current state of the poll. + answer: :class:`PollAnswer` + The answer that had its vote removed. + """ + + __event_name__: str = "MESSAGE_POLL_VOTE_REMOVE" raw: RawMessagePollVoteEvent guild: Guild | Undefined @@ -326,6 +492,7 @@ class PollVoteRemove(Event): answer: PollAnswer @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: self = cls() raw = RawMessagePollVoteEvent(data, False) diff --git a/discord/events/scheduled_event.py b/discord/events/scheduled_event.py index 4129b06450..6bf1a886c8 100644 --- a/discord/events/scheduled_event.py +++ b/discord/events/scheduled_event.py @@ -23,7 +23,9 @@ """ import logging -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState @@ -36,11 +38,19 @@ class GuildScheduledEventCreate(Event, ScheduledEvent): - __event_name__ = "GUILD_SCHEDULED_EVENT_CREATE" + """Called when a scheduled event is created. + + This requires :attr:`Intents.scheduled_events` to be enabled. + + This event inherits from :class:`ScheduledEvent`. + """ + + __event_name__: str = "GUILD_SCHEDULED_EVENT_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -60,13 +70,26 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildScheduledEventUpdate(Event, ScheduledEvent): - __event_name__ = "GUILD_SCHEDULED_EVENT_UPDATE" + """Called when a scheduled event is updated. + + This requires :attr:`Intents.scheduled_events` to be enabled. + + This event inherits from :class:`ScheduledEvent`. + + Attributes + ---------- + old: :class:`ScheduledEvent` + The old scheduled event before the update. + """ + + __event_name__: str = "GUILD_SCHEDULED_EVENT_UPDATE" old: ScheduledEvent | None def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -88,11 +111,19 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildScheduledEventDelete(Event, ScheduledEvent): - __event_name__ = "GUILD_SCHEDULED_EVENT_DELETE" + """Called when a scheduled event is deleted. + + This requires :attr:`Intents.scheduled_events` to be enabled. + + This event inherits from :class:`ScheduledEvent`. + """ + + __event_name__: str = "GUILD_SCHEDULED_EVENT_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -113,13 +144,28 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildScheduledEventUserAdd(Event): - __event_name__ = "GUILD_SCHEDULED_EVENT_USER_ADD" + """Called when a user subscribes to a scheduled event. + + This requires :attr:`Intents.scheduled_events` to be enabled. + + Attributes + ---------- + event: :class:`ScheduledEvent` + The scheduled event subscribed to. + member: :class:`Member` + The member who subscribed. + raw: :class:`RawScheduledEventSubscription` + The raw event payload data. + """ + + __event_name__: str = "GUILD_SCHEDULED_EVENT_USER_ADD" raw: RawScheduledEventSubscription event: ScheduledEvent member: Member @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -146,13 +192,28 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class GuildScheduledEventUserRemove(Event): - __event_name__ = "GUILD_SCHEDULED_EVENT_USER_REMOVE" + """Called when a user unsubscribes from a scheduled event. + + This requires :attr:`Intents.scheduled_events` to be enabled. + + Attributes + ---------- + event: :class:`ScheduledEvent` + The scheduled event unsubscribed from. + member: :class:`Member` + The member who unsubscribed. + raw: :class:`RawScheduledEventSubscription` + The raw event payload data. + """ + + __event_name__: str = "GUILD_SCHEDULED_EVENT_USER_REMOVE" raw: RawScheduledEventSubscription event: ScheduledEvent member: Member @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: diff --git a/discord/events/soundboard.py b/discord/events/soundboard.py new file mode 100644 index 0000000000..6107a0bedd --- /dev/null +++ b/discord/events/soundboard.py @@ -0,0 +1,176 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from typing import TYPE_CHECKING, Any + +from typing_extensions import Self, override + +from ..app.event_emitter import Event +from ..raw_models import RawSoundboardSoundDeleteEvent +from ..soundboard import SoundboardSound + +if TYPE_CHECKING: + from ..app.state import ConnectionState + +__all__ = ( + "SoundboardSounds", + "GuildSoundboardSoundsUpdate", + "GuildSoundboardSoundUpdate", + "GuildSoundboardSoundCreate", + "GuildSoundboardSoundDelete", +) + + +class SoundboardSounds(Event): + __event_name__: str = "SOUNDBOARD_SOUNDS" + + def __init__(self, guild_id: int, sounds: list[SoundboardSound]) -> None: + self.guild_id: int = guild_id + self.sounds: list[SoundboardSound] = sounds + + @classmethod + @override + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: + guild_id = int(data["guild_id"]) + sounds: list[SoundboardSound] = [] + for sound_data in data["soundboard_sounds"]: + sound = SoundboardSound(state=state, http=state.http, data=sound_data) + await state.cache.store_sound(sound) + sounds.append(sound) + return cls(guild_id, sounds) + + +class GuildSoundboardSoundsUpdate(Event): + """Called when multiple guild soundboard sounds are updated at once. + + This is called, for example, when a guild loses a boost level and some sounds become unavailable. + + Attributes + ---------- + old_sounds: list[:class:`SoundboardSound`] | None + The soundboard sounds prior to being updated (only if all were cached). + new_sounds: list[:class:`SoundboardSound`] + The soundboard sounds after being updated. + """ + + __event_name__: str = "GUILD_SOUNDBOARD_SOUNDS_UPDATE" + + def __init__( + self, + before_sounds: list[SoundboardSound], + after_sounds: list[SoundboardSound], + ) -> None: + self.before: list[SoundboardSound] = before_sounds + self.after: list[SoundboardSound] = after_sounds + + @classmethod + @override + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: + before_sounds: list[SoundboardSound] = [] + after_sounds: list[SoundboardSound] = [] + for sound_data in data["soundboard_sounds"]: + after = SoundboardSound(state=state, http=state.http, data=sound_data) + if before := await state.cache.get_sound(after.id): + before_sounds.append(before) + await state.cache.store_sound(after) + after_sounds.append(after) + + if len(before_sounds) == len(after_sounds): + return cls(before_sounds, after_sounds) + return None + + +class GuildSoundboardSoundUpdate(Event): + """Called when a soundboard sound is updated. + + Attributes + ---------- + old: :class:`SoundboardSound` | None + The soundboard sound prior to being updated (if it was cached). + new: :class:`SoundboardSound` + The soundboard sound after being updated. + """ + + __event_name__: str = "GUILD_SOUNDBOARD_SOUND_UPDATE" + + def __init__(self, before: SoundboardSound, after: SoundboardSound) -> None: + self.before: SoundboardSound = before + self.after: SoundboardSound = after + + @classmethod + @override + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: + after = SoundboardSound(state=state, http=state.http, data=data) + before = await state.cache.get_sound(after.id) + await state.cache.store_sound(after) + if before: + return cls(before, after) + return None + + +class GuildSoundboardSoundCreate(Event): + """Called when a soundboard sound is created. + + This event inherits from :class:`SoundboardSound`. + """ + + __event_name__: str = "GUILD_SOUNDBOARD_SOUND_CREATE" + + def __init__(self, sound: SoundboardSound) -> None: + self.sound: SoundboardSound = sound + + @classmethod + @override + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: + sound = SoundboardSound(state=state, http=state.http, data=data) + await state.cache.store_sound(sound) + return cls(sound) + + +class GuildSoundboardSoundDelete(Event): + """Called when a soundboard sound is deleted. + + Attributes + ---------- + raw: :class:`RawSoundboardSoundDeleteEvent` + The raw event payload data. + sound: :class:`SoundboardSound` | None + The deleted sound (if it was cached). + """ + + __event_name__: str = "GUILD_SOUNDBOARD_SOUND_DELETE" + + def __init__(self, sound: SoundboardSound | None, raw: RawSoundboardSoundDeleteEvent) -> None: + self.sound: SoundboardSound | None = sound + self.raw: RawSoundboardSoundDeleteEvent = raw + + @classmethod + @override + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: + sound_id = int(data["sound_id"]) + sound = await state.cache.get_sound(sound_id) + if sound is not None: + await state.cache.delete_sound(sound_id) + raw = RawSoundboardSoundDeleteEvent(data) + return cls(sound, raw) diff --git a/discord/events/stage_instance.py b/discord/events/stage_instance.py index a0740f0dce..1d76a863d2 100644 --- a/discord/events/stage_instance.py +++ b/discord/events/stage_instance.py @@ -24,7 +24,9 @@ import copy import logging -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState @@ -34,11 +36,17 @@ class StageInstanceCreate(Event, StageInstance): - __event_name__ = "STAGE_INSTANCE_CREATE" + """Called when a stage instance is created for a stage channel. + + This event inherits from :class:`StageInstance`. + """ + + __event_name__: str = "STAGE_INSTANCE_CREATE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -57,13 +65,28 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class StageInstanceUpdate(Event, StageInstance): - __event_name__ = "STAGE_INSTANCE_UPDATE" + """Called when a stage instance is updated. + + The following, but not limited to, examples illustrate when this event is called: + - The topic is changed. + - The privacy level is changed. + + This event inherits from :class:`StageInstance`. + + Attributes + ---------- + old: :class:`StageInstance` + The stage instance before the update. + """ + + __event_name__: str = "STAGE_INSTANCE_UPDATE" old: StageInstance def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: @@ -91,11 +114,17 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class StageInstanceDelete(Event, StageInstance): - __event_name__ = "STAGE_INSTANCE_DELETE" + """Called when a stage instance is deleted for a stage channel. + + This event inherits from :class:`StageInstance`. + """ + + __event_name__: str = "STAGE_INSTANCE_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: diff --git a/discord/events/subscription.py b/discord/events/subscription.py index a91b11b799..adb235eff6 100644 --- a/discord/events/subscription.py +++ b/discord/events/subscription.py @@ -22,7 +22,9 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Self +from typing import Any + +from typing_extensions import Self, override from discord.types.monetization import Entitlement as EntitlementPayload @@ -32,12 +34,18 @@ class SubscriptionCreate(Event, Subscription): - __event_name__ = "SUBSCRIPTION_CREATE" + """Called when a subscription is created for the application. + + This event inherits from :class:`Subscription`. + """ + + __event_name__: str = "SUBSCRIPTION_CREATE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Subscription(data=data, state=state).__dict__) @@ -45,12 +53,20 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class SubscriptionUpdate(Event, Subscription): - __event_name__ = "SUBSCRIPTION_UPDATE" + """Called when a subscription has been updated. + + This could be a renewal, cancellation, or other payment related update. + + This event inherits from :class:`Subscription`. + """ + + __event_name__: str = "SUBSCRIPTION_UPDATE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Subscription(data=data, state=state).__dict__) @@ -58,12 +74,18 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self: class SubscriptionDelete(Event, Subscription): - __event_name__ = "SUBSCRIPTION_DELETE" + """Called when a subscription has been deleted. + + This event inherits from :class:`Subscription`. + """ + + __event_name__: str = "SUBSCRIPTION_DELETE" def __init__(self) -> None: pass @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self: self = cls() self.__dict__.update(Subscription(data=data, state=state).__dict__) diff --git a/discord/events/thread.py b/discord/events/thread.py index beab97c38f..beaa24788f 100644 --- a/discord/events/thread.py +++ b/discord/events/thread.py @@ -23,14 +23,16 @@ """ import logging -from typing import Any, Self, cast +from typing import Any, cast + +from typing_extensions import Self, override from discord import utils from discord.abc import Snowflake from discord.app.event_emitter import Event from discord.app.state import ConnectionState from discord.raw_models import RawThreadDeleteEvent, RawThreadMembersUpdateEvent, RawThreadUpdateEvent -from discord.threads import Thread, ThreadMember +from discord.channel.thread import Thread, ThreadMember from discord.types.raw_models import ThreadDeleteEvent, ThreadUpdateEvent from discord.types.threads import ThreadMember as ThreadMemberPayload @@ -38,61 +40,125 @@ class ThreadMemberJoin(Event, ThreadMember): - __event_name__ = "THREAD_MEMBER_JOIN" + """Called when a thread member joins a thread. + + You can get the thread a member belongs in by accessing :attr:`ThreadMember.thread`. + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`ThreadMember`. + """ + + __event_name__: str = "THREAD_MEMBER_JOIN" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: ThreadMember, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: ThreadMember, state: ConnectionState) -> Self: self = cls() self.__dict__.update(data.__dict__) return self class ThreadJoin(Event, Thread): - __event_name__ = "THREAD_JOIN" + """Called whenever the bot joins a thread. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Thread`. + """ + + __event_name__: str = "THREAD_JOIN" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: Thread, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Thread, state: ConnectionState) -> Self: self = cls() self.__dict__.update(data.__dict__) return self class ThreadMemberRemove(Event, ThreadMember): - __event_name__ = "THREAD_MEMBER_REMOVE" + """Called when a thread member leaves a thread. + + You can get the thread a member belongs in by accessing :attr:`ThreadMember.thread`. + + This requires :attr:`Intents.members` to be enabled. + + This event inherits from :class:`ThreadMember`. + """ + + __event_name__: str = "THREAD_MEMBER_REMOVE" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: ThreadMember, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: ThreadMember, state: ConnectionState) -> Self: self = cls() self.__dict__.update(data.__dict__) return self class ThreadRemove(Event, Thread): - __event_name__ = "THREAD_REMOVE" + """Called whenever a thread is removed. + + This is different from a thread being deleted. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + .. warning:: + Due to technical limitations, this event might not be called + as soon as one expects. Since the library tracks thread membership + locally, the API only sends updated thread membership status upon being + synced by joining a thread. + + This event inherits from :class:`Thread`. + """ + + __event_name__: str = "THREAD_REMOVE" def __init__(self) -> None: ... @classmethod - async def __load__(cls, data: Thread, _: ConnectionState) -> Self: + @override + async def __load__(cls, data: Thread, state: ConnectionState) -> Self: self = cls() self.__dict__.update(data.__dict__) return self class ThreadCreate(Event, Thread): - __event_name__ = "THREAD_CREATE" + """Called whenever a thread is created. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Thread`. + + Attributes + ---------- + just_joined: :class:`bool` + Whether the bot just joined the thread. + """ + + __event_name__: str = "THREAD_CREATE" def __init__(self) -> None: ... just_joined: bool + __slots__: tuple[str, ...] = ("just_joined",) @classmethod + @override async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) @@ -102,7 +168,7 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | cached_thread = guild.get_thread(int(data["id"])) self = cls() if not cached_thread: - thread = Thread(guild=guild, state=guild._state, data=data) # type: ignore + thread = await Thread._from_data(guild=guild, state=guild._state, data=data) # type: ignore guild._add_thread(thread) if data.get("newly_created"): thread._add_member( @@ -117,9 +183,11 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | ) ) self.just_joined = False - self.__dict__.update(thread.__dict__) + else: + self.just_joined = True + self._populate_from_slots(thread) else: - self.__dict__.update(cached_thread.__dict__) + self._populate_from_slots(cached_thread) self.just_joined = True if self.just_joined: @@ -129,13 +197,26 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self | class ThreadUpdate(Event, Thread): - __event_name__ = "THREAD_UPDATE" + """Called whenever a thread is updated. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Thread`. + + Attributes + ---------- + old: :class:`Thread` + The thread's old info before the update. + """ + + __event_name__: str = "THREAD_UPDATE" def __init__(self) -> None: ... old: Thread @classmethod + @override async def __load__(cls, data: ThreadUpdateEvent, state: ConnectionState) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) @@ -161,11 +242,21 @@ async def __load__(cls, data: ThreadUpdateEvent, state: ConnectionState) -> Self class ThreadDelete(Event, Thread): - __event_name__ = "THREAD_DELETE" + """Called whenever a thread is deleted. + + Note that you can get the guild from :attr:`Thread.guild`. + + This requires :attr:`Intents.guilds` to be enabled. + + This event inherits from :class:`Thread`. + """ + + __event_name__: str = "THREAD_DELETE" def __init__(self) -> None: ... @classmethod + @override async def __load__(cls, data: ThreadDeleteEvent, state: ConnectionState) -> Self | None: raw = RawThreadDeleteEvent(data) guild = await state._get_guild(raw.guild_id) @@ -176,7 +267,7 @@ async def __load__(cls, data: ThreadDeleteEvent, state: ConnectionState) -> Self thread = guild.get_thread(raw.thread_id) if thread: - guild._remove_thread(cast(Snowflake, thread.id)) + guild._remove_thread(thread) if (msg := await thread.get_starting_message()) is not None: msg.thread = None # type: ignore @@ -184,9 +275,10 @@ async def __load__(cls, data: ThreadDeleteEvent, state: ConnectionState) -> Self class ThreadListSync(Event): - __event_name__ = "THREAD_LIST_SYNC" + __event_name__: str = "THREAD_LIST_SYNC" @classmethod + @override async def __load__(cls, data: dict[str, Any], state) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) @@ -228,11 +320,12 @@ async def __load__(cls, data: dict[str, Any], state) -> Self | None: class ThreadMemberUpdate(Event, ThreadMember): - __event_name__ = "THREAD_MEMBER_UPDATE" + __event_name__: str = "THREAD_MEMBER_UPDATE" def __init__(self): ... @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) @@ -262,7 +355,10 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class BulkThreadMemberUpdate(Event): + __event_name__: str = "BULK_THREAD_MEMBER_UPDATE" + @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild_id = int(data["guild_id"]) guild = await state._get_guild(guild_id) diff --git a/discord/events/typing.py b/discord/events/typing.py index 198a29f8d8..2f10035764 100644 --- a/discord/events/typing.py +++ b/discord/events/typing.py @@ -23,7 +23,9 @@ """ from datetime import datetime -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any + +from typing_extensions import Self, override from discord import utils from discord.app.event_emitter import Event @@ -31,7 +33,7 @@ from discord.channel import DMChannel, GroupChannel, TextChannel from discord.member import Member from discord.raw_models import RawTypingEvent -from discord.threads import Thread +from discord.channel.thread import Thread from discord.user import User if TYPE_CHECKING: @@ -39,7 +41,29 @@ class TypingStart(Event): - __event_name__ = "TYPING_START" + """Called when someone begins typing a message. + + The :attr:`channel` can be a :class:`abc.Messageable` instance, + which could be :class:`TextChannel`, :class:`GroupChannel`, or :class:`DMChannel`. + + If the :attr:`channel` is a :class:`TextChannel` then the :attr:`user` is a :class:`Member`, + otherwise it is a :class:`User`. + + This requires :attr:`Intents.typing` to be enabled. + + Attributes + ---------- + raw: :class:`RawTypingEvent` + The raw event payload data. + channel: :class:`abc.Messageable` + The location where the typing originated from. + user: :class:`User` | :class:`Member` + The user that started typing. + when: :class:`datetime.datetime` + When the typing started as an aware datetime in UTC. + """ + + __event_name__: str = "TYPING_START" raw: RawTypingEvent channel: "MessageableChannel" @@ -47,6 +71,7 @@ class TypingStart(Event): when: datetime @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: raw = RawTypingEvent(data) diff --git a/discord/events/voice.py b/discord/events/voice.py index 99d37f6e41..942db9ba71 100644 --- a/discord/events/voice.py +++ b/discord/events/voice.py @@ -24,16 +24,23 @@ import asyncio import logging -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState +from discord.enums import VoiceChannelEffectAnimationType, try_enum from discord.member import Member, VoiceState from discord.raw_models import RawVoiceChannelStatusUpdateEvent from discord.utils.private import get_as_snowflake if TYPE_CHECKING: from discord.abc import VocalGuildChannel + from discord.emoji import PartialEmoji + from discord.guild import Guild + from discord.soundboard import PartialSoundboardSound, SoundboardSound + from discord.types.channel import VoiceChannelEffectSend as VoiceChannelEffectSendPayload _log = logging.getLogger(__name__) @@ -47,13 +54,34 @@ async def logging_coroutine(coroutine, *, info: str) -> None: class VoiceStateUpdate(Event): - __event_name__ = "VOICE_STATE_UPDATE" + """Called when a member changes their voice state. + + The following, but not limited to, examples illustrate when this event is called: + - A member joins a voice or stage channel. + - A member leaves a voice or stage channel. + - A member is muted or deafened by their own accord. + - A member is muted or deafened by a guild administrator. + + This requires :attr:`Intents.voice_states` to be enabled. + + Attributes + ---------- + member: :class:`Member` + The member whose voice states changed. + before: :class:`VoiceState` + The voice state prior to the changes. + after: :class:`VoiceState` + The voice state after the changes. + """ + + __event_name__: str = "VOICE_STATE_UPDATE" member: Member before: VoiceState after: VoiceState @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(get_as_snowflake(data, "guild_id")) channel_id = get_as_snowflake(data, "channel_id") @@ -94,9 +122,17 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class VoiceServerUpdate(Event): - __event_name__ = "VOICE_SERVER_UPDATE" + """Called when the voice server is updated. + + .. note:: + This is an internal event used by the voice protocol. + It is not dispatched to user code. + """ + + __event_name__: str = "VOICE_SERVER_UPDATE" @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: try: key_id = int(data["guild_id"]) @@ -113,7 +149,21 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: class VoiceChannelStatusUpdate(Event): - __event_name__ = "VOICE_CHANNEL_STATUS_UPDATE" + """Called when someone updates a voice channel status. + + Attributes + ---------- + raw: :class:`RawVoiceChannelStatusUpdateEvent` + The raw voice channel status update payload. + channel: :class:`VoiceChannel` | :class:`StageChannel` + The channel where the voice channel status update originated from. + old_status: :class:`str` | None + The old voice channel status. + new_status: :class:`str` | None + The new voice channel status. + """ + + __event_name__: str = "VOICE_CHANNEL_STATUS_UPDATE" raw: RawVoiceChannelStatusUpdateEvent channel: "VocalGuildChannel" @@ -121,6 +171,7 @@ class VoiceChannelStatusUpdate(Event): new_status: str | None @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: raw = RawVoiceChannelStatusUpdateEvent(data) guild = await state._get_guild(int(data["guild_id"])) @@ -150,3 +201,104 @@ async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: self.old_status = old_status self.new_status = channel.status return self + + +class VoiceChannelEffectSend(Event): + """Called when a voice channel effect is sent. + + Attributes + ---------- + animation_type: :class:`VoiceChannelEffectAnimationType` + The type of animation that is being sent. + animation_id: :class:`int` + The ID of the animation that is being sent. + sound: :class:`SoundboardSound` | :class:`PartialSoundboardSound` | None + The sound that is being sent, could be ``None`` if the effect is not a sound effect. + guild: :class:`Guild` + The guild in which the sound is being sent. + user: :class:`Member` + The member that sent the sound. + channel: :class:`VoiceChannel` | :class:`StageChannel` + The voice channel in which the sound is being sent. + emoji: :class:`PartialEmoji` | None + The emoji associated with the effect, if any. + """ + + __event_name__: str = "VOICE_CHANNEL_EFFECT_SEND" + + def __init__( + self, + *, + animation_type: VoiceChannelEffectAnimationType, + animation_id: int, + sound: "SoundboardSound | PartialSoundboardSound | None", + guild: "Guild", + user: Member, + channel: "VocalGuildChannel", + emoji: "PartialEmoji | None", + ) -> None: + self.animation_type = animation_type + self.animation_id = animation_id + self.sound = sound + self.guild = guild + self.user = user + self.channel = channel + self.emoji = emoji + + @classmethod + @override + async def __load__(cls, data: "VoiceChannelEffectSendPayload", state: ConnectionState) -> Self | None: + from discord.emoji import PartialEmoji + from discord.soundboard import PartialSoundboardSound + + channel_id = int(data["channel_id"]) + user_id = int(data["user_id"]) + guild_id = int(data["guild_id"]) + + guild = await state._get_guild(guild_id) + if guild is None: + _log.debug( + "VOICE_CHANNEL_EFFECT_SEND referencing unknown guild ID: %s. Discarding.", + guild_id, + ) + return + + channel = guild.get_channel(channel_id) + if channel is None: + _log.debug( + "VOICE_CHANNEL_EFFECT_SEND referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) + return + + user = guild.get_member(user_id) + if user is None: + _log.debug( + "VOICE_CHANNEL_EFFECT_SEND referencing an unknown user ID: %s. Discarding.", + user_id, + ) + return + + # Create sound if present + sound = None + if data.get("sound_id"): + sound = PartialSoundboardSound(data, state, state.http) + + # Create emoji if present + emoji = None + if raw_emoji := data.get("emoji"): + emoji = PartialEmoji( + name=raw_emoji.get("name"), + animated=raw_emoji.get("animated", False), + id=int(raw_emoji["id"]) if raw_emoji.get("id") else None, + ) + + return cls( + animation_type=try_enum(VoiceChannelEffectAnimationType, data["animation_type"]), + animation_id=int(data["animation_id"]), + sound=sound, + guild=guild, + user=user, + channel=channel, # type: ignore + emoji=emoji, + ) diff --git a/discord/events/webhook.py b/discord/events/webhook.py index 74cf8bbbb8..4f882a1e41 100644 --- a/discord/events/webhook.py +++ b/discord/events/webhook.py @@ -23,23 +23,36 @@ """ import logging -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any + +from typing_extensions import Self, override from discord.app.event_emitter import Event from discord.app.state import ConnectionState if TYPE_CHECKING: - from discord.abc import GuildChannel + from discord.channel.base import GuildChannel _log = logging.getLogger(__name__) class WebhooksUpdate(Event): - __event_name__ = "WEBHOOKS_UPDATE" + """Called whenever a webhook is created, modified, or removed from a guild channel. + + This requires :attr:`Intents.webhooks` to be enabled. + + Attributes + ---------- + channel: :class:`TextChannel` | :class:`VoiceChannel` | :class:`ForumChannel` | :class:`StageChannel` + The channel that had its webhooks updated. + """ + + __event_name__: str = "WEBHOOKS_UPDATE" channel: "GuildChannel" @classmethod + @override async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: guild = await state._get_guild(int(data["guild_id"])) if guild is None: diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 621f744e43..fb838d43bf 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -92,7 +92,7 @@ async def _get_from_guilds(bot, getter, argument): T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) -CT = TypeVar("CT", bound=discord.abc.GuildChannel) +CT = TypeVar("CT", bound=discord.channel.GuildChannel) TT = TypeVar("TT", bound=discord.Thread) @@ -411,7 +411,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Message: raise ChannelNotReadable(channel) from e -class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): +class GuildChannelConverter(IDConverter[discord.channel.GuildChannel]): """Converts to a :class:`~discord.abc.GuildChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -426,8 +426,8 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): .. versionadded:: 2.0 """ - async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return await self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) + async def convert(self, ctx: Context, argument: str) -> discord.channel.GuildChannel: + return await self._resolve_channel(ctx, argument, "channels", discord.channel.base.GuildChannel) @staticmethod async def _resolve_channel(ctx: Context, argument: str, attribute: str, type: type[CT]) -> CT: @@ -1078,7 +1078,7 @@ def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool: discord.CategoryChannel: CategoryChannelConverter, discord.ForumChannel: ForumChannelConverter, discord.Thread: ThreadConverter, - discord.abc.GuildChannel: GuildChannelConverter, + discord.channel.GuildChannel: GuildChannelConverter, discord.GuildSticker: GuildStickerConverter, } diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 94aff45c77..ab4361a64e 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, TypeVar import discord.abc +import discord.channel.base from discord.enums import Enum from ...abc import PrivateChannel @@ -72,7 +73,7 @@ def get_key(self, msg: Message) -> Any: elif self is BucketType.category: return ( msg.channel.category.id - if isinstance(msg.channel, discord.abc.GuildChannel) and msg.channel.category + if isinstance(msg.channel, discord.channel.base.GuildChannel) and msg.channel.category else msg.channel.id ) elif self is BucketType.role: diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index faf71ef8b2..d3ef1c7f9a 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -32,10 +32,10 @@ if TYPE_CHECKING: from inspect import Parameter - from discord.abc import GuildChannel - from discord.threads import Thread + from discord.channel.thread import Thread from discord.types.snowflake import Snowflake, SnowflakeList + from ...channel.base import GuildChannel from .context import Context from .converter import Converter from .cooldowns import BucketType, Cooldown diff --git a/discord/flags.py b/discord/flags.py index 7c9140aa1d..037ea00d2a 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -27,6 +27,8 @@ from typing import Any, Callable, ClassVar, Iterator, TypeVar, overload +from typing_extensions import Self + from .enums import UserFlags __all__ = ( @@ -106,7 +108,7 @@ def __init__(self, **kwargs: bool): setattr(self, key, value) @classmethod - def _from_value(cls, value): + def _from_value(cls, value: int) -> Self: self = cls.__new__(cls) self.value = value return self diff --git a/discord/gateway.py b/discord/gateway.py index 6ee5f766c5..4ed4892718 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -367,31 +367,6 @@ async def from_client( await ws.resume() return ws - def wait_for(self, event, predicate, result=None): - """Waits for a DISPATCH'd event that meets the predicate. - - Parameters - ---------- - event: :class:`str` - The event name in all upper case to wait for. - predicate - A function that takes a data parameter to check for event - properties. The data parameter is the 'd' key in the JSON message. - result - A function that takes the same data parameter and executes to send - the result to the future. If ``None``, returns the data. - - Returns - ------- - asyncio.Future - A future to wait for. - """ - - future = self.loop.create_future() - entry = EventListener(event=event, predicate=predicate, result=result, future=future) - self._dispatch_listeners.append(entry) - return future - async def identify(self): """Sends the IDENTIFY packet.""" payload = { @@ -538,31 +513,6 @@ async def received_message(self, msg, /): await self._emitter.emit(event, data) - # remove the dispatched listeners - removed = [] - for index, entry in enumerate(self._dispatch_listeners): - if entry.event != event: - continue - - future = entry.future - if future.cancelled(): - removed.append(index) - continue - - try: - valid = entry.predicate(data) - except Exception as exc: - future.set_exception(exc) - removed.append(index) - else: - if valid: - ret = data if entry.result is None else entry.result(data) - future.set_result(ret) - removed.append(index) - - for index in reversed(removed): - del self._dispatch_listeners[index] - @property def latency(self) -> float: """Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. If no heartbeat diff --git a/discord/gears/__init__.py b/discord/gears/__init__.py new file mode 100644 index 0000000000..62edf16b42 --- /dev/null +++ b/discord/gears/__init__.py @@ -0,0 +1,3 @@ +from .gear import Gear + +__all__ = ("Gear",) diff --git a/discord/gears/gear.py b/discord/gears/gear.py new file mode 100644 index 0000000000..589fa270a8 --- /dev/null +++ b/discord/gears/gear.py @@ -0,0 +1,288 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import inspect +from collections import defaultdict +from collections.abc import Awaitable, Callable, Collection, Sequence +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Protocol, + TypeAlias, + TypeVar, + cast, + runtime_checkable, +) + +from ..app.event_emitter import Event +from ..utils import MISSING, Undefined +from ..utils.annotations import get_annotations +from ..utils.private import hybridmethod + +_T = TypeVar("_T", bound="Gear") +E = TypeVar("E", bound="Event", covariant=True) +E_contra = TypeVar("E_contra", bound="Event", contravariant=True) + + +@runtime_checkable +class AttributedEventCallback(Protocol): + __event__: type[Event] + __once__: bool + + +@runtime_checkable +class StaticAttributedEventCallback(AttributedEventCallback, Protocol): + __staticmethod__: bool + + +EventCallback: TypeAlias = Callable[[E], Awaitable[None]] + + +class Gear: + """A gear is a modular component that can listen to and handle events. + + You can subclass this class to create your own gears and attach them to your bot or other gears. + + Example + ------- + .. code-block:: python3 + class MyGear(Gear): + @Gear.listen() + async def listen(self, event: Ready) -> None: + print(f"Received event on instance: {event.__class__.__name__}") + + + my_gear = MyGear() + + + @my_gear.listen() + async def on_event(event: Ready) -> None: + print(f"Received event on bare: {event.__class__.__name__}") + + + bot.add_gear(my_gear) + """ + + def __init__(self) -> None: + self._listeners: dict[type[Event], set[EventCallback[Event]]] = defaultdict(set) + self._once_listeners: set[EventCallback[Event]] = set() + self._init_called: bool = True + + self._gears: set[Gear] = set() + + for name in dir(type(self)): + attr = getattr(type(self), name, None) + if not callable(attr): + continue + if isinstance(attr, StaticAttributedEventCallback): + callback = attr + event = attr.__event__ + once = attr.__once__ + elif isinstance(attr, AttributedEventCallback): + callback = partial(attr, self) + event = attr.__event__ + once = attr.__once__ + else: + continue + self.add_listener(cast("EventCallback[Event]", callback), event=event, once=once) + setattr(self, name, callback) + + def _handle_event(self, event: Event) -> Collection[Awaitable[Any]]: + tasks: list[Awaitable[None]] = [] + + for listener in self._listeners[type(event)]: + if listener in self._once_listeners: + self._once_listeners.remove(listener) + tasks.append(listener(event)) + + for gear in self._gears: + tasks.extend(gear._handle_event(event)) + + return tasks + + def attach_gear(self, gear: "Gear") -> None: + """Attaches a gear to this gear. + + This will propagate all events from the attached gear to this gear. + + Parameters + ---------- + gear: + The gear to attach. + """ + if not getattr(gear, "_init_called", False): + raise RuntimeError( + "Cannot attach gear before __init__ has been called. Maybe you forgot to call super().__init__()?" + ) + self._gears.add(gear) + + def detach_gear(self, gear: "Gear") -> None: + """Detaches a gear from this gear. + + Parameters + ---------- + gear: + The gear to detach. + + Raises + ------ + KeyError + If the gear is not attached. + """ + self._gears.remove(gear) + + @staticmethod + def _parse_listener_signature( + callback: Callable[[E], Awaitable[None]], is_instance_function: bool = False + ) -> type[E]: + params = get_annotations( + callback, + expected_types={0: type(Event)}, + custom_error="""Type annotation mismatch for parameter "{parameter}": expected , got {got}.""", + ) + if is_instance_function: + event = list(params.values())[1] + else: + event = next(iter(params.values())) + return cast(type[E], event) + + def add_listener( + self, + callback: Callable[[E], Awaitable[None]], + *, + event: type[E] | Undefined = MISSING, + is_instance_function: bool = False, + once: bool = False, + ) -> None: + """ + Adds an event listener to the gear. + + Parameters + ---------- + callback: + The callback function to be called when the event is emitted. + event: + The type of event to listen for. If not provided, it will be inferred from the callback signature. + once: + Whether the listener should be removed after being called once. + is_instance_function: + Whether the callback is an instance method (i.e., it takes the gear instance as the first argument). + + Raises + ------ + TypeError + If the event type cannot be inferred from the callback signature. + """ + if event is MISSING: + event = self._parse_listener_signature(callback, is_instance_function) + self._listeners[event].add(cast("EventCallback[Event]", callback)) + + def remove_listener( + self, callback: EventCallback[E], event: type[E] | Undefined = MISSING, is_instance_function: bool = False + ) -> None: + """ + Removes an event listener from the gear. + + Parameters + ---------- + callback: + The callback function to be removed. + event: + The type of event the listener was registered for. If not provided, it will be inferred from the callback signature. + is_instance_function: + Whether the callback is an instance method (i.e., it takes the gear instance as the first argument). + + Raises + ------ + TypeError + If the event type cannot be inferred from the callback signature. + KeyError + If the listener is not found. + """ + if event is MISSING: + event = self._parse_listener_signature(callback) + self._listeners[event].remove(cast("EventCallback[Event]", callback)) + + if TYPE_CHECKING: + + @classmethod + def listen( + cls: type[_T], + event: type[E] | Undefined = MISSING, # pyright: ignore[reportUnusedParameter] + once: bool = False, + ) -> Callable[ + [Callable[[E], Awaitable[None]] | Callable[[Any, E], Awaitable[None]]], + EventCallback[E], + ]: + """ + A decorator that registers an event listener. + + Parameters + ---------- + event: + The type of event to listen for. If not provided, it will be inferred from the callback signature. + once: + Whether the listener should be removed after being called once. + + Returns + ------- + A decorator that registers the decorated function as an event listener. + + Raises + ------ + TypeError + If the event type cannot be inferred from the callback signature. + """ + ... + else: + # Instance function events (but not bound to an instance, this is why we have to manually pass self with partial above) + @hybridmethod + def listen( + cls: type[_T], # noqa: N805 # Ruff complains of our shenanigans here + event: type[E] | Undefined = MISSING, + once: bool = False, + ) -> Callable[[Callable[[Any, E], Awaitable[None]]], Callable[[Any, E], Awaitable[None]]]: + def decorator(func: Callable[[Any, E], Awaitable[None]]) -> Callable[[Any, E], Awaitable[None]]: + if isinstance(func, staticmethod): + func.__func__.__event__ = event + func.__func__.__once__ = once + func.__func__.__staticmethod__ = True + else: + func.__event__ = event + func.__once__ = once + return func + + return decorator + + # Bare events (everything else) + @listen.instancemethod + def listen( + self, event: type[E] | Undefined = MISSING, once: bool = False + ) -> Callable[[Callable[[E], Awaitable[None]]], EventCallback[E]]: + def decorator(func: Callable[[E], Awaitable[None]]) -> EventCallback[E]: + self.add_listener(func, event=event, is_instance_function=False, once=once) + return cast(EventCallback[E], func) + + return decorator diff --git a/discord/guild.py b/discord/guild.py index 9b390a3586..1dfc4491d4 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -50,6 +50,7 @@ from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata from .channel import * from .channel import _guild_channel_factory, _threaded_guild_channel_factory +from .channel.thread import Thread, ThreadMember from .colour import Colour from .emoji import GuildEmoji, PartialEmoji, _EmojiTag from .enums import ( @@ -92,7 +93,6 @@ from .soundboard import SoundboardSound from .stage_instance import StageInstance from .sticker import GuildSticker -from .threads import Thread, ThreadMember from .user import User from .utils.private import bytes_to_base64_data, get_as_snowflake from .welcome_screen import WelcomeScreen, WelcomeScreenChannel @@ -324,7 +324,7 @@ def _voice_state_for(self, user_id: int, /) -> VoiceState | None: return self._voice_states.get(user_id) async def _add_member(self, member: Member, /) -> None: - await cast(ConnectionState, self._state).cache.store_member(member) + await cast("ConnectionState", self._state).cache.store_member(member) async def _get_and_update_member(self, payload: MemberPayload, user_id: int, cache_flag: bool, /) -> Member: members = await cast(ConnectionState, self._state).cache.get_guild_members(self.id) @@ -528,7 +528,7 @@ async def _from_data(cls, guild: GuildPayload, state: ConnectionState) -> Self: events.append(ScheduledEvent(state=self._state, guild=self, creator=creator, data=event)) self._scheduled_events_from_list(events) - self._sync(guild) + await self._sync(guild) self._large: bool | None = None if self._member_count is None else self._member_count >= 250 self.owner_id: int | None = get_as_snowflake(guild, "owner_id") @@ -539,7 +539,7 @@ async def _from_data(cls, guild: GuildPayload, state: ConnectionState) -> Self: for sound in guild.get("soundboard_sounds", []): sound = SoundboardSound(state=state, http=state.http, data=sound) - self._add_sound(sound) + await self._add_sound(sound) incidents_payload = guild.get("incidents_data") self.incidents_data: IncidentsData | None = ( @@ -547,9 +547,9 @@ async def _from_data(cls, guild: GuildPayload, state: ConnectionState) -> Self: ) return self - def _add_sound(self, sound: SoundboardSound) -> None: + async def _add_sound(self, sound: SoundboardSound) -> None: self._sounds[sound.id] = sound - self._state._add_sound(sound) + await self._state._add_sound(sound) def _remove_sound(self, sound_id: int) -> None: self._sounds.pop(sound_id, None) @@ -669,7 +669,7 @@ async def create_sound( ) # TODO: refactor/remove? - def _sync(self, data: GuildPayload) -> None: + async def _sync(self, data: GuildPayload) -> None: try: self._large = data["large"] except KeyError: @@ -687,12 +687,12 @@ def _sync(self, data: GuildPayload) -> None: for c in channels: factory, _ch_type = _guild_channel_factory(c["type"]) if factory: - self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore + self._add_channel(await factory._from_data(guild=self, data=c, state=self._state)) # type: ignore if "threads" in data: threads = data["threads"] for thread in threads: - self._add_thread(Thread(guild=self, state=self._state, data=thread)) + self._add_thread(await Thread._from_data(guild=self, state=self._state, data=thread)) @property def channels(self) -> list[GuildChannel]: @@ -990,7 +990,7 @@ async def get_member(self, user_id: int, /) -> Member | None: Optional[:class:`Member`] The member or ``None`` if not found. """ - return await cast(ConnectionState, self._state).cache.get_member(self.id, user_id) + return await cast("ConnectionState", self._state).cache.get_member(self.id, user_id) @property def premium_subscribers(self) -> list[Member]: @@ -2177,7 +2177,7 @@ async def fetch_channels(self) -> Sequence[GuildChannel]: Returns ------- - Sequence[:class:`abc.GuildChannel`] + Sequence[:class:`discord.channel.base.GuildChannel`] All channels in the guild. Raises diff --git a/discord/interactions.py b/discord/interactions.py index 3ab6cbaff5..a7839d500d 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -83,7 +83,7 @@ from .embeds import Embed from .mentions import AllowedMentions from .poll import Poll - from .threads import Thread + from .channel.thread import Thread from .types.interactions import Interaction as InteractionPayload from .types.interactions import InteractionCallback as InteractionCallbackPayload from .types.interactions import InteractionCallbackResponse, InteractionData @@ -265,12 +265,12 @@ async def load_data(self): self._guild: Guild | None = None self._guild_data = data.get("guild") - if self.guild is None and self._guild_data: + if self._guild is None and self._guild_data: self._guild = await Guild._from_data(data=self._guild_data, state=self._state) # TODO: there's a potential data loss here if self.guild_id: - guild = self.guild or await self._state._get_guild(self.guild_id) or Object(id=self.guild_id) + guild = self._guild or await self._state._get_guild(self.guild_id) or Object(id=self.guild_id) try: member = data["member"] # type: ignore except KeyError: @@ -294,7 +294,7 @@ async def load_data(self): if data_ch_type is not None: factory, ch_type = _threaded_channel_factory(data_ch_type) if ch_type in (ChannelType.group, ChannelType.private): - self.channel = factory(me=self.user, data=channel, state=self._state) + self.channel = await factory._from_data(data=channel, state=self._state) if self.channel is None and self.guild: self.channel = self.guild._resolve_channel(self.channel_id) @@ -898,7 +898,9 @@ async def _process_callback_response(self, callback_response: InteractionCallbac "Channel for message could not be resolved. Please open a issue on GitHub if you encounter this error." ) state = _InteractionMessageState(self._parent, self._parent._state) - message = InteractionMessage(state=state, channel=channel, data=callback_response["resource"]["message"]) # type: ignore + message = await InteractionMessage._from_data( + state=state, channel=channel, data=callback_response["resource"]["message"] + ) # type: ignore self._parent._original_response = message self._parent.callback = InteractionCallback(callback_response["interaction"]) diff --git a/discord/invite.py b/discord/invite.py index 83f53fa16e..1611ef124f 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -42,8 +42,8 @@ ) if TYPE_CHECKING: - from .abc import GuildChannel from .app.state import ConnectionState + from .channel.base import GuildChannel from .guild import Guild from .scheduled_events import ScheduledEvent from .types.channel import PartialChannel as InviteChannelPayload diff --git a/discord/iterators.py b/discord/iterators.py index d61cbd2695..0934dae44c 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -64,7 +64,7 @@ from .message import Message, MessagePin from .monetization import Entitlement, Subscription from .scheduled_events import ScheduledEvent - from .threads import Thread + from .channel.thread import Thread from .types.audit_log import AuditLog as AuditLogPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload @@ -848,7 +848,7 @@ async def fill_queue(self) -> None: self.before = self.update_before(threads[-1]) def create_thread(self, data: ThreadPayload) -> Thread: - from .threads import Thread # noqa: PLC0415 + from .channel.thread import Thread # noqa: PLC0415 return Thread(guild=self.guild, state=self.guild._state, data=data) diff --git a/discord/member.py b/discord/member.py index 7e411884b7..04f69246af 100644 --- a/discord/member.py +++ b/discord/member.py @@ -140,7 +140,7 @@ def __init__( self.session_id: str = data.get("session_id") self._update(data, channel) - async def _update( + def _update( self, data: VoiceStatePayload | GuildVoiceStatePayload, channel: VocalGuildChannel | None, @@ -422,7 +422,7 @@ async def _get_channel(self): ch = await self.create_dm() return ch - async def _update(self, data: MemberPayload) -> None: + def _update(self, data: MemberPayload) -> None: # the nickname change is optional, # if it isn't in the payload then it didn't change try: diff --git a/discord/message.py b/discord/message.py index e6aa2e15ad..e1318143cf 100644 --- a/discord/message.py +++ b/discord/message.py @@ -61,19 +61,19 @@ from .poll import Poll from .reaction import Reaction from .sticker import StickerItem -from .threads import Thread +from .channel.thread import Thread from .utils import MISSING, escape_mentions from .utils.private import cached_slot_property, delay_task, get_as_snowflake, parse_time, warn_deprecated if TYPE_CHECKING: from .abc import ( - GuildChannel, MessageableChannel, PartialMessageableChannel, Snowflake, ) from .app.state import ConnectionState from .channel import TextChannel + from .channel.base import GuildChannel from .components import Component from .interactions import MessageInteraction from .mentions import AllowedMentions @@ -1085,14 +1085,14 @@ async def _from_data( found = await self.guild.get_member(self.author.id) if found is not None: self.author = found - - try: - # Update member reference - self.author._update_from_message(member) # type: ignore # noqa: F821 # TODO: member is unbound - except AttributeError: - # It's a user here - # TODO: consider adding to cache here - self.author = Member._from_message(message=self, data=data["member"]) + if data.get("member"): + try: + # Update member reference + self.author._update_from_message(member) # type: ignore # noqa: F821 # TODO: member is unbound + except AttributeError: + # It's a user here + # TODO: consider adding to cache here + self.author = Member._from_message(message=self, data=data["member"]) self.mentions = r = [] if not isinstance(self.guild, Guild): diff --git a/discord/onboarding.py b/discord/onboarding.py index 83c52466c5..d79fa6b07e 100644 --- a/discord/onboarding.py +++ b/discord/onboarding.py @@ -243,7 +243,7 @@ def __init__(self, data: OnboardingPayload, guild: Guild): def __repr__(self): return f"" - async def _update(self, data: OnboardingPayload): + def _update(self, data: OnboardingPayload): self.guild_id: Snowflake = data["guild_id"] self.prompts: list[OnboardingPrompt] = [ OnboardingPrompt._from_dict(prompt, self.guild) for prompt in data.get("prompts", []) diff --git a/discord/raw_models.py b/discord/raw_models.py index 9a91c3ce35..284a40ff20 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -47,7 +47,7 @@ from .message import Message from .partial_emoji import PartialEmoji from .soundboard import PartialSoundboardSound, SoundboardSound - from .threads import Thread + from .channel.thread import Thread from .types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend from .types.raw_models import ( AuditLogEntryEvent, diff --git a/discord/role.py b/discord/role.py index 9a1d69d3eb..7944a96d75 100644 --- a/discord/role.py +++ b/discord/role.py @@ -376,7 +376,7 @@ def __ge__(self: R, other: R) -> bool: return NotImplemented return not r - async def _update(self, data: RolePayload): + def _update(self, data: RolePayload): self.name: str = data["name"] self._permissions: int = int(data.get("permissions", 0)) self.position: int = data.get("position", 0) diff --git a/discord/stage_instance.py b/discord/stage_instance.py index e5ea31f0e0..ce5b7c4da2 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -95,7 +95,7 @@ def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstanceP self.guild = guild self._update(data) - async def _update(self, data: StageInstancePayload): + def _update(self, data: StageInstancePayload): self.id: int = int(data["id"]) self.channel_id: int = int(data["channel_id"]) self.topic: str = data["topic"] diff --git a/discord/types/channel.py b/discord/types/channel.py index 72096e7a73..d618a9331a 100644 --- a/discord/types/channel.py +++ b/discord/types/channel.py @@ -51,10 +51,14 @@ class PermissionOverwrite(TypedDict): class _BaseChannel(TypedDict): id: Snowflake + type: int + + +class _BaseNamedChannel(_BaseChannel): name: str -class _BaseGuildChannel(_BaseChannel): +class _BaseGuildChannel(_BaseNamedChannel): guild_id: Snowflake position: int permission_overwrites: list[PermissionOverwrite] @@ -62,7 +66,7 @@ class _BaseGuildChannel(_BaseChannel): parent_id: Snowflake | None -class PartialChannel(_BaseChannel): +class PartialChannel(_BaseNamedChannel): type: ChannelType @@ -128,7 +132,7 @@ class StageChannel(_BaseGuildChannel): user_limit: int -class ThreadChannel(_BaseChannel): +class ThreadChannel(_BaseNamedChannel): member: NotRequired[ThreadMember] owner_id: NotRequired[Snowflake] rate_limit_per_user: NotRequired[int] @@ -149,17 +153,17 @@ class ThreadChannel(_BaseChannel): GuildChannel = TextChannel | NewsChannel | VoiceChannel | CategoryChannel | StageChannel | ThreadChannel | ForumChannel -class DMChannel(TypedDict): - id: Snowflake +class DMChannel(_BaseChannel): type: Literal[1] last_message_id: Snowflake | None recipients: list[User] -class GroupDMChannel(_BaseChannel): +class GroupDMChannel(_BaseNamedChannel, DMChannel): type: Literal[3] icon: str | None owner_id: Snowflake + name: str Channel = GuildChannel | DMChannel | GroupDMChannel diff --git a/discord/ui/select.py b/discord/ui/select.py index 6ee54a0514..5142074155 100644 --- a/discord/ui/select.py +++ b/discord/ui/select.py @@ -40,7 +40,7 @@ from ..member import Member from ..partial_emoji import PartialEmoji from ..role import Role -from ..threads import Thread +from ..channel.thread import Thread from ..user import User from ..utils import MISSING from .item import Item, ItemCallbackType @@ -58,7 +58,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..abc import GuildChannel + from ..channel.base import GuildChannel from ..types.components import SelectMenu as SelectMenuPayload from ..types.interactions import ComponentInteractionData from .view import View diff --git a/discord/user.py b/discord/user.py index e945154243..82d9353d01 100644 --- a/discord/user.py +++ b/discord/user.py @@ -132,7 +132,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return self.id >> 22 - async def _update(self, data: UserPayload) -> None: + def _update(self, data: UserPayload) -> None: self.name = data["username"] self.id = int(data["id"]) self.discriminator = data["discriminator"] @@ -424,7 +424,7 @@ def __repr__(self) -> str: f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>" ) - async def _update(self, data: UserPayload) -> None: + def _update(self, data: UserPayload) -> None: super()._update(data) # There's actually an Optional[str] phone field as well, but I won't use it self.verified = data.get("verified", False) diff --git a/discord/utils/annotations.py b/discord/utils/annotations.py new file mode 100644 index 0000000000..049e01ed42 --- /dev/null +++ b/discord/utils/annotations.py @@ -0,0 +1,207 @@ +import ast +import functools +import inspect +import textwrap +from typing import Any, overload + +from ..errors import AnnotationMismatch + + +def _param_spans(obj: Any) -> dict[str, tuple[int, int, int, int, str]]: + """ + Get the source code spans for each parameter's annotation in a function. + Returns a mapping of parameter name to a tuple of + (start_line, start_col_1b, end_line, end_col_1b, line_text). + 1b = 1-based column offset. + + Parameters + ---------- + obj: + The function or method to analyze. + + Returns + ------- + dict[str, tuple[int, int, int, int, str]] + Mapping of parameter names to their annotation spans. + """ + src, start_line = inspect.getsourcelines(obj) # original (indented) lines + filename = inspect.getsourcefile(obj) or "" + + # Compute common indent that dedent will remove + non_empty = [l for l in src if l.strip()] + common_indent = min((len(l) - len(l.lstrip(" "))) for l in non_empty) if non_empty else 0 + + # Parse a DEDENTED copy to get stable AST coords + dedented = textwrap.dedent("".join(src)) + mod = ast.parse(dedented, filename=filename, mode="exec", type_comments=True) + + fn = next((n for n in mod.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))), None) + if fn is None: + return {} + + def _collect_args(a: ast.arguments) -> list[tuple[ast.arg, ast.expr | None]]: + out: list[tuple[ast.arg, ast.expr | None]] = [] + for ar in getattr(a, "posonlyargs", []): + out.append((ar, ar.annotation)) + for ar in a.args: + out.append((ar, ar.annotation)) + if a.vararg: + out.append((a.vararg, a.vararg.annotation)) + for ar in a.kwonlyargs: + out.append((ar, ar.annotation)) + if a.kwarg: + out.append((a.kwarg, a.kwarg.annotation)) + return out + + args = _collect_args(fn.args) + + def _line_text_file(lineno_file: int) -> str: + idx = lineno_file - start_line + if 0 <= idx < len(src): + return src[idx].rstrip("\n") + return "" + + spans: dict[str, tuple[int, int, int, int, str]] = {} + + for ar, ann in args: + name = ar.arg + + # AST positions are snippet-relative: lineno 1-based, col_offset 0-based + ln_snip = getattr(ar, "lineno", 1) + col0_snip = getattr(ar, "col_offset", 0) + + # Prefer annotation end if present; otherwise end at end of the name + if ann is not None and hasattr(ann, "end_lineno") and hasattr(ann, "end_col_offset"): + end_ln_snip = ann.end_lineno + end_col0_snip = ann.end_col_offset + else: + end_ln_snip = ln_snip + end_col0_snip = col0_snip + len(name) + + # Convert SNIPPET positions -> FILE positions + ln_file = start_line + (ln_snip - 1) + end_ln_file = start_line + (end_ln_snip - 1) + + # Add back the common indent that dedent removed; convert to 1-based + col_1b_file = col0_snip + 1 + common_indent + end_col_1b_file = end_col0_snip + 1 + common_indent + + line_text = _line_text_file(ln_file) + # Guard: keep columns within the line + line_len_1b = len(line_text) + 1 + col_1b_file = max(1, min(col_1b_file, line_len_1b)) + end_col_1b_file = max(col_1b_file, min(end_col_1b_file, line_len_1b)) + + spans[name] = (ln_file, col_1b_file, end_ln_file, end_col_1b_file, line_text) + + return spans + + +def _unwrap_partial(func: Any) -> Any: + while isinstance(func, functools.partial): + func = func.func + return func + + +@overload +def get_annotations( + obj: Any, + *, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + eval_str: bool = False, + expected_types: None = None, + custom_error: None = None, +) -> dict[str, Any]: ... + + +@overload +def get_annotations( + obj: Any, + *, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + eval_str: bool = False, + expected_types: dict[int, type], + custom_error: str | None = None, +) -> dict[str, Any]: ... + + +def get_annotations( + obj: Any, + *, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + eval_str: bool = False, + expected_types: dict[int, type] | None = None, + custom_error: str | None = None, +) -> dict[str, Any]: + """ + Get the type annotations of a function or method, with optional type checking. + + This function unwraps `functools.partial` objects to access the original function. + + This function is a modified version of `inspect.get_annotations` that adds the ability to check parameter types. + + .. note:: + This function is not intended to be used by end-users. + + Parameters + ---------- + obj: + The function or method to inspect. + globals: + The global namespace to use for evaluating string annotations. + locals: + The local namespace to use for evaluating string annotations. + eval_str: + Whether to evaluate string annotations. + expected_types: + A mapping of parameter index to expected type for type checking. + custom_error: + A custom error message format for type mismatches. Supports the following format fields: + - parameter: The name of the parameter with the mismatch. + - expected: The expected type. + - got: The actual type found. + + Returns + ------- + dict[str, Any] + A mapping of parameter names to their type annotations. + """ + unwrapped_obj = _unwrap_partial(obj) + r = inspect.get_annotations(unwrapped_obj, globals=globals, locals=locals, eval_str=eval_str) + + if expected_types is not None: + for i, (k, v) in enumerate(r.items()): + if i in expected_types and not isinstance(v, expected_types[i]): + error = AnnotationMismatch( + ( + custom_error + or 'Type annotation mismatch for parameter "{parameter}": expected {expected}, got {got}' + ).format( + parameter=k, + expected=repr(expected_types[i]), + got=repr(r[k]), + ) + ) + spans = _param_spans(unwrapped_obj) + + if k in spans: + ln, col_1b, end_ln, end_col_1b, line_text = spans[k] + else: + ln = unwrapped_obj.__code__.co_firstlineno + line_text = inspect.getsource(unwrapped_obj).splitlines()[0] + col_1b, end_ln, end_col_1b = 1, ln, len(line_text) + 1 + error.filename = unwrapped_obj.__code__.co_filename + error.lineno = ln + error.offset = col_1b + error.end_lineno = end_ln + error.end_offset = end_col_1b + error.text = line_text + raise error + + return r + + +__all__ = ("get_annotations", "AnnotationMismatch") diff --git a/discord/utils/hybridmethod.py b/discord/utils/hybridmethod.py new file mode 100644 index 0000000000..8e774aeb67 --- /dev/null +++ b/discord/utils/hybridmethod.py @@ -0,0 +1,49 @@ +# Source - https://stackoverflow.com/questions/28237955/same-name-for-classmethod-and-instancemethod +# Posted by Martijn Pieters +# Retrieved 11/5/2025, License - CC-BY-SA 4.0 + +from typing import Callable, Generic, Protocol, TypeVar, overload + +from typing_extensions import Concatenate, ParamSpec, Self, override + +_T = TypeVar("_T") +_R1_co = TypeVar("_R1_co", covariant=True) +_R2_co = TypeVar("_R2_co", covariant=True) +_P = ParamSpec("_P") + + +class hybridmethod(Generic[_T, _P, _R1_co, _R2_co]): + fclass: Callable[Concatenate[type[_T], _P], _R1_co] + finstance: Callable[Concatenate[_T, _P], _R2_co] | None + __doc__: str | None + __isabstractmethod__: bool + + def __init__( + self, + fclass: Callable[Concatenate[type[_T], _P], _R1_co], + finstance: Callable[Concatenate[_T, _P], _R2_co] | None = None, + doc: str | None = None, + ): + self.fclass = fclass + self.finstance = finstance + self.__doc__ = doc or fclass.__doc__ + # support use on abstract base classes + self.__isabstractmethod__ = bool(getattr(fclass, "__isabstractmethod__", False)) + + def classmethod(self, fclass: Callable[Concatenate[type[_T], _P], _R1_co]) -> Self: + return type(self)(fclass, self.finstance, None) + + def instancemethod(self, finstance: Callable[Concatenate[_T, _P], _R2_co]) -> Self: + return type(self)(self.fclass, finstance, self.__doc__) + + @overload + def __get__(self, instance: None, cls: type[_T]) -> Callable[_P, _R1_co]: ... + + @overload + def __get__(self, instance: _T, cls: type[_T] | None = ...) -> Callable[_P, _R1_co] | Callable[_P, _R2_co]: ... + + def __get__(self, instance: _T | None, cls: type[_T] | None = None) -> Callable[_P, _R1_co] | Callable[_P, _R2_co]: + if instance is None or self.finstance is None: + # either bound to the class, or no instance method available + return self.fclass.__get__(cls, None) + return self.finstance.__get__(instance, cls) diff --git a/discord/utils/private.py b/discord/utils/private.py index 951debc6d0..7e28aafeb0 100644 --- a/discord/utils/private.py +++ b/discord/utils/private.py @@ -33,6 +33,7 @@ ) from ..errors import HTTPException, InvalidArgument +from .hybridmethod import hybridmethod if TYPE_CHECKING: from ..invite import Invite @@ -554,3 +555,31 @@ def to_json(obj: Any) -> str: # type: ignore[reportUnusedFunction] return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) from_json = json.loads +__all__ = ( + "deprecated", + "flatten_literal_params", + "normalise_optional_params", + "evaluate_annotation", + "resolve_annotation", + "delay_task", + "async_all", + "maybe_awaitable", + "sane_wait_for", + "SnowflakeList", + "copy_doc", + "SequenceProxy", + "CachedSlotProperty", + "get_slots", + "cached_slot_property", + "to_json", + "from_json", + "get_as_snowflake", + "get_mime_type_for_file", + "bytes_to_base64_data", + "parse_ratelimit_header", + "string_width", + "resolve_template", + "warn_deprecated", + "hybridmethod", + "resolve_invite", +) diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index c5ad2d019c..2f0a320cac 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -53,7 +53,7 @@ from ..message import Attachment, Message from ..mixins import Hashable from ..object import Object -from ..threads import Thread +from ..channel.thread import Thread from ..user import BaseUser, User from ..utils.private import bytes_to_base64_data, get_as_snowflake, parse_ratelimit_header, to_json @@ -1028,7 +1028,7 @@ def __init__( self._state: ConnectionState | _WebhookState = state or _WebhookState(self, parent=state) self._update(data) - async def _update(self, data: WebhookPayload | FollowerWebhookPayload): + def _update(self, data: WebhookPayload | FollowerWebhookPayload): self.id = int(data["id"]) self.type = try_enum(WebhookType, int(data["type"])) self.channel_id = get_as_snowflake(data, "channel_id") diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index acc15eadde..a0b66a6525 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -52,7 +52,7 @@ from ..http import Route from ..message import Message from ..object import Object -from ..threads import Thread +from ..channel.thread import Thread from ..utils.private import bytes_to_base64_data, parse_ratelimit_header, to_json from .async_ import BaseWebhook, _WebhookState, handle_message_parameters diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index a41cbbe0a5..57286c1c70 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -128,7 +128,7 @@ def __init__(self, data: WelcomeScreenPayload, guild: Guild): def __repr__(self): return f" None: self.id: int = int(data["id"]) self.channels: list[WidgetChannel] = [] - for channel in data.get("channels", []): + for channel in data.get("channel", []): _id = int(channel["id"]) self.channels.append(WidgetChannel(id=_id, name=channel["name"], position=channel["position"])) diff --git a/docs/api/application_commands.rst b/docs/api/application_commands.rst index 2fe39f0760..e31e890c9b 100644 --- a/docs/api/application_commands.rst +++ b/docs/api/application_commands.rst @@ -64,11 +64,6 @@ Objects Options ------- -Shortcut Decorators -~~~~~~~~~~~~~~~~~~~ -.. autofunction:: discord.commands.option - :decorator: - Objects ~~~~~~~ diff --git a/docs/api/clients.rst b/docs/api/clients.rst index f5b03d8e6c..bfde02aac4 100644 --- a/docs/api/clients.rst +++ b/docs/api/clients.rst @@ -10,14 +10,11 @@ Bots .. autoclass:: Bot :members: :inherited-members: - :exclude-members: command, event, message_command, slash_command, user_command, listen + :exclude-members: command, message_command, slash_command, user_command, listen .. automethod:: Bot.command(**kwargs) :decorator: - .. automethod:: Bot.event() - :decorator: - .. automethod:: Bot.message_command(**kwargs) :decorator: @@ -27,7 +24,7 @@ Bots .. automethod:: Bot.user_command(**kwargs) :decorator: - .. automethod:: Bot.listen(name=None, once=False) + .. automethod:: Bot.listen(event, once=False) :decorator: .. attributetable:: AutoShardedBot @@ -41,15 +38,12 @@ Clients .. attributetable:: Client .. autoclass:: Client :members: - :exclude-members: fetch_guilds, event, listen - - .. automethod:: Client.event() - :decorator: + :exclude-members: fetch_guilds, listen .. automethod:: Client.fetch_guilds :async-for: - .. automethod:: Client.listen(name=None, once=False) + .. automethod:: Client.listen(event, once=False) :decorator: .. attributetable:: AutoShardedClient diff --git a/docs/api/events.rst b/docs/api/events.rst index c948fe972b..2ccef212d9 100644 --- a/docs/api/events.rst +++ b/docs/api/events.rst @@ -5,1493 +5,481 @@ Event Reference =============== -This section outlines the different types of events listened by :class:`Client`. +This section outlines the different types of events in Pycord. Events are class-based objects that inherit from +:class:`~discord.app.event_emitter.Event` and are dispatched by the Discord gateway when certain actions occur. -There are 3 ways to register an event, the first way is through the use of -:meth:`Client.event`. The second way is through subclassing :class:`Client` and -overriding the specific events. The third way is through the use of :meth:`Client.listen`, -which can be used to assign multiple event handlers instead of only one like in :meth:`Client.event`. -For example: +.. seealso:: -.. code-block:: python - :emphasize-lines: 17, 22 + For information about the Gears system and modular event handling, see :ref:`discord_api_gears`. - import discord +Listening to Events +------------------- - class MyClient(discord.Client): - async def on_message(self, message): - if message.author == self.user: - return +There are two main ways to listen to events in Pycord: - if message.content.startswith('$hello'): - await message.channel.send('Hello World!') +1. **Using** :meth:`Client.listen` **decorator** - This allows you to register typed event listeners directly on the client. +2. **Using Gears** - A modular event handling system that allows you to organize event listeners into reusable components. +Using the listen() Decorator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - intents = discord.Intents.default() - intents.message_content = True # Needed to see message content - client = MyClient(intents=intents) +The modern way to register event listeners is by using the :meth:`~Client.listen` decorator with typed event classes: - # Overrides the 'on_message' method defined in MyClient - @client.event - async def on_message(message: discord.Message): - print(f"Received {message.content}") +.. code-block:: python3 + + import discord + from discord.events import MessageCreate, Ready + + client = discord.Client(intents=discord.Intents.default()) - # Assigns an ADDITIONAL handler @client.listen() - async def on_message(message: discord.Message): - print(f"Received {message.content}") + async def on_message(event: MessageCreate) -> None: + if event.author == client.user: + return + if event.content.startswith('$hello'): + await event.channel.send('Hello World!') - # Runs only for the 1st event dispatch. Can be useful for listening to 'on_ready' @client.listen(once=True) - async def on_ready(): + async def on_ready(event: Ready) -> None: print("Client is ready!") + client.run("TOKEN") -If an event handler raises an exception, :func:`on_error` will be called -to handle it, which defaults to print a traceback and ignoring the exception. +Note that: -.. warning:: +- Event listeners use type annotations to specify which event they handle +- Event objects may inherit from domain models (e.g., ``MessageCreate`` inherits from ``Message``) +- The ``once=True`` parameter creates a one-time listener that is automatically removed after being called once +- All event listeners must be coroutines (``async def`` functions) - All the events must be a |coroutine_link|_. If they aren't, then you might get unexpected - errors. In order to turn a function into a coroutine they must be ``async def`` - functions. +Using Gears +~~~~~~~~~~~ -Application Commands --------------------- -.. function:: on_application_command(context) +For more organized code, especially in larger bots, you can use the Gears system: - Called when an application command is received. +.. code-block:: python3 - .. versionadded:: 2.0 + from discord.gears import Gear + from discord.events import Ready, MessageCreate - :param context: The ApplicationContext associated to the command being received. - :type context: :class:`ApplicationContext` + class MyGear(Gear): + @Gear.listen() + async def on_ready(self, event: Ready) -> None: + print("Bot is ready!") -.. function:: on_application_command_completion(context) + @Gear.listen() + async def on_message(self, event: MessageCreate) -> None: + print(f"Message: {event.content}") - Called when an application command is completed, after any checks have finished. + bot = discord.Bot() + bot.attach_gear(MyGear()) + bot.run("TOKEN") - .. versionadded:: 2.0 +See :ref:`discord_api_gears` for more information on using Gears. - :param context: The ApplicationContext associated to the command that was completed. - :type context: :class:`ApplicationContext` +.. warning:: -.. function:: on_application_command_error(context, exception) + All event listeners must be |coroutine_link|_. If they aren't, then you might get unexpected + errors. In order to turn a function into a coroutine they must be ``async def`` functions. - Called when an application command has an error. +Event Classes +------------- - .. versionadded:: 2.0 +All events inherit from the base :class:`~discord.app.event_emitter.Event` class. Events are typed objects that +contain data related to the specific Discord gateway event that occurred. - :param context: The ApplicationContext associated to the command that has an error. - :type context: :class:`ApplicationContext` +Some event classes inherit from domain models, meaning they have all the attributes and methods of that model. +For example: - :param exception: The DiscordException associated to the error. - :type exception: :class:`DiscordException` +- :class:`~discord.events.MessageCreate` inherits from :class:`Message` +- :class:`~discord.events.GuildMemberJoin` inherits from :class:`Member` +- :class:`~discord.events.GuildJoin` inherits from :class:`Guild` -.. function:: on_unknown_application_command(interaction) +Events that don't inherit from a domain model will have specific attributes for accessing event data. - Called when an application command was not found in the bot's internal cache. +Many events also include a ``raw`` attribute that contains the raw event payload data from Discord, which can be +useful for accessing data that may not be in the cache. - .. versionadded:: 2.0 +Available Events +---------------- - :param interaction: The interaction associated to the unknown command. - :type interaction: :class:`Interaction` +Below is a comprehensive list of all events available in Pycord, organized by category. Audit Logs ----------- - -.. function:: on_audit_log_entry(entry) - - Called when an audit log entry is created. - - The bot must have :attr:`~Permissions.view_audit_log` to receive this, and - :attr:`Intents.moderation` must be enabled. - - .. versionadded:: 2.5 +~~~~~~~~~~ - :param entry: The audit log entry that was created. - :type entry: :class:`AuditLogEntry` - -.. function:: on_raw_audit_log_entry(payload) - - Called when an audit log entry is created. Unlike - :func:`on_audit_log_entry`, this is called regardless of the state of the internal - user cache. - - The bot must have :attr:`~Permissions.view_audit_log` to receive this, and - :attr:`Intents.moderation` must be enabled. - - .. versionadded:: 2.5 - - :param payload: The raw event payload data. - :type payload: :class:`RawAuditLogEntryEvent` +.. autoclass:: discord.events.GuildAuditLogEntryCreate() + :members: + :inherited-members: AutoMod -------- -.. function:: on_auto_moderation_rule_create(rule) - - Called when an auto moderation rule is created. - - The bot must have :attr:`~Permissions.manage_guild` to receive this, and - :attr:`Intents.auto_moderation_configuration` must be enabled. - - :param rule: The newly created rule. - :type rule: :class:`AutoModRule` - -.. function:: on_auto_moderation_rule_update(rule) - - Called when an auto moderation rule is updated. - - The bot must have :attr:`~Permissions.manage_guild` to receive this, and - :attr:`Intents.auto_moderation_configuration` must be enabled. - - :param rule: The updated rule. - :type rule: :class:`AutoModRule` - -.. function:: on_auto_moderation_rule_delete(rule) +~~~~~~~ - Called when an auto moderation rule is deleted. +.. autoclass:: discord.events.AutoModRuleCreate() + :members: + :inherited-members: - The bot must have :attr:`~Permissions.manage_guild` to receive this, and - :attr:`Intents.auto_moderation_configuration` must be enabled. +.. autoclass:: discord.events.AutoModRuleUpdate() + :members: + :inherited-members: - :param rule: The deleted rule. - :type rule: :class:`AutoModRule` +.. autoclass:: discord.events.AutoModRuleDelete() + :members: + :inherited-members: -.. function:: on_auto_moderation_action_execution(payload) - - Called when an auto moderation action is executed. - - The bot must have :attr:`~Permissions.manage_guild` to receive this, and - :attr:`Intents.auto_moderation_execution` must be enabled. - - :param payload: The event's data. - :type payload: :class:`AutoModActionExecutionEvent` - -Bans ----- -.. function:: on_member_ban(guild, user) - - Called when user gets banned from a :class:`Guild`. - - This requires :attr:`Intents.moderation` to be enabled. - - :param guild: The guild the user got banned from. - :type guild: :class:`Guild` - :param user: The user that got banned. - Can be either :class:`User` or :class:`Member` depending if - the user was in the guild or not at the time of removal. - :type user: Union[:class:`User`, :class:`Member`] - -.. function:: on_member_unban(guild, user) - - Called when a :class:`User` gets unbanned from a :class:`Guild`. - - This requires :attr:`Intents.moderation` to be enabled. - - :param guild: The guild the user got unbanned from. - :type guild: :class:`Guild` - :param user: The user that got unbanned. - :type user: :class:`User` +.. autoclass:: discord.events.AutoModActionExecution() + :members: + :inherited-members: Channels --------- -.. function:: on_private_channel_update(before, after) - - Called whenever a private group DM is updated. e.g. changed name or topic. - - This requires :attr:`Intents.messages` to be enabled. - - :param before: The updated group channel's old info. - :type before: :class:`GroupChannel` - :param after: The updated group channel's new info. - :type after: :class:`GroupChannel` - -.. function:: on_private_channel_pins_update(channel, last_pin) - - Called whenever a message is pinned or unpinned from a private channel. - - :param channel: The private channel that had its pins updated. - :type channel: :class:`abc.PrivateChannel` - :param last_pin: The latest message that was pinned as an aware datetime in UTC. Could be ``None``. - :type last_pin: Optional[:class:`datetime.datetime`] - -.. function:: on_guild_channel_update(before, after) - - Called whenever a guild channel is updated. e.g. changed name, topic, permissions. - - This requires :attr:`Intents.guilds` to be enabled. - - :param before: The updated guild channel's old info. - :type before: :class:`abc.GuildChannel` - :param after: The updated guild channel's new info. - :type after: :class:`abc.GuildChannel` - -.. function:: on_guild_channel_pins_update(channel, last_pin) - - Called whenever a message is pinned or unpinned from a guild channel. - - This requires :attr:`Intents.guilds` to be enabled. - - :param channel: The guild channel that had its pins updated. - :type channel: Union[:class:`abc.GuildChannel`, :class:`Thread`] - :param last_pin: The latest message that was pinned as an aware datetime in UTC. Could be ``None``. - :type last_pin: Optional[:class:`datetime.datetime`] - -.. function:: on_guild_channel_delete(channel) - on_guild_channel_create(channel) - - Called whenever a guild channel is deleted or created. +~~~~~~~~ - Note that you can get the guild from :attr:`~abc.GuildChannel.guild`. +.. autoclass:: discord.events.ChannelCreate() + :members: + :inherited-members: - This requires :attr:`Intents.guilds` to be enabled. +.. autoclass:: discord.events.ChannelDelete() + :members: + :inherited-members: - :param channel: The guild channel that got created or deleted. - :type channel: :class:`abc.GuildChannel` +.. autoclass:: discord.events.ChannelUpdate() + :members: + :inherited-members: -Connection ----------- -.. function:: on_error(event, *args, **kwargs) +.. autoclass:: discord.events.GuildChannelUpdate() + :members: + :inherited-members: - Usually when an event raises an uncaught exception, a traceback is - printed to stderr and the exception is ignored. If you want to - change this behaviour and handle the exception for whatever reason - yourself, this event can be overridden. Which, when done, will - suppress the default action of printing the traceback. +.. autoclass:: discord.events.PrivateChannelUpdate() + :members: + :inherited-members: - The information of the exception raised and the exception itself can - be retrieved with a standard call to :func:`sys.exc_info`. +.. autoclass:: discord.events.ChannelPinsUpdate() + :members: + :inherited-members: - If you want exception to propagate out of the :class:`Client` class - you can define an ``on_error`` handler consisting of a single empty - :ref:`raise statement `. Exceptions raised by ``on_error`` will not be - handled in any way by :class:`Client`. +Connection & Gateway +~~~~~~~~~~~~~~~~~~~~ - .. note:: +.. autoclass:: discord.events.Ready() + :members: + :inherited-members: - ``on_error`` will only be dispatched to :meth:`Client.event`. +.. autoclass:: discord.events.Resumed() + :members: + :inherited-members: - It will not be received by :meth:`Client.wait_for`, or, if used, - :ref:`ext_commands_api_bot` listeners such as - :meth:`~ext.commands.Bot.listen` or :meth:`~ext.commands.Cog.listener`. +Entitlements & Monetization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - :param event: The name of the event that raised the exception. - :type event: :class:`str` +.. autoclass:: discord.events.EntitlementCreate() + :members: + :inherited-members: - :param args: The positional arguments for the event that raised the - exception. - :param kwargs: The keyword arguments for the event that raised the - exception. +.. autoclass:: discord.events.EntitlementUpdate() + :members: + :inherited-members: -.. function:: on_connect() +.. autoclass:: discord.events.EntitlementDelete() + :members: + :inherited-members: - Called when the client has successfully connected to Discord. This is not - the same as the client being fully prepared, see :func:`on_ready` for that. +.. autoclass:: discord.events.SubscriptionCreate() + :members: + :inherited-members: - The warnings on :func:`on_ready` also apply. +.. autoclass:: discord.events.SubscriptionUpdate() + :members: + :inherited-members: - .. warning:: - - Overriding this event will not call :meth:`Bot.sync_commands`. - As a result, :class:`ApplicationCommand` will not be registered. - -.. function:: on_shard_connect(shard_id) - - Similar to :func:`on_connect` except used by :class:`AutoShardedClient` - to denote when a particular shard ID has connected to Discord. - - .. versionadded:: 1.4 - - :param shard_id: The shard ID that has connected. - :type shard_id: :class:`int` - -.. function:: on_disconnect() - - Called when the client has disconnected from Discord, or a connection attempt to Discord has failed. - This could happen either through the internet being disconnected, explicit calls to close, - or Discord terminating the connection one way or the other. - - This function can be called many times without a corresponding :func:`on_connect` call. - - -.. function:: on_shard_disconnect(shard_id) - - Similar to :func:`on_disconnect` except used by :class:`AutoShardedClient` - to denote when a particular shard ID has disconnected from Discord. - - .. versionadded:: 1.4 - - :param shard_id: The shard ID that has disconnected. - :type shard_id: :class:`int` - -.. function:: on_ready() - - Called when the client is done preparing the data received from Discord. Usually after login is successful - and the :func:`Client.get_guilds` and co. are filled up. - - .. warning:: - - This function is not guaranteed to be the first event called. - Likewise, this function is **not** guaranteed to only be called - once. This library implements reconnection logic and thus will - end up calling this event whenever a RESUME request fails. - -.. function:: on_shard_ready(shard_id) - - Similar to :func:`on_ready` except used by :class:`AutoShardedClient` - to denote when a particular shard ID has become ready. - - :param shard_id: The shard ID that is ready. - :type shard_id: :class:`int` - -.. function:: on_resumed() - - Called when the client has resumed a session. - -.. function:: on_shard_resumed(shard_id) - - Similar to :func:`on_resumed` except used by :class:`AutoShardedClient` - to denote when a particular shard ID has resumed a session. - - .. versionadded:: 1.4 - - :param shard_id: The shard ID that has resumed. - :type shard_id: :class:`int` - -.. function:: on_socket_event_type(event_type) - - Called whenever a WebSocket event is received from the WebSocket. - - This is mainly useful for logging how many events you are receiving - from the Discord gateway. - - .. versionadded:: 2.0 - - :param event_type: The event type from Discord that is received, e.g. ``'READY'``. - :type event_type: :class:`str` - -.. function:: on_socket_raw_receive(msg) - - Called whenever a message is completely received from the WebSocket, before - it's processed and parsed. This event is always dispatched when a - complete message is received and the passed data is not parsed in any way. - - This is only really useful for grabbing the WebSocket stream and - debugging purposes. - - This requires setting the ``enable_debug_events`` setting in the :class:`Client`. - - .. note:: - - This is only for the messages received from the client - WebSocket. The voice WebSocket will not trigger this event. - - :param msg: The message passed in from the WebSocket library. - :type msg: :class:`str` - -.. function:: on_socket_raw_send(payload) - - Called whenever a send operation is done on the WebSocket before the - message is sent. The passed parameter is the message that is being - sent to the WebSocket. - - This is only really useful for grabbing the WebSocket stream and - debugging purposes. - - This requires setting the ``enable_debug_events`` setting in the :class:`Client`. - - .. note:: - - This is only for the messages sent from the client - WebSocket. The voice WebSocket will not trigger this event. - - :param payload: The message that is about to be passed on to the - WebSocket library. It can be :class:`bytes` to denote a binary - message or :class:`str` to denote a regular text message. +.. autoclass:: discord.events.SubscriptionDelete() + :members: + :inherited-members: Guilds ------- -.. function:: on_guild_join(guild) - - Called when a :class:`Guild` is either created by the :class:`Client` or when the - :class:`Client` joins a guild. - - This requires :attr:`Intents.guilds` to be enabled. - - :param guild: The guild that was joined. - :type guild: :class:`Guild` - -.. function:: on_guild_remove(guild) - - Called when a :class:`Guild` is removed from the :class:`Client`. - - This happens through, but not limited to, these circumstances: - - - The client got banned. - - The client got kicked. - - The client left the guild. - - The client or the guild owner deleted the guild. - - In order for this event to be invoked then the :class:`Client` must have - been part of the guild to begin with. (i.e. it is part of :func:`Client.get_guilds`) - - This requires :attr:`Intents.guilds` to be enabled. - - :param guild: The guild that got removed. - :type guild: :class:`Guild` - -.. function:: on_guild_update(before, after) - - Called when a :class:`Guild` is updated, for example: - - - Changed name - - Changed AFK channel - - Changed AFK timeout - - etc. - - This requires :attr:`Intents.guilds` to be enabled. - - :param before: The guild prior to being updated. - :type before: :class:`Guild` - :param after: The guild after being updated. - :type after: :class:`Guild` - -.. function:: on_guild_role_create(role) - on_guild_role_delete(role) - - Called when a :class:`Guild` creates or deletes a :class:`Role`. - - To get the guild it belongs to, use :attr:`Role.guild`. - - This requires :attr:`Intents.guilds` to be enabled. - - :param role: The role that was created or deleted. - :type role: :class:`Role` - -.. function:: on_guild_role_update(before, after) - - Called when a :class:`Role` is changed guild-wide. - - This requires :attr:`Intents.guilds` to be enabled. - - :param before: The updated role's old info. - :type before: :class:`Role` - :param after: The updated role's updated info. - :type after: :class:`Role` - -.. function:: on_guild_emojis_update(guild, before, after) - - Called when a :class:`Guild` adds or removes an :class:`GuildEmoji`. - - This requires :attr:`Intents.emojis_and_stickers` to be enabled. +~~~~~~ - :param guild: The guild who got their emojis updated. - :type guild: :class:`Guild` - :param before: A list of emojis before the update. - :type before: Sequence[:class:`GuildEmoji`] - :param after: A list of emojis after the update. - :type after: Sequence[:class:`GuildEmoji`] +.. autoclass:: discord.events.GuildJoin() + :members: + :inherited-members: -.. function:: on_guild_stickers_update(guild, before, after) +.. autoclass:: discord.events.GuildCreate() + :members: + :inherited-members: - Called when a :class:`Guild` adds or removes a sticker. +.. autoclass:: discord.events.GuildDelete() + :members: + :inherited-members: - This requires :attr:`Intents.emojis_and_stickers` to be enabled. +.. autoclass:: discord.events.GuildUpdate() + :members: + :inherited-members: - .. versionadded:: 2.0 +.. autoclass:: discord.events.GuildAvailable() + :members: + :inherited-members: - :param guild: The guild who got their stickers updated. - :type guild: :class:`Guild` - :param before: A list of stickers before the update. - :type before: Sequence[:class:`GuildSticker`] - :param after: A list of stickers after the update. - :type after: Sequence[:class:`GuildSticker`] +.. autoclass:: discord.events.GuildUnavailable() + :members: + :inherited-members: -.. function:: on_guild_available(guild) - on_guild_unavailable(guild) +.. autoclass:: discord.events.GuildBanAdd() + :members: + :inherited-members: - Called when a guild becomes available or unavailable. The guild must have - existed in the :func:`Client.get_guilds` cache. +.. autoclass:: discord.events.GuildBanRemove() + :members: + :inherited-members: - This requires :attr:`Intents.guilds` to be enabled. +.. autoclass:: discord.events.GuildEmojisUpdate() + :members: + :inherited-members: - :param guild: The guild that has changed availability. - :type guild: :class:`Guild` +.. autoclass:: discord.events.GuildStickersUpdate() + :members: + :inherited-members: -.. function:: on_webhooks_update(channel) +Roles +^^^^^ - Called whenever a webhook is created, modified, or removed from a guild channel. +.. autoclass:: discord.events.GuildRoleCreate() + :members: + :inherited-members: - This requires :attr:`Intents.webhooks` to be enabled. +.. autoclass:: discord.events.GuildRoleUpdate() + :members: + :inherited-members: - :param channel: The channel that had its webhooks updated. - :type channel: :class:`abc.GuildChannel` +.. autoclass:: discord.events.GuildRoleDelete() + :members: + :inherited-members: Integrations ------------- -.. function:: on_guild_integrations_update(guild) - - Called whenever an integration is created, modified, or removed from a guild. - - This requires :attr:`Intents.integrations` to be enabled. - - .. versionadded:: 1.4 - - :param guild: The guild that had its integrations updated. - :type guild: :class:`Guild` - -.. function:: on_integration_create(integration) - - Called when an integration is created. - - This requires :attr:`Intents.integrations` to be enabled. - - .. versionadded:: 2.0 +~~~~~~~~~~~~ - :param integration: The integration that was created. - :type integration: :class:`Integration` +.. autoclass:: discord.events.GuildIntegrationsUpdate() + :members: + :inherited-members: -.. function:: on_integration_update(integration) +.. autoclass:: discord.events.IntegrationCreate() + :members: + :inherited-members: - Called when an integration is updated. +.. autoclass:: discord.events.IntegrationUpdate() + :members: + :inherited-members: - This requires :attr:`Intents.integrations` to be enabled. - - .. versionadded:: 2.0 - - :param integration: The integration that was created. - :type integration: :class:`Integration` - -.. function:: on_raw_integration_delete(payload) - - Called when an integration is deleted. - - This requires :attr:`Intents.integrations` to be enabled. - - .. versionadded:: 2.0 - - :param payload: The raw event payload data. - :type payload: :class:`RawIntegrationDeleteEvent` +.. autoclass:: discord.events.IntegrationDelete() + :members: + :inherited-members: Interactions ------------- -.. function:: on_interaction(interaction) +~~~~~~~~~~~~ - Called when an interaction happened. - - This currently happens due to application command invocations or components being used. - - .. warning:: - - This is a low level function that is not generally meant to be used. - If you are working with components, consider using the callbacks associated - with the :class:`~discord.ui.View` instead as it provides a nicer user experience. - - .. versionadded:: 2.0 - - :param interaction: The interaction data. - :type interaction: :class:`Interaction` +.. autoclass:: discord.events.InteractionCreate() + :members: + :inherited-members: Invites -------- -.. function:: on_invite_create(invite) - - Called when an :class:`Invite` is created. - You must have the :attr:`~Permissions.manage_channels` permission to receive this. - - .. versionadded:: 1.3 - - .. note:: - - There is a rare possibility that the :attr:`Invite.guild` and :attr:`Invite.channel` - attributes will be of :class:`Object` rather than the respective models. - - This requires :attr:`Intents.invites` to be enabled. - - :param invite: The invite that was created. - :type invite: :class:`Invite` - -.. function:: on_invite_delete(invite) - - Called when an :class:`Invite` is deleted. - You must have the :attr:`~Permissions.manage_channels` permission to receive this. - - .. versionadded:: 1.3 - - .. note:: - - There is a rare possibility that the :attr:`Invite.guild` and :attr:`Invite.channel` - attributes will be of :class:`Object` rather than the respective models. - - Outside of those two attributes, the only other attribute guaranteed to be - filled by the Discord gateway for this event is :attr:`Invite.code`. - - This requires :attr:`Intents.invites` to be enabled. - - :param invite: The invite that was deleted. - :type invite: :class:`Invite` - -Members/Users -------------- -.. function:: on_member_join(member) - - Called when a :class:`Member` joins a :class:`Guild`. - - This requires :attr:`Intents.members` to be enabled. - - :param member: The member who joined. - :type member: :class:`Member` - -.. function:: on_member_remove(member) - - Called when a :class:`Member` leaves a :class:`Guild`. - - If the guild or member could not be found in the internal cache, this event will not - be called. Alternatively, :func:`on_raw_member_remove` is called regardless of the - internal cache. - - This requires :attr:`Intents.members` to be enabled. - - :param member: The member who left. - :type member: :class:`Member` - -.. function:: on_raw_member_remove(payload) - - Called when a :class:`Member` leaves a :class:`Guild`. Unlike - :func:`on_member_remove`, this is called regardless of the state of the internal - member cache. - - This requires :attr:`Intents.members` to be enabled. - - .. versionadded:: 2.4 - - :param payload: The raw event payload data. - :type payload: :class:`RawMemberRemoveEvent` - -.. function:: on_member_update(before, after) - - Called when a :class:`Member` updates their profile. - - This is called when one or more of the following things change: - - - nickname - - roles - - pending - - communication_disabled_until - - timed_out - - This requires :attr:`Intents.members` to be enabled. - - :param before: The updated member's old info. - :type before: :class:`Member` - :param after: The updated member's updated info. - :type after: :class:`Member` - -.. function:: on_presence_update(before, after) - - Called when a :class:`Member` updates their presence. - - This is called when one or more of the following things change: - - - status - - activity - - This requires :attr:`Intents.presences` and :attr:`Intents.members` to be enabled. - - .. versionadded:: 2.0 - - :param before: The updated member's old info. - :type before: :class:`Member` - :param after: The updated member's updated info. - :type after: :class:`Member` +~~~~~~~ -.. function:: on_voice_state_update(member, before, after) +.. autoclass:: discord.events.InviteCreate() + :members: + :inherited-members: - Called when a :class:`Member` changes their :class:`VoiceState`. +.. autoclass:: discord.events.InviteDelete() + :members: + :inherited-members: - The following, but not limited to, examples illustrate when this event is called: +Members & Users +~~~~~~~~~~~~~~~ - - A member joins a voice or stage channel. - - A member leaves a voice or stage channel. - - A member is muted or deafened by their own accord. - - A member is muted or deafened by a guild administrator. +.. autoclass:: discord.events.GuildMemberJoin() + :members: + :inherited-members: - This requires :attr:`Intents.voice_states` to be enabled. +.. autoclass:: discord.events.GuildMemberRemove() + :members: + :inherited-members: - :param member: The member whose voice states changed. - :type member: :class:`Member` - :param before: The voice state prior to the changes. - :type before: :class:`VoiceState` - :param after: The voice state after the changes. - :type after: :class:`VoiceState` +.. autoclass:: discord.events.GuildMemberUpdate() + :members: + :inherited-members: -.. function:: on_user_update(before, after) +.. autoclass:: discord.events.UserUpdate() + :members: + :inherited-members: - Called when a :class:`User` updates their profile. - - This is called when one or more of the following things change: - - - avatar - - username - - discriminator - - global_name - - This requires :attr:`Intents.members` to be enabled. - - :param before: The updated user's old info. - :type before: :class:`User` - :param after: The updated user's updated info. - :type after: :class:`User` +.. autoclass:: discord.events.PresenceUpdate() + :members: + :inherited-members: Messages --------- -.. function:: on_message(message) - - Called when a :class:`Message` is created and sent. - - This requires :attr:`Intents.messages` to be enabled. - - .. warning:: - - Your bot's own messages and private messages are sent through this - event. This can lead cases of 'recursion' depending on how your bot was - programmed. If you want the bot to not reply to itself, consider - checking the user IDs. Note that :class:`~ext.commands.Bot` does not - have this problem. - - :param message: The current message. - :type message: :class:`Message` - -.. function:: on_message_delete(message) - - Called when a message is deleted. If the message is not found in the - internal message cache, then this event will not be called. - Messages might not be in cache if the message is too old - or the client is participating in high traffic guilds. - - If this occurs increase the :class:`max_messages ` parameter - or use the :func:`on_raw_message_delete` event instead. - - This requires :attr:`Intents.messages` to be enabled. - - :param message: The deleted message. - :type message: :class:`Message` - -.. function:: on_bulk_message_delete(messages) - - Called when messages are bulk deleted. If none of the messages deleted - are found in the internal message cache, then this event will not be called. - If individual messages were not found in the internal message cache, - this event will still be called, but the messages not found will not be included in - the messages list. Messages might not be in cache if the message is too old - or the client is participating in high traffic guilds. - - If this occurs increase the :class:`max_messages ` parameter - or use the :func:`on_raw_bulk_message_delete` event instead. - - This requires :attr:`Intents.messages` to be enabled. - - :param messages: The messages that have been deleted. - :type messages: List[:class:`Message`] - -.. function:: on_raw_message_delete(payload) - - Called when a message is deleted. Unlike :func:`on_message_delete`, this is - called regardless of the message being in the internal message cache or not. - - If the message is found in the message cache, - it can be accessed via :attr:`RawMessageDeleteEvent.cached_message` - - This requires :attr:`Intents.messages` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawMessageDeleteEvent` - -.. function:: on_raw_bulk_message_delete(payload) - - Called when a bulk delete is triggered. Unlike :func:`on_bulk_message_delete`, this is - called regardless of the messages being in the internal message cache or not. - - If the messages are found in the message cache, - they can be accessed via :attr:`RawBulkMessageDeleteEvent.cached_messages` +~~~~~~~~ - This requires :attr:`Intents.messages` to be enabled. +.. autoclass:: discord.events.MessageCreate() + :members: + :inherited-members: - :param payload: The raw event payload data. - :type payload: :class:`RawBulkMessageDeleteEvent` +.. autoclass:: discord.events.MessageUpdate() + :members: + :inherited-members: -.. function:: on_message_edit(before, after) +.. autoclass:: discord.events.MessageDelete() + :members: + :inherited-members: - Called when a :class:`Message` receives an update event. If the message is not found - in the internal message cache, then these events will not be called. - Messages might not be in cache if the message is too old - or the client is participating in high traffic guilds. - - If this occurs increase the :class:`max_messages ` parameter - or use the :func:`on_raw_message_edit` event instead. - - The following non-exhaustive cases trigger this event: - - - A message has been pinned or unpinned. - - The message content has been changed. - - The message has received an embed. - - - For performance reasons, the embed server does not do this in a "consistent" manner. - - - The message's embeds were suppressed or unsuppressed. - - A call message has received an update to its participants or ending time. - - A poll has ended and the results have been finalized. - - This requires :attr:`Intents.messages` to be enabled. - - :param before: The previous version of the message. - :type before: :class:`Message` - :param after: The current version of the message. - :type after: :class:`Message` - -.. function:: on_raw_message_edit(payload) - - Called when a message is edited. Unlike :func:`on_message_edit`, this is called - regardless of the state of the internal message cache. - - If the message is found in the message cache, - it can be accessed via :attr:`RawMessageUpdateEvent.cached_message`. The cached message represents - the message before it has been edited. For example, if the content of a message is modified and - triggers the :func:`on_raw_message_edit` coroutine, the :attr:`RawMessageUpdateEvent.cached_message` - will return a :class:`Message` object that represents the message before the content was modified. - - Due to the inherently raw nature of this event, the data parameter coincides with - the raw data given by the `gateway `_. - - Since the data payload can be partial, care must be taken when accessing stuff in the dictionary. - One example of a common case of partial data is when the ``'content'`` key is inaccessible. This - denotes an "embed" only edit, which is an edit in which only the embeds are updated by the Discord - embed server. - - This requires :attr:`Intents.messages` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawMessageUpdateEvent` - -Polls -~~~~~~ -.. function:: on_poll_vote_add(poll, user, answer) - - Called when a vote is cast on a poll. If multiple answers were selected, this fires multiple times. - if the poll was not found in the internal poll cache, then this - event will not be called. Consider using :func:`on_raw_poll_vote_add` instead. - - This requires :attr:`Intents.polls` to be enabled. - - :param poll: The current state of the poll. - :type poll: :class:`Poll` - :param user: The user who added the vote. - :type user: Union[:class:`Member`, :class:`User`] - :param answer: The answer that was voted. - :type answer: :class:`PollAnswer` - -.. function:: on_raw_poll_vote_add(payload) - - Called when a vote is cast on a poll. Unlike :func:`on_poll_vote_add`, this is - called regardless of the state of the internal poll cache. - - This requires :attr:`Intents.polls` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawMessagePollVoteEvent` - -.. function:: on_poll_vote_remove(message, user, answer) - - Called when a vote is removed from a poll. If multiple answers were removed, this fires multiple times. - if the poll is not found in the internal poll cache, then this - event will not be called. Consider using :func:`on_raw_poll_vote_remove` instead. - - This requires :attr:`Intents.polls` to be enabled. - - :param poll: The current state of the poll. - :type poll: :class:`Poll` - :param user: The user who removed the vote. - :type user: Union[:class:`Member`, :class:`User`] - :param answer: The answer that was voted. - :type answer: :class:`PollAnswer` - -.. function:: on_raw_poll_vote_remove(payload) - - Called when a vote is removed from a poll. Unlike :func:`on_poll_vote_remove`, this is - called regardless of the state of the internal message cache. - - This requires :attr:`Intents.polls` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawMessagePollVoteEvent` +.. autoclass:: discord.events.MessageDeleteBulk() + :members: + :inherited-members: Reactions -~~~~~~~~~ -.. function:: on_reaction_add(reaction, user) - - Called when a message has a reaction added to it. Similar to :func:`on_message_edit`, - if the message is not found in the internal message cache, then this - event will not be called. Consider using :func:`on_raw_reaction_add` instead. - - .. note:: - - To get the :class:`Message` being reacted, access it via :attr:`Reaction.message`. - - This requires :attr:`Intents.reactions` to be enabled. - - .. note:: - - This doesn't require :attr:`Intents.members` within a guild context, - but due to Discord not providing updated user information in a direct message - it's required for direct messages to receive this event. - Consider using :func:`on_raw_reaction_add` if you need this and do not otherwise want - to enable the members intent. - - :param reaction: The current state of the reaction. - :type reaction: :class:`Reaction` - :param user: The user who added the reaction. - :type user: Union[:class:`Member`, :class:`User`] - -.. function:: on_raw_reaction_add(payload) - - Called when a message has a reaction added. Unlike :func:`on_reaction_add`, this is - called regardless of the state of the internal message cache. - - This requires :attr:`Intents.reactions` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawReactionActionEvent` - -.. function:: on_reaction_remove(reaction, user) - - Called when a message has a reaction removed from it. Similar to on_message_edit, - if the message is not found in the internal message cache, then this event - will not be called. - - .. note:: - - To get the message being reacted, access it via :attr:`Reaction.message`. - - This requires both :attr:`Intents.reactions` and :attr:`Intents.members` to be enabled. - - .. note:: - - Consider using :func:`on_raw_reaction_remove` if you need this and do not want - to enable the members intent. - - :param reaction: The current state of the reaction. - :type reaction: :class:`Reaction` - :param user: The user who added the reaction. - :type user: Union[:class:`Member`, :class:`User`] - -.. function:: on_raw_reaction_remove(payload) - - Called when a message has a reaction removed. Unlike :func:`on_reaction_remove`, this is - called regardless of the state of the internal message cache. - - This requires :attr:`Intents.reactions` to be enabled. +^^^^^^^^^ - :param payload: The raw event payload data. - :type payload: :class:`RawReactionActionEvent` +.. autoclass:: discord.events.ReactionAdd() + :members: + :inherited-members: -.. function:: on_reaction_clear(message, reactions) +.. autoclass:: discord.events.ReactionRemove() + :members: + :inherited-members: - Called when a message has all its reactions removed from it. Similar to :func:`on_message_edit`, - if the message is not found in the internal message cache, then this event - will not be called. Consider using :func:`on_raw_reaction_clear` instead. +.. autoclass:: discord.events.ReactionClear() + :members: + :inherited-members: - This requires :attr:`Intents.reactions` to be enabled. +.. autoclass:: discord.events.ReactionRemoveEmoji() + :members: + :inherited-members: - :param message: The message that had its reactions cleared. - :type message: :class:`Message` - :param reactions: The reactions that were removed. - :type reactions: List[:class:`Reaction`] - -.. function:: on_raw_reaction_clear(payload) - - Called when a message has all its reactions removed. Unlike :func:`on_reaction_clear`, - this is called regardless of the state of the internal message cache. - - This requires :attr:`Intents.reactions` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawReactionClearEvent` - -.. function:: on_reaction_clear_emoji(reaction) - - Called when a message has a specific reaction removed from it. Similar to :func:`on_message_edit`, - if the message is not found in the internal message cache, then this event - will not be called. Consider using :func:`on_raw_reaction_clear_emoji` instead. - - This requires :attr:`Intents.reactions` to be enabled. - - .. versionadded:: 1.3 - - :param reaction: The reaction that got cleared. - :type reaction: :class:`Reaction` - -.. function:: on_raw_reaction_clear_emoji(payload) - - Called when a message has a specific reaction removed from it. Unlike :func:`on_reaction_clear_emoji` this is called - regardless of the state of the internal message cache. - - This requires :attr:`Intents.reactions` to be enabled. - - .. versionadded:: 1.3 - - :param payload: The raw event payload data. - :type payload: :class:`RawReactionClearEmojiEvent` - -Monetization ------------- -.. function:: on_entitlement_create(entitlement) - - Called when a user subscribes to an SKU. - - .. versionadded:: 2.5 - - :param entitlement: The entitlement that was created as a result of the subscription. - :type entitlement: :class:`Entitlement` - -.. function:: on_entitlement_update(entitlement) - - Called when a user's subscription to an Entitlement is cancelled. - - .. versionadded:: 2.5 - - .. note:: - - Before October 1, 2024, this event was called when a user's subscription was renewed. - - Entitlements that no longer follow this behavior will have a type of :attr:`EntitlementType.purchase`. - Those that follow the old behavior will have a type of :attr:`EntitlementType.application_subscription`. - - `See the Discord changelog. `_ - - :param entitlement: The entitlement that was updated. - :type entitlement: :class:`Entitlement` - -.. function:: on_entitlement_delete(entitlement) - - Called when a user's entitlement is deleted. - - Entitlements are usually only deleted when Discord issues a refund for a subscription, - or manually removes an entitlement from a user. - - .. note:: - - This is not called when a user's subscription is cancelled. - - .. versionadded:: 2.5 - - :param entitlement: The entitlement that was deleted. - :type entitlement: :class:`Entitlement` - -.. function:: on_subscription_create(subscription) - - Called when a subscription is created for the application. - - .. versionadded:: 2.7 - - :param subscription: The subscription that was created. - :type subscription: :class:`Subscription` - -.. function:: on_subscription_update(subscription) - - Called when a subscription has been updated. This could be a renewal, cancellation, or other payment related update. - - .. versionadded:: 2.7 - - :param subscription: The subscription that was updated. - :type subscription: :class:`Subscription` - -.. function:: on_subscription_delete(subscription) - - Called when a subscription has been deleted. +Polls +^^^^^ - .. versionadded:: 2.7 +.. autoclass:: discord.events.PollVoteAdd() + :members: + :inherited-members: - :param subscription: The subscription that was deleted. - :type subscription: :class:`Subscription` +.. autoclass:: discord.events.PollVoteRemove() + :members: + :inherited-members: Scheduled Events ----------------- -.. function:: on_scheduled_event_create(event) - - Called when an :class:`ScheduledEvent` is created. - - This requires :attr:`Intents.scheduled_events` to be enabled. - - :param event: The newly created scheduled event. - :type event: :class:`ScheduledEvent` - -.. function:: on_scheduled_event_update(before, after) - - Called when a scheduled event is updated. - - This requires :attr:`Intents.scheduled_events` to be enabled. - - :param before: The old scheduled event. - :type before: :class:`ScheduledEvent` - :param after: The updated scheduled event. - :type after: :class:`ScheduledEvent` - -.. function:: on_scheduled_event_delete(event) - - Called when a scheduled event is deleted. +~~~~~~~~~~~~~~~~ - This requires :attr:`Intents.scheduled_events` to be enabled. +.. autoclass:: discord.events.GuildScheduledEventCreate() + :members: + :inherited-members: - :param event: The deleted scheduled event. - :type event: :class:`ScheduledEvent` +.. autoclass:: discord.events.GuildScheduledEventUpdate() + :members: + :inherited-members: -.. function:: on_scheduled_event_user_add(event, member) +.. autoclass:: discord.events.GuildScheduledEventDelete() + :members: + :inherited-members: - Called when a user subscribes to an event. If the member or event - is not found in the internal cache, then this event will not be - called. Consider using :func:`on_raw_scheduled_event_user_add` instead. +.. autoclass:: discord.events.GuildScheduledEventUserAdd() + :members: + :inherited-members: - This requires :attr:`Intents.scheduled_events` to be enabled. +.. autoclass:: discord.events.GuildScheduledEventUserRemove() + :members: + :inherited-members: - :param event: The scheduled event subscribed to. - :type event: :class:`ScheduledEvent` - :param member: The member who subscribed. - :type member: :class:`Member` +Soundboard +~~~~~~~~~~ -.. function:: on_raw_scheduled_event_user_add(payload) +.. autoclass:: discord.events.GuildSoundboardSoundCreate() + :members: + :inherited-members: - Called when a user subscribes to an event. Unlike - :meth:`on_scheduled_event_user_add`, this will be called - regardless of the state of the internal cache. +.. autoclass:: discord.events.GuildSoundboardSoundUpdate() + :members: + :inherited-members: - This requires :attr:`Intents.scheduled_events` to be enabled. +.. autoclass:: discord.events.GuildSoundboardSoundDelete() + :members: + :inherited-members: - :param payload: The raw event payload data. - :type payload: :class:`RawScheduledEventSubscription` - -.. function:: on_scheduled_event_user_remove(event, member) - - Called when a user unsubscribes to an event. If the member or event is - not found in the internal cache, then this event will not be called. - Consider using :func:`on_raw_scheduled_event_user_remove` instead. - - This requires :attr:`Intents.scheduled_events` to be enabled. - - :param event: The scheduled event unsubscribed from. - :type event: :class:`ScheduledEvent` - :param member: The member who unsubscribed. - :type member: :class:`Member` - -.. function:: on_raw_scheduled_event_user_remove(payload) - - Called when a user unsubscribes to an event. Unlike - :meth:`on_scheduled_event_user_remove`, this will be called - regardless of the state of the internal cache. - - This requires :attr:`Intents.scheduled_events` to be enabled. - - :param payload: The raw event payload data. - :type payload: :class:`RawScheduledEventSubscription` +.. autoclass:: discord.events.GuildSoundboardSoundsUpdate() + :members: + :inherited-members: Stage Instances ---------------- -.. function:: on_stage_instance_create(stage_instance) - on_stage_instance_delete(stage_instance) - - Called when a :class:`StageInstance` is created or deleted for a :class:`StageChannel`. - - .. versionadded:: 2.0 - - :param stage_instance: The stage instance that was created or deleted. - :type stage_instance: :class:`StageInstance` - -.. function:: on_stage_instance_update(before, after) - - Called when a :class:`StageInstance` is updated. - - The following, but not limited to, examples illustrate when this event is called: +~~~~~~~~~~~~~~~ - - The topic is changed. - - The privacy level is changed. +.. autoclass:: discord.events.StageInstanceCreate() + :members: + :inherited-members: - .. versionadded:: 2.0 +.. autoclass:: discord.events.StageInstanceUpdate() + :members: + :inherited-members: - :param before: The stage instance before the update. - :type before: :class:`StageInstance` - :param after: The stage instance after the update. - :type after: :class:`StageInstance` +.. autoclass:: discord.events.StageInstanceDelete() + :members: + :inherited-members: Threads -------- -.. function:: on_thread_join(thread) - - Called whenever a thread is joined. - - Note that you can get the guild from :attr:`Thread.guild`. - - This requires :attr:`Intents.guilds` to be enabled. - - .. versionadded:: 2.0 - - :param thread: The thread that got joined. - :type thread: :class:`Thread` - -.. function:: on_thread_create(thread) - - Called whenever a thread is created. - - Note that you can get the guild from :attr:`Thread.guild`. - - This requires :attr:`Intents.guilds` to be enabled. - - .. versionadded:: 2.0 - - :param thread: The thread that got created. - :type thread: :class:`Thread` - -.. function:: on_thread_remove(thread) - - Called whenever a thread is removed. This is different from a thread being deleted. - - Note that you can get the guild from :attr:`Thread.guild`. - - This requires :attr:`Intents.guilds` to be enabled. - - .. warning:: - - Due to technical limitations, this event might not be called - as soon as one expects. Since the library tracks thread membership - locally, the API only sends updated thread membership status upon being - synced by joining a thread. - - .. versionadded:: 2.0 - - :param thread: The thread that got removed. - :type thread: :class:`Thread` - -.. function:: on_thread_delete(thread) - - Called whenever a thread is deleted. If the deleted thread isn't found in internal cache - then this will not be called. Archived threads are not in the cache. Consider using :func:`on_raw_thread_delete` - - - Note that you can get the guild from :attr:`Thread.guild`. - - This requires :attr:`Intents.guilds` to be enabled. - - .. versionadded:: 2.0 - - :param thread: The thread that got deleted. - :type thread: :class:`Thread` - -.. function:: on_raw_thread_delete(payload) - - Called whenever a thread is deleted. Unlike :func:`on_thread_delete` this is called - regardless of the state of the internal cache. - - :param payload: The raw event payload data. - :type payload: :class:`RawThreadDeleteEvent` - -.. function:: on_thread_member_join(member) - on_thread_member_remove(member) - - Called when a :class:`ThreadMember` leaves or joins a :class:`Thread`. - - You can get the thread a member belongs in by accessing :attr:`ThreadMember.thread`. - - This requires :attr:`Intents.members` to be enabled. - - .. versionadded:: 2.0 - - :param member: The member who joined or left. - :type member: :class:`ThreadMember` - - -.. function:: on_raw_thread_member_remove(payload) - - Called when a :class:`ThreadMember` leaves a :class:`Thread`. Unlike :func:`on_thread_member_remove` this - is called regardless of the member being in the thread's internal cache of members or not. - - This requires :attr:`Intents.members` to be enabled. - - .. versionadded:: 2.4 +~~~~~~~ - :param payload: The raw event payload data. - :type member: :class:`RawThreadMembersUpdateEvent` +.. autoclass:: discord.events.ThreadCreate() + :members: + :inherited-members: +.. autoclass:: discord.events.ThreadUpdate() + :members: + :inherited-members: +.. autoclass:: discord.events.ThreadDelete() + :members: + :inherited-members: -.. function:: on_thread_update(before, after) +.. autoclass:: discord.events.ThreadJoin() + :members: + :inherited-members: - Called whenever a thread is updated. +.. autoclass:: discord.events.ThreadRemove() + :members: + :inherited-members: - This requires :attr:`Intents.guilds` to be enabled. +.. autoclass:: discord.events.ThreadMemberJoin() + :members: + :inherited-members: - If the thread could not be found in the internal cache, this event will not be called. - Threads will not be in the cache if they are archived. Alternatively, - :func:`on_raw_thread_update` is called regardless of the internal cache. - - .. versionadded:: 2.0 - - :param before: The updated thread's old info. - :type before: :class:`Thread` - :param after: The updated thread's new info. - :type after: :class:`Thread` - - -.. function:: on_raw_thread_update(payload) - - Called whenever a thread is updated. - - Unlike :func:`on_thread_update` this is called regardless of if the thread is in the - internal thread cache or not. - - This requires :attr:`Intents.guilds` to be enabled. - - .. versionadded:: 2.4 - - :param payload: The raw event payload data. - :type payload: :class:`RawThreadUpdateEvent` +.. autoclass:: discord.events.ThreadMemberRemove() + :members: + :inherited-members: Typing ------- -.. function:: on_typing(channel, user, when) - - Called when someone begins typing a message. - - The ``channel`` parameter can be a :class:`abc.Messageable` instance. - Which could either be :class:`TextChannel`, :class:`GroupChannel`, or - :class:`DMChannel`. - - If the ``channel`` is a :class:`TextChannel` then the ``user`` parameter - is a :class:`Member`, otherwise it is a :class:`User`. - - This requires :attr:`Intents.typing` to be enabled. - - :param channel: The location where the typing originated from. - :type channel: :class:`abc.Messageable` - :param user: The user that started typing. - :type user: Union[:class:`User`, :class:`Member`] - :param when: When the typing started as an aware datetime in UTC. - :type when: :class:`datetime.datetime` - -.. function:: on_raw_typing(payload) - - Called when someone begins typing a message. Unlike :func:`on_typing`, this is - called regardless if the user can be found in the bot's cache or not. - - If the typing event is occurring in a guild, - the member that started typing can be accessed via :attr:`RawTypingEvent.member` - - This requires :attr:`Intents.typing` to be enabled. - - :param payload: The raw typing payload. - :type payload: :class:`RawTypingEvent` - - -Voice Channel Status Update ---------------------------- -.. function:: on_voice_channel_status_update(channel, before, after) - - Called when someone updates a voice channel status. - - .. versionadded:: 2.5 - - :param channel: The channel where the voice channel status update originated from. - :type channel: :class:`abc.GuildChannel` - :param before: The old voice channel status. - :type before: Optional[:class:`str`] - :param after: The new voice channel status. - :type after: Optional[:class:`str`] - -.. function:: on_raw_voice_channel_status_update(payload) - - Called when someone updates a voice channels status. - - .. versionadded:: 2.5 - - :param payload: The raw voice channel status update payload. - :type payload: :class:`RawVoiceChannelStatusUpdateEvent` - -Voice Channel Effects ---------------------- -.. function:: on_voice_channel_effect_send(event) - - Called when a voice channel effect is sent. - - .. versionadded:: 2.7 - - :param event: The voice channel effect event. - :type event: :class:`VoiceChannelEffectSendEvent` - -Soundboard Sound ----------------- -.. function:: on_soundboard_sounds_update(before, after) - - Called when multiple guild soundboard sounds are updated at once and they were all already in the cache. - This is called, for example, when a guild loses a boost level and some sounds become unavailable. - - .. versionadded:: 2.7 - - :param before: The soundboard sounds prior to being updated. - :type before: List[:class:`SoundboardSound`] - :param after: The soundboard sounds after being updated. - :type after: List[:class:`SoundboardSound`] - -.. function:: on_raw_soundboard_sounds_update(after) - - Called when multiple guild soundboard sounds are updated at once. - This is called, for example, when a guild loses a boost level and some sounds become unavailable. - - .. versionadded:: 2.7 - - :param after: The soundboard sounds after being updated. - :type after: List[:class:`SoundboardSound`] - -.. function:: on_soundboard_sound_update(before, after) - - Called when a soundboard sound is updated and it was already in the cache. - - .. versionadded:: 2.7 - - :param before: The soundboard sound prior to being updated. - :type before: :class:`Soundboard` - :param after: The soundboard sound after being updated. - :type after: :class:`Soundboard` - -.. function:: on_raw_soundboard_sound_update(after) - - Called when a soundboard sound is updated. - - .. versionadded:: 2.7 - - :param after: The soundboard sound after being updated. - :type after: :class:`SoundboardSound` - -.. function:: on_soundboard_sound_delete(sound) - - Called when a soundboard sound is deleted. - - .. versionadded:: 2.7 - - :param sound: The soundboard sound that was deleted. - :type sound: :class:`SoundboardSound` +~~~~~~ -.. function:: on_raw_soundboard_sound_delete(payload) +.. autoclass:: discord.events.TypingStart() + :members: + :inherited-members: - Called when a soundboard sound is deleted. +Voice +~~~~~ - .. versionadded:: 2.7 +.. autoclass:: discord.events.VoiceStateUpdate() + :members: + :inherited-members: - :param payload: The raw event payload data. - :type payload: :class:`RawSoundboardSoundDeleteEvent` +.. autoclass:: discord.events.VoiceServerUpdate() + :members: + :inherited-members: -.. function:: on_soundboard_sound_create(sound) +.. autoclass:: discord.events.VoiceChannelStatusUpdate() + :members: + :inherited-members: - Called when a soundboard sound is created. +.. autoclass:: discord.events.VoiceChannelEffectSend() + :members: + :inherited-members: - .. versionadded:: 2.7 +Webhooks +~~~~~~~~ - :param sound: The soundboard sound that was created. - :type sound: :class:`SoundboardSound` +.. autoclass:: discord.events.WebhooksUpdate() + :members: + :inherited-members: diff --git a/docs/api/gears.rst b/docs/api/gears.rst new file mode 100644 index 0000000000..160aaaa17c --- /dev/null +++ b/docs/api/gears.rst @@ -0,0 +1,168 @@ +.. currentmodule:: discord + +.. _discord_api_gears: + +Gears +===== + +Gears are a modular event handling system in Pycord that allow you to organize your event listeners +into reusable components. They provide a clean way to structure event-driven code and enable +composition by allowing gears to be attached to other gears or to the bot itself. + +Gear +---- + +.. attributetable:: discord.gears.Gear + +.. autoclass:: discord.gears.Gear + :members: + :exclude-members: listen + + .. automethod:: discord.gears.Gear.listen(event, once=False) + :decorator: + +Basic Usage +----------- + +Creating a Gear +~~~~~~~~~~~~~~~ + +You can create a gear by subclassing :class:`discord.gears.Gear` and using the :meth:`~discord.gears.Gear.listen` +decorator to register event listeners: + +.. code-block:: python3 + + from discord.gears import Gear + from discord.events import Ready, MessageCreate + + class MyGear(Gear): + @Gear.listen() + async def on_ready(self, event: Ready) -> None: + print(f"Bot is ready!") + + @Gear.listen() + async def on_message(self, event: MessageCreate) -> None: + print(f"Message: {event.content}") + +Attaching Gears +~~~~~~~~~~~~~~~ + +Gears can be attached to a :class:`Client` or :class:`Bot` using the :meth:`~Client.attach_gear` method: + +.. code-block:: python3 + + bot = discord.Bot() + my_gear = MyGear() + bot.attach_gear(my_gear) + +You can also attach gears to other gears, creating a hierarchy: + +.. code-block:: python3 + + parent_gear = MyGear() + child_gear = AnotherGear() + parent_gear.attach_gear(child_gear) + +Instance Listeners +~~~~~~~~~~~~~~~~~~ + +You can also add listeners to a gear instance dynamically: + +.. code-block:: python3 + + my_gear = MyGear() + + @my_gear.listen() + async def on_guild_join(event: GuildJoin) -> None: + print(f"Joined guild: {event.guild.name}") + +Advanced Usage +-------------- + +One-Time Listeners +~~~~~~~~~~~~~~~~~~ + +Use the ``once`` parameter to create listeners that are automatically removed after being called once: + +.. code-block:: python3 + + class MyGear(Gear): + @Gear.listen(once=True) + async def on_first_message(self, event: MessageCreate) -> None: + print("This will only run once!") + +Manual Listener Management +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can manually add and remove listeners using :meth:`~discord.gears.Gear.add_listener` and +:meth:`~discord.gears.Gear.remove_listener`: + +.. code-block:: python3 + + from discord.events import MessageCreate + + async def my_listener(event: MessageCreate) -> None: + print(f"Message: {event.content}") + + gear = MyGear() + gear.add_listener(my_listener, event=MessageCreate) + + # Later, remove it + gear.remove_listener(my_listener, event=MessageCreate) + +Detaching Gears +~~~~~~~~~~~~~~~ + +Remove a gear using :meth:`~discord.gears.Gear.detach_gear`: + +.. code-block:: python3 + + bot.detach_gear(my_gear) + +Client and Bot Integration +--------------------------- + +Both :class:`Client` and :class:`Bot` provide gear-related methods: + +- :meth:`Client.attach_gear` - Attach a gear to the client +- :meth:`Client.detach_gear` - Detach a gear from the client +- :meth:`Client.add_listener` - Add an event listener directly +- :meth:`Client.remove_listener` - Remove an event listener +- :meth:`Client.listen` - Decorator to add listeners to the client + +These methods work identically to their :class:`~discord.gears.Gear` counterparts. + +Example: Modular Bot Structure +------------------------------- + +Here's an example of using gears to create a modular bot: + +.. code-block:: python3 + + from discord import Bot + from discord.gears import Gear + from discord.events import Ready, MessageCreate, GuildJoin + + class LoggingGear(Gear): + @Gear.listen() + async def log_ready(self, event: Ready) -> None: + print("Bot started!") + + @Gear.listen() + async def log_messages(self, event: MessageCreate) -> None: + print(f"[{event.channel.name}] {event.author}: {event.content}") + + class ModerationGear(Gear): + @Gear.listen() + async def welcome_new_guilds(self, event: GuildJoin) -> None: + system_channel = event.guild.system_channel + if system_channel: + await system_channel.send("Thanks for adding me!") + + bot = Bot() + + # Attach gears to the bot + bot.attach_gear(LoggingGear()) + bot.attach_gear(ModerationGear()) + + bot.run("TOKEN") diff --git a/docs/api/index.rst b/docs/api/index.rst index 0c8624774f..9ab2552a50 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -21,6 +21,7 @@ The following section outlines the API of Pycord. clients application_commands cogs + gears application_info voice events diff --git a/docs/api/models.rst b/docs/api/models.rst index 1aeb7bb2b3..d24b587b01 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -505,6 +505,19 @@ Channels :members: :inherited-members: +.. attributetable:: NewsChannel + +.. autoclass:: NewsChannel() + :members: + :inherited-members: + :exclude-members: history, typing + + .. automethod:: history + :async-for: + + .. automethod:: typing + :async-with: + .. attributetable:: CategoryChannel .. autoclass:: CategoryChannel() @@ -662,11 +675,6 @@ Events .. autoclass:: RawVoiceChannelStatusUpdateEvent() :members: -.. attributetable:: VoiceChannelEffectSendEvent - -.. autoclass:: VoiceChannelEffectSendEvent() - :members: - Webhooks diff --git a/docs/api/utils.rst b/docs/api/utils.rst index db6930cf0d..2ee4104c27 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -7,9 +7,6 @@ Utility Functions .. autofunction:: discord.utils.find - -.. autofunction:: discord.utils.get_or_fetch - .. autofunction:: discord.utils.oauth_url .. autofunction:: discord.utils.remove_markdown diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 2d9af7f378..c1f689d2bd 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -23,7 +23,7 @@ Bot .. autoclass:: discord.ext.commands.Bot :members: :inherited-members: - :exclude-members: after_invoke, before_invoke, check, check_once, command, event, group, listen + :exclude-members: after_invoke, before_invoke, check, check_once, command, group, listen .. automethod:: Bot.after_invoke() :decorator: @@ -40,13 +40,10 @@ Bot .. automethod:: Bot.command(*args, **kwargs) :decorator: - .. automethod:: Bot.event() - :decorator: - .. automethod:: Bot.group(*args, **kwargs) :decorator: - .. automethod:: Bot.listen(name=None, once=False) + .. automethod:: Bot.listen(event, once=False) :decorator: AutoShardedBot diff --git a/pyproject.toml b/pyproject.toml index 05f507b24b..9ee26fb00a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ exclude = [ [tool.ruff.lint] select = ["ALL"] -per-file-ignores = {} + # When ignoring a rule globally, please consider if it can be ignored in a more specific way. # Also, leave a comment explaining why the rule is ignored. extend-ignore = [ @@ -301,6 +301,9 @@ extend-ignore = [ "D203" # conflicts with formatter ] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["PLC0415"] # Allow non top-level imports in tests + [tool.mypy] namespace_packages = true install_types = true diff --git a/scripts/check_license/__init__.py b/scripts/check_license/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/scripts/check_license/__main__.py b/scripts/check_license/__main__.py new file mode 100644 index 0000000000..25998a536a --- /dev/null +++ b/scripts/check_license/__main__.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +"""Check copyright headers in Python files.""" + +import sys +from pathlib import Path + +MIT_HEADER = "The MIT License (MIT)\n\nCopyright (c) 2021-present Pycord Development" + +# Files with non-MIT licenses +EXCEPTIONS = { + "discord/utils/private.py": "CC-BY-SA 4.0", # hybridmethod + # Add more exceptions as needed +} + + +def check_file(filepath: Path) -> tuple[bool, str]: + """ + Check if file has appropriate header. + + Returns: + (is_valid, message) + """ + relative_path = str(filepath.relative_to(Path.cwd())) + + try: + content = filepath.read_text(encoding="utf-8") + except Exception as e: + return False, f"Error reading file: {e}" + + # Check if this is an exception file + if relative_path in EXCEPTIONS: + expected_license = EXCEPTIONS[relative_path] + # Just verify it has SOME license header + if "License" in content[:500] or "Copyright" in content[:500]: + return True, f"OK (Exception: {expected_license})" + return False, f"Missing license header (Expected: {expected_license})" + + # Check for standard MIT header + if MIT_HEADER in content: + return True, "OK (MIT)" + + return False, "Missing MIT license header" + + +def main(): + errors = [] + warnings = [] + + print("Checking copyright headers...\n") + + for filepath in sorted(Path("discord").rglob("*.py")): + # Skip common excluded directories + if any(part in ["__pycache__", ".git", "venv", ".venv"] for part in filepath.parts): + continue + + is_valid, message = check_file(filepath) + relative_path = filepath.relative_to(Path.cwd()) + + if not is_valid: + errors.append((relative_path, message)) + print(f"❌ {relative_path}: {message}") + elif "Exception" in message: + warnings.append((relative_path, message)) + print(f"⚠️ {relative_path}: {message}") + else: + print(f"✓ {relative_path}: {message}") + + print("\n" + "=" * 60) + + if warnings: + print(f"\n⚠️ {len(warnings)} file(s) with non-MIT licenses:") + for path, msg in warnings: + print(f" - {path}") + + if errors: + print(f"\n❌ {len(errors)} file(s) with issues:") + for path, msg in errors: + print(f" - {path}: {msg}") + sys.exit(1) + else: + print(f"\n✓ All {len(list(Path('discord').rglob('*.py')))} files have valid headers!") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/tests/event_helpers.py b/tests/event_helpers.py new file mode 100644 index 0000000000..ff69f06d01 --- /dev/null +++ b/tests/event_helpers.py @@ -0,0 +1,107 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from typing import Any +from unittest.mock import AsyncMock + +from discord.app.event_emitter import Event, EventEmitter +from discord.app.state import ConnectionState + + +class EventCapture: + """Helper class to capture events emitted by the EventEmitter.""" + + def __init__(self): + self.events: list[Event] = [] + self.call_count = 0 + + async def __call__(self, event: Event) -> None: + """Called when an event is received.""" + self.events.append(event) + self.call_count += 1 + + def assert_called_once(self): + """Assert that the event was received exactly once.""" + assert self.call_count == 1, f"Expected 1 event, got {self.call_count}" + + def assert_called_with_event_type(self, event_type: type[Event]): + """Assert that the event received is of the expected type.""" + assert len(self.events) > 0, "No events were captured" + event = self.events[-1] + assert isinstance(event, event_type), f"Expected {event_type.__name__}, got {type(event).__name__}" + + def assert_not_called(self): + """Assert that no events were received.""" + assert self.call_count == 0, f"Expected 0 events, got {self.call_count}" + + def get_last_event(self) -> Event | None: + """Get the last event that was captured.""" + return self.events[-1] if self.events else None + + def reset(self): + """Reset the capture state.""" + self.events.clear() + self.call_count = 0 + + +async def emit_and_capture( + state: ConnectionState, + event_name: str, + payload: Any, +) -> EventCapture: + """ + Emit an event and capture it using an EventCapture receiver. + + Args: + state: The ConnectionState to use for emission + event_name: The name of the event to emit + payload: The payload to emit + + Returns: + EventCapture instance containing captured events + """ + capture = EventCapture() + state.emitter.add_receiver(capture) + + try: + await state.emitter.emit(event_name, payload) + finally: + state.emitter.remove_receiver(capture) + + return capture + + +async def populate_guild_cache(state: ConnectionState, guild_id: int, guild_data: dict[str, Any]): + """ + Populate the cache with a guild. + + Args: + state: The ConnectionState to populate + guild_id: The ID of the guild + guild_data: The guild data payload + """ + from discord.guild import Guild + + guild = await Guild._from_data(guild_data, state) + await state.cache.add_guild(guild) diff --git a/tests/events/__init__.py b/tests/events/__init__.py new file mode 100644 index 0000000000..30cf39b9c0 --- /dev/null +++ b/tests/events/__init__.py @@ -0,0 +1 @@ +"""Event tests for py-cord.""" diff --git a/tests/events/test_events_channel.py b/tests/events/test_events_channel.py new file mode 100644 index 0000000000..16aad2e5f4 --- /dev/null +++ b/tests/events/test_events_channel.py @@ -0,0 +1,172 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import pytest + +from discord.events.channel import ( + ChannelCreate, + ChannelDelete, + ChannelPinsUpdate, + GuildChannelUpdate, +) +from tests.event_helpers import emit_and_capture, populate_guild_cache +from tests.fixtures import create_channel_payload, create_guild_payload, create_mock_state + + +@pytest.mark.asyncio +async def test_channel_create(): + """Test that CHANNEL_CREATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create channel payload + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + + # Emit event and capture + capture = await emit_and_capture(state, "CHANNEL_CREATE", channel_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(ChannelCreate) + + event = capture.get_last_event() + assert event is not None + assert event.id == channel_id + assert event.name == "test-channel" + + +@pytest.mark.asyncio +async def test_channel_delete(): + """Test that CHANNEL_DELETE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + + # Populate cache with guild and channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create channel first + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Now delete it + capture = await emit_and_capture(state, "CHANNEL_DELETE", channel_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(ChannelDelete) + + event = capture.get_last_event() + assert event is not None + assert event.id == channel_id + assert event.name == "test-channel" + + +@pytest.mark.asyncio +async def test_channel_pins_update(): + """Test that CHANNEL_PINS_UPDATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + + # Populate cache with guild and channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Create pins update payload + pins_data = { + "guild_id": str(guild_id), + "channel_id": str(channel_id), + "last_pin_timestamp": "2024-01-01T00:00:00+00:00", + } + + # Emit event and capture + capture = await emit_and_capture(state, "CHANNEL_PINS_UPDATE", pins_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(ChannelPinsUpdate) + + event = capture.get_last_event() + assert event is not None + assert event.channel.id == channel_id + assert event.last_pin is not None + + +@pytest.mark.asyncio +async def test_channel_update(): + """Test that CHANNEL_UPDATE event triggers GUILD_CHANNEL_UPDATE.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + + # Populate cache with guild and channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Update channel + updated_channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="updated-channel") + + # Emit event and capture + capture = await emit_and_capture(state, "CHANNEL_UPDATE", updated_channel_data) + + # Assertions - CHANNEL_UPDATE dispatches GUILD_CHANNEL_UPDATE + # The original event doesn't return anything but emits a sub-event + assert capture.call_count >= 0 # May emit GUILD_CHANNEL_UPDATE + + +@pytest.mark.asyncio +async def test_channel_create_without_guild(): + """Test that CHANNEL_CREATE returns None when guild is not found.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + + # Don't populate cache with guild + + # Create channel payload + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + + # Emit event and capture + capture = await emit_and_capture(state, "CHANNEL_CREATE", channel_data) + + # Assertions - should not emit event if guild not found + capture.assert_not_called() diff --git a/tests/events/test_events_guild.py b/tests/events/test_events_guild.py new file mode 100644 index 0000000000..0d8578f1e6 --- /dev/null +++ b/tests/events/test_events_guild.py @@ -0,0 +1,408 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import pytest + +from discord.events.guild import ( + GuildBanAdd, + GuildBanRemove, + GuildDelete, + GuildMemberJoin, + GuildMemberRemove, + GuildMemberUpdate, + GuildRoleCreate, + GuildRoleDelete, + GuildRoleUpdate, + GuildUpdate, +) +from discord.guild import Guild +from discord.member import Member +from tests.event_helpers import emit_and_capture, populate_guild_cache +from tests.fixtures import ( + create_guild_payload, + create_member_payload, + create_mock_state, + create_user_payload, +) + + +@pytest.mark.asyncio +async def test_guild_member_join(): + """Test that GUILD_MEMBER_JOIN event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + user_id = 123456789 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create member payload + member_data = create_member_payload(user_id, guild_id, "NewMember") + member_data["guild_id"] = str(guild_id) + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_MEMBER_JOIN", member_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildMemberJoin) + + event = capture.get_last_event() + assert event is not None + assert isinstance(event, Member) + assert event.id == user_id + + +@pytest.mark.asyncio +async def test_guild_member_remove(): + """Test that GUILD_MEMBER_REMOVE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + user_id = 123456789 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Add member first + member_data = create_member_payload(user_id, guild_id, "TestMember") + member_data["guild_id"] = str(guild_id) + await state.emitter.emit("GUILD_MEMBER_JOIN", member_data) + + # Create remove payload + remove_data = { + "guild_id": str(guild_id), + "user": create_user_payload(user_id, "TestMember"), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_MEMBER_REMOVE", remove_data) + + # Assertions + # Event may or may not be emitted depending on whether member exists + assert capture.call_count >= 0 + + +@pytest.mark.asyncio +async def test_guild_member_update(): + """Test that GUILD_MEMBER_UPDATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + user_id = 123456789 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Add member first + member_data = create_member_payload(user_id, guild_id, "TestMember") + member_data["guild_id"] = str(guild_id) + await state.emitter.emit("GUILD_MEMBER_JOIN", member_data) + + # Update member + updated_data = create_member_payload(user_id, guild_id, "TestMember") + updated_data["guild_id"] = str(guild_id) + updated_data["nick"] = "NewNick" + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_MEMBER_UPDATE", updated_data) + + # Assertions + # Event may or may not be emitted depending on cache state + assert capture.call_count >= 0 + + +@pytest.mark.asyncio +async def test_guild_role_create(): + """Test that GUILD_ROLE_CREATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + role_id = 555555555 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create role payload + role_data = { + "guild_id": str(guild_id), + "role": { + "id": str(role_id), + "name": "Test Role", + "colors": { + "primary_color": 0xFF0000, + }, + "hoist": False, + "position": 1, + "permissions": "0", + "managed": False, + "mentionable": True, + }, + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_ROLE_CREATE", role_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildRoleCreate) + + event = capture.get_last_event() + assert event is not None + assert event.id == role_id + assert event.name == "Test Role" + + +@pytest.mark.asyncio +async def test_guild_role_update(): + """Test that GUILD_ROLE_UPDATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + role_id = 555555555 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create role first + role_data = { + "guild_id": str(guild_id), + "role": { + "id": str(role_id), + "name": "Test Role", + "colors": { + "primary_color": 0xFF0000, + "secondary_color": 0x00FF00, + }, + "hoist": False, + "position": 1, + "permissions": "0", + "managed": False, + "mentionable": True, + }, + } + await state.emitter.emit("GUILD_ROLE_CREATE", role_data) + + # Update role + updated_role_data = { + "guild_id": str(guild_id), + "role": { + "id": str(role_id), + "name": "Updated Role", + "colors": { + "primary_color": 0x0000FF, + "secondary_color": 0xFFFF00, + }, + "hoist": True, + "position": 2, + "permissions": "8", + "managed": False, + "mentionable": True, + }, + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_ROLE_UPDATE", updated_role_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildRoleUpdate) + + event = capture.get_last_event() + assert event is not None + assert event.id == role_id + assert event.name == "Updated Role" + assert event.old.name == "Test Role" + + +@pytest.mark.asyncio +async def test_guild_role_delete(): + """Test that GUILD_ROLE_DELETE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + role_id = 555555555 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create role first + role_data = { + "guild_id": str(guild_id), + "role": { + "id": str(role_id), + "name": "Test Role", + "colors": { + "primary_color": 0xFF0000, + "secondary_color": 0x00FF00, + }, + "hoist": False, + "position": 1, + "permissions": "0", + "managed": False, + "mentionable": True, + }, + } + await state.emitter.emit("GUILD_ROLE_CREATE", role_data) + + # Delete role + delete_data = { + "guild_id": str(guild_id), + "role_id": str(role_id), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_ROLE_DELETE", delete_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildRoleDelete) + + event = capture.get_last_event() + assert event is not None + assert event.id == role_id + + +@pytest.mark.asyncio +async def test_guild_update(): + """Test that GUILD_UPDATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id, "Original Name") + await populate_guild_cache(state, guild_id, guild_data) + + # Update guild + updated_data = create_guild_payload(guild_id, "Updated Name") + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_UPDATE", updated_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildUpdate) + + event = capture.get_last_event() + assert event is not None + assert event.id == guild_id + assert event.name == "Updated Name" + assert event.old.name == "Original Name" + + +@pytest.mark.asyncio +async def test_guild_delete(): + """Test that GUILD_DELETE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Delete guild + delete_data = { + "id": str(guild_id), + "unavailable": False, + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_DELETE", delete_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildDelete) + + event = capture.get_last_event() + assert event is not None + assert event.id == guild_id + + +@pytest.mark.asyncio +async def test_guild_ban_add(): + """Test that GUILD_BAN_ADD event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + user_id = 123456789 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create ban payload + ban_data = { + "guild_id": str(guild_id), + "user": create_user_payload(user_id, "BannedUser"), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_BAN_ADD", ban_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildBanAdd) + + event = capture.get_last_event() + assert event is not None + assert event.id == user_id + + +@pytest.mark.asyncio +async def test_guild_ban_remove(): + """Test that GUILD_BAN_REMOVE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + user_id = 123456789 + + # Populate cache with guild + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + # Create unban payload + unban_data = { + "guild_id": str(guild_id), + "user": create_user_payload(user_id, "UnbannedUser"), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_BAN_REMOVE", unban_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildBanRemove) + + event = capture.get_last_event() + assert event is not None + assert event.id == user_id diff --git a/tests/events/test_events_soundboard.py b/tests/events/test_events_soundboard.py new file mode 100644 index 0000000000..e72d151a2f --- /dev/null +++ b/tests/events/test_events_soundboard.py @@ -0,0 +1,226 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import pytest + +from discord.events.soundboard import ( + GuildSoundboardSoundCreate, + GuildSoundboardSoundDelete, + GuildSoundboardSoundUpdate, + SoundboardSounds, +) +from discord.soundboard import SoundboardSound +from tests.event_helpers import emit_and_capture +from tests.fixtures import create_mock_state, create_soundboard_sound_payload + + +@pytest.mark.asyncio +async def test_soundboard_sounds(): + """Test that SOUNDBOARD_SOUNDS event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + + # Create soundboard sounds payload + sounds_data = { + "guild_id": str(guild_id), + "soundboard_sounds": [ + create_soundboard_sound_payload(444444444, guild_id, "sound1"), + create_soundboard_sound_payload(444444445, guild_id, "sound2"), + ], + } + + # Emit event and capture + capture = await emit_and_capture(state, "SOUNDBOARD_SOUNDS", sounds_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(SoundboardSounds) + + event = capture.get_last_event() + assert event is not None + assert event.guild_id == guild_id + assert len(event.sounds) == 2 + assert event.sounds[0].name == "sound1" + assert event.sounds[1].name == "sound2" + + # Verify sounds are cached + sound1 = await state.cache.get_sound(444444444) + assert sound1 is not None + assert sound1.name == "sound1" + + +@pytest.mark.asyncio +async def test_guild_soundboard_sound_create(): + """Test that GUILD_SOUNDBOARD_SOUND_CREATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + sound_id = 444444444 + + # Create sound payload + sound_data = create_soundboard_sound_payload(sound_id, guild_id, "new-sound", emoji_name="🎵") + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildSoundboardSoundCreate) + + event = capture.get_last_event() + assert event is not None + assert event.sound.id == sound_id + assert event.sound.name == "new-sound" + + # Verify sound is cached + cached_sound = await state.cache.get_sound(sound_id) + assert cached_sound is not None + assert cached_sound.name == "new-sound" + + +@pytest.mark.asyncio +async def test_guild_soundboard_sound_update(): + """Test that GUILD_SOUNDBOARD_SOUND_UPDATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + sound_id = 444444444 + + # Create and cache original sound + original_sound = SoundboardSound( + state=state, + http=state.http, + data=create_soundboard_sound_payload(sound_id, guild_id, "original-name"), + ) + await state.cache.store_sound(original_sound) + + # Create updated sound payload + updated_data = create_soundboard_sound_payload(sound_id, guild_id, "updated-name") + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_SOUNDBOARD_SOUND_UPDATE", updated_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildSoundboardSoundUpdate) + + event = capture.get_last_event() + assert event is not None + assert event.before.name == "original-name" + assert event.after.name == "updated-name" + assert event.before.id == sound_id + assert event.after.id == sound_id + + +@pytest.mark.asyncio +async def test_guild_soundboard_sound_update_without_cache(): + """Test that GUILD_SOUNDBOARD_SOUND_UPDATE returns None when sound is not cached.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + sound_id = 444444444 + + # Don't cache the sound + + # Create sound payload + sound_data = create_soundboard_sound_payload(sound_id, guild_id, "new-sound") + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_SOUNDBOARD_SOUND_UPDATE", sound_data) + + # Assertions - should not emit event if sound not found + capture.assert_not_called() + + +@pytest.mark.asyncio +async def test_guild_soundboard_sound_delete(): + """Test that GUILD_SOUNDBOARD_SOUND_DELETE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + sound_id = 444444444 + + # Create and cache sound + sound = SoundboardSound( + state=state, + http=state.http, + data=create_soundboard_sound_payload(sound_id, guild_id, "test-sound"), + ) + await state.cache.store_sound(sound) + + # Create delete payload + delete_data = { + "guild_id": str(guild_id), + "sound_id": str(sound_id), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_SOUNDBOARD_SOUND_DELETE", delete_data) + + # Assertions + capture.assert_called_once() + capture.assert_called_with_event_type(GuildSoundboardSoundDelete) + + event = capture.get_last_event() + assert event is not None + assert event.sound is not None + assert event.sound.id == sound_id + assert event.sound.name == "test-sound" + assert event.raw.sound_id == sound_id + assert event.raw.guild_id == guild_id + + # Verify sound is removed from cache + cached_sound = await state.cache.get_sound(sound_id) + assert cached_sound is None + + +@pytest.mark.asyncio +async def test_guild_soundboard_sound_delete_without_cache(): + """Test that GUILD_SOUNDBOARD_SOUND_DELETE handles missing sound gracefully.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + sound_id = 444444444 + + # Don't cache the sound + + # Create delete payload + delete_data = { + "guild_id": str(guild_id), + "sound_id": str(sound_id), + } + + # Emit event and capture + capture = await emit_and_capture(state, "GUILD_SOUNDBOARD_SOUND_DELETE", delete_data) + + # Assertions - should still emit event with None sound + capture.assert_called_once() + capture.assert_called_with_event_type(GuildSoundboardSoundDelete) + + event = capture.get_last_event() + assert event is not None + assert event.sound is None + assert event.raw.sound_id == sound_id + assert event.raw.guild_id == guild_id diff --git a/tests/events/test_events_thread.py b/tests/events/test_events_thread.py new file mode 100644 index 0000000000..77f752357a --- /dev/null +++ b/tests/events/test_events_thread.py @@ -0,0 +1,156 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import pytest + +from discord.events.thread import ThreadCreate, ThreadDelete, ThreadJoin, ThreadUpdate +from tests.event_helpers import emit_and_capture, populate_guild_cache +from tests.fixtures import ( + create_channel_payload, + create_guild_payload, + create_mock_state, + create_thread_payload, +) + + +@pytest.mark.asyncio +async def test_thread_create(): + """Test that THREAD_CREATE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + thread_id = 333333333 + + # Populate cache with guild and parent channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Create thread payload + thread_data = create_thread_payload( + thread_id=thread_id, guild_id=guild_id, parent_id=channel_id, name="test-thread" + ) + + # Emit event and capture + capture = await emit_and_capture(state, "THREAD_CREATE", thread_data) + + # Assertions + # ThreadCreate may emit THREAD_JOIN or return the thread itself + assert capture.call_count >= 0 # May or may not emit depending on just_joined + + +@pytest.mark.asyncio +async def test_thread_create_newly_created(): + """Test that THREAD_CREATE event with newly_created flag.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + thread_id = 333333333 + + # Populate cache with guild and parent channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Create thread payload with newly_created flag + thread_data = create_thread_payload( + thread_id=thread_id, guild_id=guild_id, parent_id=channel_id, name="test-thread" + ) + thread_data["newly_created"] = True + + # Emit event and capture + capture = await emit_and_capture(state, "THREAD_CREATE", thread_data) + + # Assertions - newly created threads emit ThreadCreate, not ThreadJoin + if capture.call_count > 0: + event = capture.get_last_event() + assert event is not None + + +@pytest.mark.asyncio +async def test_thread_delete(): + """Test that THREAD_DELETE event is emitted correctly.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + thread_id = 333333333 + + # Populate cache with guild and parent channel + guild_data = create_guild_payload(guild_id) + await populate_guild_cache(state, guild_id, guild_data) + + channel_data = create_channel_payload(channel_id=channel_id, guild_id=guild_id, name="test-channel") + await state.emitter.emit("CHANNEL_CREATE", channel_data) + + # Create thread first + thread_data = create_thread_payload( + thread_id=thread_id, guild_id=guild_id, parent_id=channel_id, name="test-thread" + ) + thread_data["newly_created"] = True + await state.emitter.emit("THREAD_CREATE", thread_data) + + # Create delete payload + delete_data = { + "id": str(thread_id), + "guild_id": str(guild_id), + "parent_id": str(channel_id), + "type": 11, # PUBLIC_THREAD + } + + # Emit event and capture + capture = await emit_and_capture(state, "THREAD_DELETE", delete_data) + + # Assertions + # The event may or may not be emitted depending on whether thread exists + assert capture.call_count >= 0 + + +@pytest.mark.asyncio +async def test_thread_create_without_guild(): + """Test that THREAD_CREATE returns None when guild is not found.""" + # Setup + state = create_mock_state() + guild_id = 111111111 + channel_id = 222222222 + thread_id = 333333333 + + # Don't populate cache with guild + + # Create thread payload + thread_data = create_thread_payload( + thread_id=thread_id, guild_id=guild_id, parent_id=channel_id, name="test-thread" + ) + + # Emit event and capture + capture = await emit_and_capture(state, "THREAD_CREATE", thread_data) + + # Assertions - should not emit event if guild not found + capture.assert_not_called() diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 0000000000..69dfab86c7 --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,408 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +from discord.app.cache import Cache +from discord.app.state import ConnectionState +from discord.bot import Bot +from discord.channel import DMChannel, TextChannel +from discord.enums import ChannelType +from discord.flags import Intents +from discord.guild import Guild +from discord.http import HTTPClient +from discord.member import Member +from discord.soundboard import SoundboardSound +from discord.channel.thread import Thread +from discord.user import ClientUser, User + + +class MockCache: + """Mock implementation of the Cache protocol for testing.""" + + def __init__(self): + self._users: dict[int, User] = {} + self._guilds: dict[int, Guild] = {} + self._sounds: dict[int, SoundboardSound] = {} + self._guild_members: dict[int, dict[int, Member]] = defaultdict(dict) + self.__state: ConnectionState | None = None + + @property + def _state(self) -> ConnectionState: + if self.__state is None: + raise RuntimeError("Cache state has not been initialized.") + return self.__state + + @_state.setter + def _state(self, state: ConnectionState) -> None: + self.__state = state + + # Users + async def get_all_users(self) -> list[User]: + return list(self._users.values()) + + async def store_user(self, payload: dict[str, Any]) -> User: + user = User(state=self._state, data=payload) + self._users[user.id] = user + return user + + async def delete_user(self, user_id: int) -> None: + self._users.pop(user_id, None) + + async def get_user(self, user_id: int) -> User | None: + return self._users.get(user_id) + + # Guilds + async def get_all_guilds(self) -> list[Guild]: + return list(self._guilds.values()) + + async def get_guild(self, id: int) -> Guild | None: + return self._guilds.get(id) + + async def add_guild(self, guild: Guild) -> None: + self._guilds[guild.id] = guild + + async def delete_guild(self, guild: Guild) -> None: + self._guilds.pop(guild.id, None) + + # Soundboard sounds + async def get_sound(self, sound_id: int) -> SoundboardSound | None: + return self._sounds.get(sound_id) + + async def store_sound(self, sound: SoundboardSound) -> None: + self._sounds[sound.id] = sound + + async def delete_sound(self, sound_id: int) -> None: + self._sounds.pop(sound_id, None) + + # Guild members + async def store_member(self, member: Member) -> None: + self._guild_members[member.guild.id][member.id] = member + + async def get_member(self, guild_id: int, user_id: int) -> Member | None: + return self._guild_members[guild_id].get(user_id) + + async def delete_member(self, guild_id: int, user_id: int) -> None: + self._guild_members[guild_id].pop(user_id, None) + + async def delete_guild_members(self, guild_id: int) -> None: + self._guild_members.pop(guild_id, None) + + async def get_guild_members(self, guild_id: int) -> list[Member]: + return list(self._guild_members.get(guild_id, {}).values()) + + async def get_all_members(self) -> list[Member]: + members = [] + for guild_members in self._guild_members.values(): + members.extend(guild_members.values()) + return members + + # Stubs for other required methods + async def get_all_stickers(self) -> list: + return [] + + async def get_sticker(self, sticker_id: int): + return None + + async def store_sticker(self, guild, data): + return None + + async def delete_sticker(self, sticker_id: int) -> None: + pass + + async def store_view(self, view, message_id: int | None) -> None: + pass + + async def delete_view_on(self, message_id: int) -> None: + pass + + async def get_all_views(self) -> list: + return [] + + async def store_modal(self, modal, user_id: int) -> None: + pass + + async def delete_modal(self, custom_id: str) -> None: + pass + + async def get_all_modals(self) -> list: + return [] + + async def store_guild_emoji(self, guild, data): + return None + + async def store_app_emoji(self, application_id: int, data): + return None + + async def get_all_emojis(self) -> list: + return [] + + async def get_emoji(self, emoji_id: int | None): + return None + + async def delete_emoji(self, emoji) -> None: + pass + + async def get_all_polls(self) -> list: + return [] + + async def get_poll(self, message_id: int): + return None + + async def store_poll(self, poll, message_id: int) -> None: + pass + + async def get_private_channels(self) -> list: + return [] + + async def get_private_channel(self, channel_id: int): + return None + + async def get_private_channel_by_user(self, user_id: int): + return None + + async def store_private_channel(self, channel) -> None: + pass + + async def store_message(self, message, channel): + return None + + async def store_built_message(self, message) -> None: + pass + + async def upsert_message(self, message) -> None: + pass + + async def delete_message(self, message_id: int) -> None: + pass + + async def get_message(self, message_id: int): + return None + + async def get_all_messages(self) -> list: + return [] + + +def create_mock_http() -> HTTPClient: + """Create a mock HTTP client.""" + http = MagicMock(spec=HTTPClient) + http.get_all_application_emojis = AsyncMock(return_value={"items": []}) + return http + + +def create_mock_state(*, intents: Intents | None = None, cache: Cache | None = None) -> ConnectionState: + """Create a mock ConnectionState for testing.""" + from discord.app.event_emitter import EventEmitter + from discord.flags import MemberCacheFlags + + if cache is None: + cache = MockCache() + + http = create_mock_http() + + state = MagicMock(spec=ConnectionState) + state.http = http + state.cache = cache + state.cache._state = state + state.intents = intents or Intents.default() + state.application_id = 123456789 + state.self_id = 987654321 + state.cache_app_emojis = False + state._guilds = {} + state._private_channels = {} + state.member_cache_flags = MemberCacheFlags.from_intents(state.intents) + + # Create real EventEmitter + state.emitter = EventEmitter(state) + + # Make _get_guild async + async def _get_guild(guild_id: int) -> Guild | None: + return await state.cache.get_guild(guild_id) + + state._get_guild = _get_guild + + # Make _add_guild async + async def _add_guild(guild: Guild) -> None: + await state.cache.add_guild(guild) + + state._add_guild = _add_guild + + # Make _remove_guild async + async def _remove_guild(guild: Guild) -> None: + await state.cache.delete_guild(guild) + + state._remove_guild = _remove_guild + + # Make store_user async + async def store_user(payload: dict[str, Any]) -> User: + return await state.cache.store_user(payload) + + state.store_user = store_user + + # Make _get_private_channel async + async def _get_private_channel(channel_id: int): + return await state.cache.get_private_channel(channel_id) + + state._get_private_channel = _get_private_channel + + return state + + +def create_mock_bot(*, intents: Intents | None = None, cache: Cache | None = None) -> Bot: + """Create a mock ClientUser for testing.""" + state = create_mock_state(intents=intents, cache=cache) + bot = Bot() + state.emitter = bot._connection.emitter + bot._connection = state + return bot + + +def create_user_payload(user_id: int = 123456789, username: str = "TestUser") -> dict[str, Any]: + """Create a mock user payload.""" + return { + "id": str(user_id), + "username": username, + "discriminator": "0001", + "global_name": username, + "avatar": "abc123", + "bot": False, + } + + +def create_guild_payload(guild_id: int = 111111111, name: str = "Test Guild") -> dict[str, Any]: + """Create a mock guild payload.""" + return { + "id": str(guild_id), + "name": name, + "icon": None, + "splash": None, + "discovery_splash": None, + "owner_id": "123456789", + "afk_channel_id": None, + "afk_timeout": 300, + "verification_level": 0, + "default_message_notifications": 0, + "explicit_content_filter": 0, + "roles": [], + "emojis": [], + "features": [], + "mfa_level": 0, + "system_channel_id": None, + "system_channel_flags": 0, + "rules_channel_id": None, + "vanity_url_code": None, + "description": None, + "banner": None, + "premium_tier": 0, + "preferred_locale": "en-US", + "public_updates_channel_id": None, + "nsfw_level": 0, + "premium_progress_bar_enabled": False, + } + + +def create_channel_payload( + channel_id: int = 222222222, + guild_id: int = 111111111, + name: str = "test-channel", + channel_type: int = 0, +) -> dict[str, Any]: + """Create a mock channel payload.""" + return { + "id": str(channel_id), + "type": channel_type, + "guild_id": str(guild_id), + "name": name, + "position": 0, + "permission_overwrites": [], + "nsfw": False, + "parent_id": None, + } + + +def create_thread_payload( + thread_id: int = 333333333, + guild_id: int = 111111111, + parent_id: int = 222222222, + name: str = "test-thread", + owner_id: int = 123456789, +) -> dict[str, Any]: + """Create a mock thread payload.""" + return { + "id": str(thread_id), + "type": ChannelType.public_thread.value, + "guild_id": str(guild_id), + "name": name, + "parent_id": str(parent_id), + "owner_id": str(owner_id), + "thread_metadata": { + "archived": False, + "auto_archive_duration": 1440, + "archive_timestamp": datetime.now(timezone.utc).isoformat(), + "locked": False, + "create_timestamp": datetime.now(timezone.utc).isoformat(), + }, + "message_count": 0, + "member_count": 1, + } + + +def create_soundboard_sound_payload( + sound_id: int = 444444444, + guild_id: int = 111111111, + name: str = "test-sound", + emoji_name: str | None = None, +) -> dict[str, Any]: + """Create a mock soundboard sound payload.""" + payload = { + "sound_id": str(sound_id), + "name": name, + "volume": 1.0, + "guild_id": str(guild_id), + "available": True, + } + if emoji_name: + payload["emoji_name"] = emoji_name + return payload + + +def create_member_payload( + user_id: int = 123456789, + guild_id: int = 111111111, + username: str = "TestUser", + roles: list[str] | None = None, +) -> dict[str, Any]: + """Create a mock member payload.""" + return { + "user": create_user_payload(user_id, username), + "nick": None, + "roles": roles or [], + "joined_at": datetime.now(timezone.utc).isoformat(), + "premium_since": None, + "deaf": False, + "mute": False, + } diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..f18bb78f4c --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for py-cord.""" diff --git a/tests/integration/test_event_listeners.py b/tests/integration/test_event_listeners.py new file mode 100644 index 0000000000..d9c466023f --- /dev/null +++ b/tests/integration/test_event_listeners.py @@ -0,0 +1,327 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import pytest + +from discord.bot import Bot +from discord.events.soundboard import GuildSoundboardSoundCreate +from discord.gears import Gear +from tests.fixtures import create_mock_bot, create_mock_state, create_soundboard_sound_payload + + +@pytest.mark.asyncio +async def test_add_listener(): + """Test adding a listener using add_listener method.""" + # Setup + bot = create_mock_bot() + + # Track if listener was called + called = [] + + async def on_sound_create(event: GuildSoundboardSoundCreate): + called.append(event) + + # Add listener + bot.add_listener(on_sound_create, event=GuildSoundboardSoundCreate) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(called) == 1 + assert isinstance(called[0], GuildSoundboardSoundCreate) + assert called[0].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_listen_decorator_on_bot_instance(): + """Test using @bot.listen decorator on a bot instance.""" + # Setup + bot = create_mock_bot() + + # Track if listener was called + called = [] + + @bot.listen(GuildSoundboardSoundCreate) + async def on_sound_create(event: GuildSoundboardSoundCreate): + called.append(event) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(called) == 1 + assert isinstance(called[0], GuildSoundboardSoundCreate) + assert called[0].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_gear_with_class_decorator(): + """Test using @Gear.listen decorator on a class method.""" + + # Create a custom gear with class decorator + class MyGear(Gear): + def __init__(self): + super().__init__() + self.called = [] + + @Gear.listen(GuildSoundboardSoundCreate) + async def on_sound_create(self, event: GuildSoundboardSoundCreate): + self.called.append(event) + + # Setup + bot = create_mock_bot() + + # Add gear to bot + my_gear = MyGear() + bot.attach_gear(my_gear) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(my_gear.called) == 1 + assert isinstance(my_gear.called[0], GuildSoundboardSoundCreate) + assert my_gear.called[0].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_gear_instance_decorator(): + """Test using @gear.listen decorator on a gear instance.""" + # Setup + bot = create_mock_bot() + + # Create gear instance + my_gear = Gear() + + # Track if listener was called + called = [] + + @my_gear.listen(GuildSoundboardSoundCreate) + async def on_sound_create(event: GuildSoundboardSoundCreate): + called.append(event) + + # Add gear to bot + bot.attach_gear(my_gear) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(called) == 1 + assert isinstance(called[0], GuildSoundboardSoundCreate) + assert called[0].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_gear_add_listener(): + """Test using gear.add_listener method.""" + # Setup + bot = create_mock_bot() + + # Create gear instance + my_gear = Gear() + + # Track if listener was called + called = [] + + async def on_sound_create(event: GuildSoundboardSoundCreate): + called.append(event) + + # Add listener to gear + my_gear.add_listener(on_sound_create, event=GuildSoundboardSoundCreate) + + # Add gear to bot + bot.attach_gear(my_gear) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(called) == 1 + assert isinstance(called[0], GuildSoundboardSoundCreate) + assert called[0].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_nested_gears(): + """Test that nested gears work correctly.""" + + class ParentGear(Gear): + def __init__(self): + super().__init__() + self.called = [] + + @Gear.listen(GuildSoundboardSoundCreate) + async def on_sound_create(self, event: GuildSoundboardSoundCreate): + self.called.append(("parent", event)) + + class ChildGear(Gear): + def __init__(self): + super().__init__() + self.called = [] + + @Gear.listen(GuildSoundboardSoundCreate) + async def on_sound_create(self, event: GuildSoundboardSoundCreate): + self.called.append(("child", event)) + + # Setup + bot = create_mock_bot() + + # Create gears + parent_gear = ParentGear() + child_gear = ChildGear() + + # Add child to parent + parent_gear.attach_gear(child_gear) + + # Add parent to bot + bot.attach_gear(parent_gear) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions + assert len(parent_gear.called) == 1 + assert parent_gear.called[0][0] == "parent" + assert parent_gear.called[0][1].sound.name == "test-sound" + + assert len(child_gear.called) == 1 + assert child_gear.called[0][0] == "child" + assert child_gear.called[0][1].sound.name == "test-sound" + + +@pytest.mark.asyncio +async def test_remove_listener(): + """Test removing a listener.""" + # Setup + bot = create_mock_bot() + + # Track if listener was called + called = [] + + async def on_sound_create(event: GuildSoundboardSoundCreate): + called.append(event) + + # Add listener + bot.add_listener(on_sound_create, event=GuildSoundboardSoundCreate) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound-1") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Should be called once + assert len(called) == 1 + + # Remove listener + bot.remove_listener(on_sound_create) + + # Emit another event + sound_data = create_soundboard_sound_payload(444444445, 111111111, "test-sound-2") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + await asyncio.sleep(0.1) + + # Should still be 1 (not called again) + assert len(called) == 1 + + +@pytest.mark.asyncio +async def test_multiple_listeners_same_event(): + """Test that multiple listeners for the same event all get called.""" + # Setup + bot = create_mock_bot() + + # Track calls + calls = [] + + async def listener1(event: GuildSoundboardSoundCreate): + calls.append("listener1") + + async def listener2(event: GuildSoundboardSoundCreate): + calls.append("listener2") + + @bot.listen(GuildSoundboardSoundCreate) + async def listener3(event: GuildSoundboardSoundCreate): + calls.append("listener3") + + # Add listeners + bot.add_listener(listener1, event=GuildSoundboardSoundCreate) + bot.add_listener(listener2, event=GuildSoundboardSoundCreate) + + # Create sound payload and emit event + sound_data = create_soundboard_sound_payload(444444444, 111111111, "test-sound") + await bot._connection.emitter.emit("GUILD_SOUNDBOARD_SOUND_CREATE", sound_data) + + # Wait a bit for event processing + import asyncio + + await asyncio.sleep(0.1) + + # Assertions - all three should be called + assert len(calls) == 3 + assert "listener1" in calls + assert "listener2" in calls + assert "listener3" in calls