Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion discord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from .sticker import *
from .team import *
from .template import *
from .threads import *
from .channel.thread import *
from .user import *
from .voice_client import *
from .webhook import *
Expand Down
19 changes: 2 additions & 17 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
from __future__ import annotations

import asyncio
import copy
import time
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Protocol,
Expand All @@ -43,17 +41,11 @@

from . import utils
from .context_managers import Typing
from .enums import ChannelType
from .errors import ClientException, InvalidArgument
from .file import File, VoiceMessage
from .flags import ChannelFlags, MessageFlags
from .invite import Invite
from .flags import MessageFlags
from .iterators import HistoryIterator, MessagePinIterator
from .mentions import AllowedMentions
from .partial_emoji import PartialEmoji, _EmojiTag
from .permissions import PermissionOverwrite, Permissions
from .role import Role
from .scheduled_events import ScheduledEvent
from .sticker import GuildSticker, StickerItem
from .utils.private import warn_deprecated
from .voice_client import VoiceClient, VoiceProtocol
Expand All @@ -62,7 +54,6 @@
"Snowflake",
"User",
"PrivateChannel",
"GuildChannel",
"Messageable",
"Connectable",
"Mentionable",
Expand All @@ -76,7 +67,6 @@
from .app.state import ConnectionState
from .asset import Asset
from .channel import (
CategoryChannel,
DMChannel,
GroupChannel,
PartialMessageable,
Expand All @@ -86,14 +76,9 @@
)
from .client import Client
from .embeds import Embed
from .enums import InviteTarget
from .guild import Guild
from .member import Member
from .message import Message, MessageReference, PartialMessage
from .poll import Poll
from .threads import Thread
from .types.channel import Channel as ChannelPayload
from .types.channel import GuildChannel as GuildChannelPayload
from .channel.thread import Thread
from .types.channel import OverwriteType
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
from .ui.view import View
Expand Down
32 changes: 31 additions & 1 deletion discord/app/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from discord import utils
from discord.member import Member
from discord.message import Message
from discord.soundboard import SoundboardSound

from ..channel import DMChannel
from ..emoji import AppEmoji, GuildEmoji
Expand Down Expand Up @@ -142,6 +143,8 @@ async def store_private_channel(self, channel: "PrivateChannel") -> None: ...

async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: ...

async def store_built_message(self, message: Message) -> None: ...

async def upsert_message(self, message: Message) -> None: ...

async def delete_message(self, message_id: int) -> None: ...
Expand All @@ -166,6 +169,14 @@ async def get_all_members(self) -> list[Member]: ...

async def clear(self, views: bool = True) -> None: ...

async def store_sound(self, sound: SoundboardSound) -> None: ...

async def get_sound(self, sound_id: int) -> SoundboardSound | None: ...

async def get_all_sounds(self) -> list[SoundboardSound]: ...

async def delete_sound(self, sound_id: int) -> None: ...


class MemoryCache(Cache):
def __init__(self, max_messages: int | None = None) -> None:
Expand All @@ -177,6 +188,7 @@ def __init__(self, max_messages: int | None = None) -> None:
self._stickers: dict[int, list[GuildSticker]] = {}
self._views: dict[str, View] = {}
self._modals: dict[str, Modal] = {}
self._sounds: dict[int, SoundboardSound] = {}
self._messages: Deque[Message] = deque(maxlen=self.max_messages)

self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {}
Expand Down Expand Up @@ -362,9 +374,15 @@ async def upsert_message(self, message: Message) -> None:

async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message:
msg = await Message._from_data(state=self._state, channel=channel, data=message)
self._messages.append(msg)
self.store_built_message(msg)
return msg

async def store_built_message(self, message: Message) -> None:
self._messages.append(message)

async def delete_message(self, message_id: int) -> None:
self._messages.remove(utils.find(lambda m: m.id == message_id, reversed(self._messages)))

async def get_message(self, message_id: int) -> Message | None:
return utils.find(lambda m: m.id == message_id, reversed(self._messages))

Expand Down Expand Up @@ -393,3 +411,15 @@ async def get_guild_members(self, guild_id: int) -> list[Member]:

async def get_all_members(self) -> list[Member]:
return self._flatten([list(members.values()) for members in self._guild_members.values()])

async def store_sound(self, sound: SoundboardSound) -> None:
self._sounds[sound.id] = sound

async def get_sound(self, sound_id: int) -> SoundboardSound | None:
return self._sounds.get(sound_id)

async def get_all_sounds(self) -> list[SoundboardSound]:
return list(self._sounds.values())

async def delete_sound(self, sound_id: int) -> None:
self._sounds.pop(sound_id, None)
103 changes: 62 additions & 41 deletions discord/app/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import asyncio
from abc import ABC, abstractmethod
from asyncio import Future
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from collections.abc import Awaitable, Coroutine
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeAlias, TypeVar

from typing_extensions import Self

Expand All @@ -43,60 +43,81 @@ class Event(ABC):
@abstractmethod
async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: ...

def _populate_from_slots(self, obj: Any) -> None:
"""
Populate this event instance with attributes from another object.

Handles both __slots__ and __dict__ based objects.

Parameters
----------
obj: Any
The object to copy attributes from.
"""
# Collect all slots from the object's class hierarchy
slots = set()
for klass in type(obj).__mro__:
if hasattr(klass, "__slots__"):
slots.update(klass.__slots__)

# Copy slot attributes
for slot in slots:
if hasattr(obj, slot):
try:
setattr(self, slot, getattr(obj, slot))
except AttributeError:
# Some slots might be read-only or not settable
pass

# Also copy __dict__ if it exists
if hasattr(obj, "__dict__"):
for key, value in obj.__dict__.items():
try:
setattr(self, key, value)
except AttributeError:
pass


ListenerCallback: TypeAlias = Callable[[Event], Any]


class EventReciever(Protocol):
def __call__(self, event: Event) -> Awaitable[Any]: ...


class EventEmitter:
def __init__(self, state: "ConnectionState") -> None:
self._listeners: dict[type[Event], list[Callable]] = {}
self._events: dict[str, list[type[Event]]]
self._wait_fors: dict[type[Event], list[Future]] = defaultdict(list)
self._state = state
self._receivers: list[EventReciever] = []
self._events: dict[str, list[type[Event]]] = defaultdict(list)
self._state: ConnectionState = state

from ..events import ALL_EVENTS # noqa: PLC0415

for event_cls in ALL_EVENTS:
self.add_event(event_cls)

def add_event(self, event: type[Event]) -> None:
try:
self._events[event.__event_name__].append(event)
except KeyError:
self._events[event.__event_name__] = [event]
self._events[event.__event_name__].append(event)

def remove_event(self, event: type[Event]) -> list[type[Event]] | None:
return self._events.pop(event.__event_name__, None)

def add_listener(self, event: type[Event], listener: Callable) -> None:
try:
self._listeners[event].append(listener)
except KeyError:
self.add_event(event)
self._listeners[event] = [listener]

def remove_listener(self, event: type[Event], listener: Callable) -> None:
self._listeners[event].remove(listener)

def add_wait_for(self, event: type[T]) -> Future[T]:
fut = Future()
def add_receiver(self, receiver: EventReciever) -> None:
self._receivers.append(receiver)

self._wait_fors[event].append(fut)

return fut

def remove_wait_for(self, event: type[Event], fut: Future) -> None:
self._wait_fors[event].remove(fut)
def remove_receiver(self, receiver: EventReciever) -> None:
self._receivers.remove(receiver)

async def emit(self, event_str: str, data: Any) -> None:
events = self._events.get(event_str, [])

for event in events:
eve = await event.__load__(data=data, state=self._state)
coros: list[Awaitable[None]] = []
for event_cls in events:
event = await event_cls.__load__(data=data, state=self._state)

if eve is None:
if event is None:
continue

funcs = self._listeners.get(event, [])

for func in funcs:
asyncio.create_task(func(eve))

wait_fors = self._wait_fors.get(event)
coros.extend(receiver(event) for receiver in self._receivers)

if wait_fors is not None:
for wait_for in wait_fors:
wait_for.set_result(eve)
self._wait_fors.pop(event)
await asyncio.gather(*coros)
35 changes: 28 additions & 7 deletions discord/app/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
cast,
)

from discord.soundboard import SoundboardSound

from .. import utils
from ..activity import BaseActivity
from ..automod import AutoModRule
Expand All @@ -65,7 +67,7 @@
from ..raw_models import *
from ..role import Role
from ..sticker import GuildSticker
from ..threads import Thread, ThreadMember
from ..channel.thread import Thread, ThreadMember
from ..ui.modal import Modal
from ..ui.view import View
from ..user import ClientUser, User
Expand Down Expand Up @@ -237,12 +239,12 @@ def __init__(
self._voice_clients: dict[int, VoiceClient] = {}

if not intents.members or cache_flags._empty:
self.store_user = self.create_user # type: ignore
self.store_user = self.create_user_async # type: ignore
self.deref_user = self.deref_user_no_intents # type: ignore

self.cache_app_emojis: bool = options.get("cache_app_emojis", False)

self.emitter = EventEmitter(self)
self.emitter: EventEmitter = EventEmitter(self)

self.cache: Cache = cache
self.cache._state = self
Expand Down Expand Up @@ -320,6 +322,9 @@ async def deref_user(self, user_id: int) -> None:
def create_user(self, data: UserPayload) -> User:
return User(state=self, data=data)

async def create_user_async(self, data: UserPayload) -> User:
return User(state=self, data=data)

def deref_user_no_intents(self, user_id: int) -> None:
return

Expand Down Expand Up @@ -373,6 +378,21 @@ async def _remove_guild(self, guild: Guild) -> None:

del guild

async def _add_default_sounds(self) -> None:
default_sounds = await self.http.get_default_sounds()
for default_sound in default_sounds:
sound = SoundboardSound(state=self, http=self.http, data=default_sound)
await self._add_sound(sound)

async def _add_sound(self, sound: SoundboardSound) -> None:
await self.cache.store_sound(sound)

async def _remove_sound(self, sound: SoundboardSound) -> None:
await self.cache.delete_sound(sound.id)

async def get_sounds(self) -> list[SoundboardSound]:
return list(await self.cache.get_all_sounds())

async def get_emojis(self) -> list[GuildEmoji | AppEmoji]:
return await self.cache.get_all_emojis()

Expand Down Expand Up @@ -431,7 +451,7 @@ async def _get_guild_channel(
# guild_id is in data
guild = await self._get_guild(int(guild_id or data["guild_id"])) # type: ignore
except KeyError:
channel = DMChannel._from_message(self, channel_id)
channel = DMChannel(id=channel_id, state=self)
guild = None
else:
channel = guild and guild._resolve_channel(channel_id)
Expand Down Expand Up @@ -487,15 +507,15 @@ async def query_members(
)
raise

def _get_create_guild(self, data):
async def _get_create_guild(self, data):
if data.get("unavailable") is False:
# GUILD_CREATE with unavailable in the response
# usually means that the guild has become available
# and is therefore in the cache
guild = self._get_guild(int(data["id"]))
guild = await self._get_guild(int(data["id"]))
if guild is not None:
guild.unavailable = False
guild._from_data(data)
await guild._from_data(data, self)
return guild

return self._add_guild_from_data(data)
Expand Down Expand Up @@ -660,6 +680,7 @@ async def _delay_ready(self) -> None:
future = asyncio.ensure_future(self.chunk_guild(guild))
current_bucket.append(future)
else:
await self._add_default_sounds()
future = self.loop.create_future()
future.set_result([])

Expand Down
Loading
Loading