Skip to content

Commit 668b7c0

Browse files
authored
feat!: refactor cache and dispatch cached models in delete/remove/update events
Merge pull request #909 from Catalyst4222/unstable
2 parents 09e8699 + 04dce8f commit 668b7c0

File tree

10 files changed

+121
-120
lines changed

10 files changed

+121
-120
lines changed

interactions/api/cache.py

Lines changed: 50 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,73 @@
1-
from collections import OrderedDict
2-
from typing import Any, List, Optional
1+
from collections import defaultdict
2+
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
3+
4+
if TYPE_CHECKING:
5+
from .models import Snowflake
6+
7+
Key = TypeVar("Key", Snowflake, Tuple[Snowflake, Snowflake])
38

49
__all__ = (
5-
"Item",
610
"Storage",
711
"Cache",
812
)
913

14+
_T = TypeVar("_T")
15+
_P = TypeVar("_P")
1016

11-
class Item:
12-
"""
13-
A class representing the defined item in a stored dataset.
1417

15-
:ivar str id: The ID of the item.
16-
:ivar Any value: The item itself.
17-
:ivar Type type: The ID type representation.
18-
"""
19-
20-
__slots__ = ("id", "value", "type")
21-
22-
def __init__(self, id: str, value: Any) -> None:
23-
"""
24-
:param id: The item's ID.
25-
:type id: str
26-
:param value: The item itself.
27-
:type value: Any
28-
"""
29-
self.id = id
30-
self.value = value
31-
self.type = type(value)
32-
33-
34-
class Storage:
18+
class Storage(Generic[_T]):
3519
"""
3620
A class representing a set of items stored as a cache state.
3721
38-
:ivar List[Item] values: The list of items stored.
22+
:ivar Dict[Union[Snowflake, Tuple[Snowflake, Snowflake]], Any] values: The list of items stored.
3923
"""
4024

41-
__slots__ = "values"
25+
__slots__ = ("values",)
4226

4327
def __repr__(self) -> str:
4428
return f"<{self.__class__.__name__} object containing {len(self.values)} items.>"
4529

4630
def __init__(self) -> None:
47-
self.values = OrderedDict()
31+
self.values: Dict["Key", _T] = {}
4832

49-
def add(self, item: Item) -> OrderedDict:
33+
def add(self, item: _T, id: Optional["Key"] = None) -> None:
5034
"""
5135
Adds a new item to the storage.
5236
5337
:param item: The item to add.
54-
:type item: Item
55-
:return: The new storage.
56-
:rtype: OrderedDict
38+
:type item: Any
39+
:param id: The unique id of the item.
40+
:type id: Optional[Union[Snowflake, Tuple[Snowflake, Snowflake]]]
5741
"""
58-
self.values.update({item.id: item.value})
59-
return self.values
42+
self.values[id or item.id] = item
6043

61-
def get(self, id: str) -> Optional[Item]:
44+
def get(self, id: "Key", default: Optional[_P] = None) -> Union[_T, _P]:
6245
"""
6346
Gets an item from the storage.
6447
6548
:param id: The ID of the item.
66-
:type id: str
49+
:type id: Union[Snowflake, Tuple[Snowflake, Snowflake]]
50+
:param default: The default value to return if the item is not found.
51+
:type default: Optional[Any]
6752
:return: The item from the storage if any.
68-
:rtype: Optional[Item]
53+
:rtype: Optional[Any]
6954
"""
70-
if id in self.values.keys():
71-
return self.values[id]
55+
return self.values.get(id, default)
7256

73-
def update(self, item: Item) -> Optional[Item]:
57+
def update(self, data: Dict["Key", _T]):
7458
"""
75-
Updates an item from the storage.
59+
Updates multiple items from the storage.
7660
77-
:param item: The item to update.
78-
:return: The updated item, if stored.
79-
:rtype: Optional[Item]
61+
:param data: The data to update with.
62+
:type data: dict
8063
"""
81-
if item.id in self.values.keys():
82-
self.values[item.id] = item.value
83-
return self.values[
84-
id
85-
] # fetches from cache to see if its saved properly, instead of returning input.
64+
self.values.update(data)
65+
66+
def pop(self, key: "Key", default: Optional[_P] = None) -> Union[_T, _P]:
67+
try:
68+
return self.values.pop(key)
69+
except KeyError:
70+
return default
8671

8772
@property
8873
def view(self) -> List[dict]:
@@ -93,45 +78,32 @@ def view(self) -> List[dict]:
9378
"""
9479
return [v._json for v in self.values.values()]
9580

81+
def __getitem__(self, item: "Key") -> _T:
82+
return self.values.__getitem__(item)
83+
84+
def __setitem__(self, key: "Key", value: _T) -> None:
85+
return self.values.__setitem__(key, value)
86+
87+
def __delitem__(self, key: "Key") -> None:
88+
return self.values.__delitem__(key)
89+
9690

9791
class Cache:
9892
"""
9993
A class representing the cache.
10094
This cache collects all of the HTTP requests made for
10195
the represented instances of the class.
10296
103-
:ivar Cache dms: The cached Direct Messages.
104-
:ivar Cache self_guilds: The cached guilds upon gateway connection.
105-
:ivar Cache guilds: The cached guilds after ready.
106-
:ivar Cache channels: The cached channels of guilds.
107-
:ivar Cache roles: The cached roles of guilds.
108-
:ivar Cache members: The cached members of guilds and threads.
109-
:ivar Cache messages: The cached messages of DMs and channels.
110-
:ivar Cache interactions: The cached interactions upon interaction.
97+
:ivar defaultdict[Type, Storage] storages:
11198
"""
11299

113-
__slots__ = (
114-
"dms",
115-
"self_guilds",
116-
"guilds",
117-
"channels",
118-
"roles",
119-
"members",
120-
"messages",
121-
"users",
122-
"interactions",
123-
)
100+
__slots__ = "storages"
124101

125102
def __init__(self) -> None:
126-
self.dms = Storage()
127-
self.self_guilds = Storage()
128-
self.guilds = Storage()
129-
self.channels = Storage()
130-
self.roles = Storage()
131-
self.members = Storage()
132-
self.messages = Storage()
133-
self.users = Storage()
134-
self.interactions = Storage()
103+
self.storages: defaultdict[Type[_T], Storage[_T]] = defaultdict(Storage)
104+
105+
def __getitem__(self, item: Type[_T]) -> Storage[_T]:
106+
return self.storages[item]
135107

136108

137109
ref_cache = Cache() # noqa

interactions/api/gateway/client.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
from ..http.client import HTTPClient
2929
from ..models.attrs_utils import MISSING
3030
from ..models.flags import Intents
31+
from ..models.member import Member
32+
from ..models.misc import Snowflake
3133
from ..models.presence import ClientPresence
3234
from .heartbeat import _Heartbeat
3335

3436
if TYPE_CHECKING:
3537
from ...client.context import _Context
38+
from ..cache import Storage
3639

3740
log = get_logger("gateway")
3841

@@ -273,12 +276,11 @@ def _dispatch_event(self, event: str, data: dict) -> None:
273276
:param data: The data for the event.
274277
:type data: dict
275278
"""
276-
# sourcery no-metrics
279+
self._dispatch.dispatch("raw_socket_create", data)
277280
path: str = "interactions"
278281
path += ".models" if event == "INTERACTION_CREATE" else ".api.models"
279282
if event == "INTERACTION_CREATE":
280283
if data.get("type"):
281-
# sourcery skip: dict-assign-update-to-union, extract-method, low-code-quality
282284
_context = self.__contextualize(data)
283285
_name: str = ""
284286
__args: list = [_context]
@@ -386,17 +388,42 @@ def _dispatch_event(self, event: str, data: dict) -> None:
386388
try:
387389
_event_path: list = [section.capitalize() for section in name.split("_")]
388390
_name: str = _event_path[0] if len(_event_path) < 3 else "".join(_event_path[:-1])
389-
__obj: object = getattr(__import__(path), _name)
391+
model = getattr(__import__(path), _name)
392+
393+
data["_client"] = self._http
394+
obj = model(**data)
395+
_cache: "Storage" = self._http.cache[model]
396+
397+
if isinstance(obj, Member):
398+
id = (Snowflake(data["guild_id"]), obj.id)
399+
else:
400+
id = getattr(obj, "id", None)
401+
402+
if "_create" in name or "_add" in name:
403+
_cache.add(obj, id)
404+
self._dispatch.dispatch(f"on_{name}", obj)
390405

391-
# name in {"_create", "_add"} returns False (tested w message_create)
392-
if any(_ in name for _ in {"_create", "_update", "_add", "_remove", "_delete"}):
393-
data["_client"] = self._http
406+
elif "_update" in name and hasattr(obj, "id"):
407+
old_obj = self._http.cache[model].get(id)
408+
_cache.add(obj, id)
409+
copy = model(**old_obj._json)
410+
old_obj.update(**obj._json)
411+
self._dispatch.dispatch(
412+
f"on_{name}", copy, old_obj
413+
) # give previously stored and new one
414+
return
394415

395-
self._dispatch.dispatch(f"on_{name}", __obj(**data)) # noqa
416+
elif "_remove" in name or "_delete" in name:
417+
self._dispatch.dispatch(f"on_raw_{name}", obj)
418+
419+
old_obj = _cache.pop(id)
420+
self._dispatch.dispatch(f"on_{name}", old_obj)
421+
422+
else:
423+
self._dispatch.dispatch(f"on_{name}", obj)
396424

397425
except AttributeError as error:
398426
log.fatal(f"An error occured dispatching {name}: {error}")
399-
self._dispatch.dispatch("raw_socket_create", data)
400427

401428
def __contextualize(self, data: dict) -> "_Context":
402429
"""

interactions/api/http/channel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional, Union
22

3-
from ...api.cache import Cache, Item
3+
from ...api.cache import Cache
44
from ..error import LibraryException
55
from ..models.channel import Channel
66
from ..models.message import Message
@@ -11,7 +11,6 @@
1111

1212

1313
class ChannelRequest:
14-
1514
_req: _Request
1615
cache: Cache
1716

@@ -26,7 +25,7 @@ async def get_channel(self, channel_id: int) -> dict:
2625
:return: Dictionary of the channel object.
2726
"""
2827
request = await self._req.request(Route("GET", f"/channels/{channel_id}"))
29-
self.cache.channels.add(Item(id=str(channel_id), value=Channel(**request, _client=self)))
28+
self.cache[Channel].add(Channel(**request, _client=self))
3029

3130
return request
3231

@@ -88,7 +87,7 @@ async def get_channel_messages(
8887
if isinstance(request, list):
8988
for message in request:
9089
if message.get("id"):
91-
self.cache.messages.add(Item(id=message["id"], value=Message(**message)))
90+
self.cache[Message].add(Message(**message))
9291

9392
return request
9493

@@ -110,7 +109,7 @@ async def create_channel(
110109
Route("POST", f"/guilds/{guild_id}/channels"), json=payload, reason=reason
111110
)
112111
if request.get("id"):
113-
self.cache.channels.add(Item(id=request["id"], value=Channel(**request)))
112+
self.cache[Channel].add(Channel(**request))
114113

115114
return request
116115

interactions/api/http/guild.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List, Optional
22
from urllib.parse import quote
33

4-
from ...api.cache import Cache, Item
4+
from ...api.cache import Cache
55
from ..models.channel import Channel
66
from ..models.guild import Guild
77
from ..models.member import Member
@@ -30,7 +30,7 @@ async def get_self_guilds(self) -> List[dict]:
3030

3131
for guild in request:
3232
if guild.get("id"):
33-
self.cache.self_guilds.add(Item(id=guild["id"], value=Guild(**guild, _client=self)))
33+
self.cache[Guild].add(Guild(**guild, _client=self))
3434

3535
return request
3636

@@ -45,7 +45,7 @@ async def get_guild(self, guild_id: int, with_counts: bool = False) -> dict:
4545
request = await self._req.request(
4646
Route("GET", f"/guilds/{guild_id}{f'?{with_counts=}' if with_counts else ''}")
4747
)
48-
self.cache.guilds.add(Item(id=str(guild_id), value=Guild(**request, _client=self)))
48+
self.cache[Guild].add(Guild(**request, _client=self))
4949

5050
return request
5151

@@ -369,9 +369,7 @@ async def get_all_channels(self, guild_id: int) -> List[dict]:
369369

370370
for channel in request:
371371
if channel.get("id"):
372-
self.cache.channels.add(
373-
Item(id=channel["id"], value=Channel(**channel, _client=self))
374-
)
372+
self.cache[Channel].add(Channel(**channel, _client=self))
375373

376374
return request
377375

@@ -388,7 +386,7 @@ async def get_all_roles(self, guild_id: int) -> List[dict]:
388386

389387
for role in request:
390388
if role.get("id"):
391-
self.cache.roles.add(Item(id=role["id"], value=Role(**role)))
389+
self.cache[Role].add(Role(**role))
392390

393391
return request
394392

@@ -407,7 +405,7 @@ async def create_guild_role(
407405
Route("POST", f"/guilds/{guild_id}/roles"), json=payload, reason=reason
408406
)
409407
if request.get("id"):
410-
self.cache.roles.add(Item(id=request["id"], value=Role(**request)))
408+
self.cache[Role].add(Role(**request))
411409

412410
return request
413411

@@ -590,7 +588,7 @@ async def add_guild_member(
590588
},
591589
)
592590

593-
self.cache.members.add(Item(id=str(user_id), value=Member(**request)))
591+
self.cache[Member].add(Member(**request))
594592

595593
return request
596594

interactions/api/http/message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from aiohttp import MultipartWriter
44

5-
from ...api.cache import Cache, Item
5+
from ...api.cache import Cache
66
from ..models.attrs_utils import MISSING
77
from ..models.message import Embed, Message, MessageInteraction, Sticker
88
from ..models.misc import File, Snowflake
@@ -98,7 +98,7 @@ async def create_message(
9898
data=data,
9999
)
100100
if request.get("id"):
101-
self.cache.messages.add(Item(id=request["id"], value=Message(**request, _client=self)))
101+
self.cache[Message].add(Message(**request, _client=self))
102102

103103
return request
104104

0 commit comments

Comments
 (0)