Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion discord/app/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from discord import utils
from discord.member import Member
from discord.message import Message
from discord.soundboard import SoundboardSound

from ..channel import DMChannel
from ..emoji import AppEmoji, GuildEmoji
Expand Down Expand Up @@ -142,6 +143,8 @@ async def store_private_channel(self, channel: "PrivateChannel") -> None: ...

async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: ...

async def store_built_message(self, message: Message) -> None: ...

async def upsert_message(self, message: Message) -> None: ...

async def delete_message(self, message_id: int) -> None: ...
Expand All @@ -166,6 +169,14 @@ async def get_all_members(self) -> list[Member]: ...

async def clear(self, views: bool = True) -> None: ...

async def store_sound(self, sound: SoundboardSound) -> None: ...

async def get_sound(self, sound_id: int) -> SoundboardSound | None: ...

async def get_all_sounds(self) -> list[SoundboardSound]: ...

async def delete_sound(self, sound_id: int) -> None: ...


class MemoryCache(Cache):
def __init__(self, max_messages: int | None = None) -> None:
Expand All @@ -177,6 +188,7 @@ def __init__(self, max_messages: int | None = None) -> None:
self._stickers: dict[int, list[GuildSticker]] = {}
self._views: dict[str, View] = {}
self._modals: dict[str, Modal] = {}
self._sounds: dict[int, SoundboardSound] = {}
self._messages: Deque[Message] = deque(maxlen=self.max_messages)

self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {}
Expand Down Expand Up @@ -362,9 +374,15 @@ async def upsert_message(self, message: Message) -> None:

async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message:
msg = await Message._from_data(state=self._state, channel=channel, data=message)
self._messages.append(msg)
self.store_built_message(msg)
return msg

async def store_built_message(self, message: Message) -> None:
self._messages.append(message)

async def delete_message(self, message_id: int) -> None:
self._messages.remove(utils.find(lambda m: m.id == message_id, reversed(self._messages)))

async def get_message(self, message_id: int) -> Message | None:
return utils.find(lambda m: m.id == message_id, reversed(self._messages))

Expand Down Expand Up @@ -393,3 +411,15 @@ async def get_guild_members(self, guild_id: int) -> list[Member]:

async def get_all_members(self) -> list[Member]:
return self._flatten([list(members.values()) for members in self._guild_members.values()])

async def store_sound(self, sound: SoundboardSound) -> None:
self._sounds[sound.id] = sound

async def get_sound(self, sound_id: int) -> SoundboardSound | None:
return self._sounds.get(sound_id)

async def get_all_sounds(self) -> list[SoundboardSound]:
return list(self._sounds.values())

async def delete_sound(self, sound_id: int) -> None:
self._sounds.pop(sound_id, None)
103 changes: 62 additions & 41 deletions discord/app/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

import asyncio
from abc import ABC, abstractmethod
from asyncio import Future
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from collections.abc import Awaitable, Coroutine
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeAlias, TypeVar

from typing_extensions import Self

Expand All @@ -43,60 +43,81 @@ class Event(ABC):
@abstractmethod
async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: ...

def _populate_from_slots(self, obj: Any) -> None:
"""
Populate this event instance with attributes from another object.

Handles both __slots__ and __dict__ based objects.

Parameters
----------
obj: Any
The object to copy attributes from.
"""
# Collect all slots from the object's class hierarchy
slots = set()
for klass in type(obj).__mro__:
if hasattr(klass, "__slots__"):
slots.update(klass.__slots__)

# Copy slot attributes
for slot in slots:
if hasattr(obj, slot):
try:
setattr(self, slot, getattr(obj, slot))
except AttributeError:
# Some slots might be read-only or not settable
pass

# Also copy __dict__ if it exists
if hasattr(obj, "__dict__"):
for key, value in obj.__dict__.items():
try:
setattr(self, key, value)
except AttributeError:
pass


ListenerCallback: TypeAlias = Callable[[Event], Any]


class EventReciever(Protocol):
def __call__(self, event: Event) -> Awaitable[Any]: ...


class EventEmitter:
def __init__(self, state: "ConnectionState") -> None:
self._listeners: dict[type[Event], list[Callable]] = {}
self._events: dict[str, list[type[Event]]]
self._wait_fors: dict[type[Event], list[Future]] = defaultdict(list)
self._state = state
self._receivers: list[EventReciever] = []
self._events: dict[str, list[type[Event]]] = defaultdict(list)
self._state: ConnectionState = state

from ..events import ALL_EVENTS # noqa: PLC0415

for event_cls in ALL_EVENTS:
self.add_event(event_cls)

def add_event(self, event: type[Event]) -> None:
try:
self._events[event.__event_name__].append(event)
except KeyError:
self._events[event.__event_name__] = [event]
self._events[event.__event_name__].append(event)

def remove_event(self, event: type[Event]) -> list[type[Event]] | None:
return self._events.pop(event.__event_name__, None)

def add_listener(self, event: type[Event], listener: Callable) -> None:
try:
self._listeners[event].append(listener)
except KeyError:
self.add_event(event)
self._listeners[event] = [listener]

def remove_listener(self, event: type[Event], listener: Callable) -> None:
self._listeners[event].remove(listener)

def add_wait_for(self, event: type[T]) -> Future[T]:
fut = Future()
def add_receiver(self, receiver: EventReciever) -> None:
self._receivers.append(receiver)

self._wait_fors[event].append(fut)

return fut

def remove_wait_for(self, event: type[Event], fut: Future) -> None:
self._wait_fors[event].remove(fut)
def remove_receiver(self, receiver: EventReciever) -> None:
self._receivers.remove(receiver)

async def emit(self, event_str: str, data: Any) -> None:
events = self._events.get(event_str, [])

for event in events:
eve = await event.__load__(data=data, state=self._state)
coros: list[Awaitable[None]] = []
for event_cls in events:
event = await event_cls.__load__(data=data, state=self._state)

if eve is None:
if event is None:
continue

funcs = self._listeners.get(event, [])

for func in funcs:
asyncio.create_task(func(eve))

wait_fors = self._wait_fors.get(event)
coros.extend(receiver(event) for receiver in self._receivers)

if wait_fors is not None:
for wait_for in wait_fors:
wait_for.set_result(eve)
self._wait_fors.pop(event)
await asyncio.gather(*coros)
33 changes: 27 additions & 6 deletions discord/app/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
cast,
)

from discord.soundboard import SoundboardSound

from .. import utils
from ..activity import BaseActivity
from ..automod import AutoModRule
Expand Down Expand Up @@ -237,12 +239,12 @@ def __init__(
self._voice_clients: dict[int, VoiceClient] = {}

if not intents.members or cache_flags._empty:
self.store_user = self.create_user # type: ignore
self.store_user = self.create_user_async # type: ignore
self.deref_user = self.deref_user_no_intents # type: ignore

self.cache_app_emojis: bool = options.get("cache_app_emojis", False)

self.emitter = EventEmitter(self)
self.emitter: EventEmitter = EventEmitter(self)

self.cache: Cache = cache
self.cache._state = self
Expand Down Expand Up @@ -320,6 +322,9 @@ async def deref_user(self, user_id: int) -> None:
def create_user(self, data: UserPayload) -> User:
return User(state=self, data=data)

async def create_user_async(self, data: UserPayload) -> User:
return User(state=self, data=data)

def deref_user_no_intents(self, user_id: int) -> None:
return

Expand Down Expand Up @@ -373,6 +378,21 @@ async def _remove_guild(self, guild: Guild) -> None:

del guild

async def _add_default_sounds(self) -> None:
default_sounds = await self.http.get_default_sounds()
for default_sound in default_sounds:
sound = SoundboardSound(state=self, http=self.http, data=default_sound)
await self._add_sound(sound)

async def _add_sound(self, sound: SoundboardSound) -> None:
await self.cache.store_sound(sound)

async def _remove_sound(self, sound: SoundboardSound) -> None:
await self.cache.delete_sound(sound.id)

async def get_sounds(self) -> list[SoundboardSound]:
return list(await self.cache.get_all_sounds())

async def get_emojis(self) -> list[GuildEmoji | AppEmoji]:
return await self.cache.get_all_emojis()

Expand Down Expand Up @@ -431,7 +451,7 @@ async def _get_guild_channel(
# guild_id is in data
guild = await self._get_guild(int(guild_id or data["guild_id"])) # type: ignore
except KeyError:
channel = DMChannel._from_message(self, channel_id)
channel = DMChannel(id=channel_id, state=self)
guild = None
else:
channel = guild and guild._resolve_channel(channel_id)
Expand Down Expand Up @@ -487,15 +507,15 @@ async def query_members(
)
raise

def _get_create_guild(self, data):
async def _get_create_guild(self, data):
if data.get("unavailable") is False:
# GUILD_CREATE with unavailable in the response
# usually means that the guild has become available
# and is therefore in the cache
guild = self._get_guild(int(data["id"]))
guild = await self._get_guild(int(data["id"]))
if guild is not None:
guild.unavailable = False
guild._from_data(data)
await guild._from_data(data, self)
return guild

return self._add_guild_from_data(data)
Expand Down Expand Up @@ -660,6 +680,7 @@ async def _delay_ready(self) -> None:
future = asyncio.ensure_future(self.chunk_guild(guild))
current_bucket.append(future)
else:
await self._add_default_sounds()
future = self.loop.create_future()
future.set_result([])

Expand Down
Loading
Loading