Skip to content

Commit 6a5cb29

Browse files
authored
Futher improvements in components implementation (#204)
* Added support for actionrows in wait_for_component (cherry picked from commit 4e2c310) * Applied black formatting (cherry picked from commit f4e34ff) * Added message kwarg to wait_for_component, removed wait_for_any_component (cherry picked from commit 06f3437) * Exception for hidden+edit_origin (cherry picked from commit 296792f) * Warning on edit_origin when deferred with different state (cherry picked from commit 1d33e7f) * Changed exception types in get_components_ids and _get_messages_ids (cherry picked from commit 3eb8c27) * Added warning for send when deffered with different state (cherry picked from commit 9e242ef) * Tweaked docstrings (cherry picked from commit 0ecd922) * Moved component enums to model.py (cherry picked from commit a664f4a) * Fix ComponentContext.send() (cherry picked from commit 3a6dee2) * Applied pre_push * Fixed bad merge result (bare except) * Added component attribute to ComponentContext * Applied changes from review
1 parent f19eb92 commit 6a5cb29

File tree

5 files changed

+161
-64
lines changed

5 files changed

+161
-64
lines changed

discord_slash/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010

1111
from .client import SlashCommand # noqa: F401
1212
from .const import __version__ # noqa: F401
13-
from .context import ComponentContext # noqa: F401
14-
from .context import SlashContext # noqa: F401
13+
from .context import ComponentContext, SlashContext # noqa: F401
1514
from .dpy_overrides import ComponentMessage # noqa: F401
16-
from .model import SlashCommandOptionType # noqa: F401
15+
from .model import ButtonStyle, ComponentType, SlashCommandOptionType # noqa: F401
1716
from .utils import manage_commands # noqa: F401
1817
from .utils import manage_components # noqa: F401

discord_slash/context.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class InteractionContext:
1414
"""
1515
Base context for interactions.\n
16-
Kinda similar with discord.ext.commands.Context.
16+
In some ways similar with discord.ext.commands.Context.
1717
1818
.. warning::
1919
Do not manually init this model.
@@ -139,7 +139,7 @@ async def send(
139139
components: typing.List[dict] = None,
140140
) -> model.SlashMessage:
141141
"""
142-
Sends response of the slash command.
142+
Sends response of the interaction.
143143
144144
.. warning::
145145
- Since Release 1.0.9, this is completely changed. If you are migrating from older version, please make sure to fix the usage.
@@ -278,10 +278,12 @@ class ComponentContext(InteractionContext):
278278
"""
279279
Context of a component interaction. Has all attributes from :class:`InteractionContext`, plus the component-specific ones below.
280280
281-
:ivar custom_id: The custom ID of the component.
281+
:ivar custom_id: The custom ID of the component (has alias component_id).
282282
:ivar component_type: The type of the component.
283+
:ivar component: Component data retrieved from the message. Not available if the origin message was ephemeral.
283284
:ivar origin_message: The origin message of the component. Not available if the origin message was ephemeral.
284285
:ivar origin_message_id: The ID of the origin message.
286+
285287
"""
286288

287289
def __init__(
@@ -297,27 +299,73 @@ def __init__(
297299
self.origin_message = None
298300
self.origin_message_id = int(_json["message"]["id"]) if "message" in _json.keys() else None
299301

302+
self.component = None
303+
304+
self._deferred_edit_origin = False
305+
300306
if self.origin_message_id and (_json["message"]["flags"] & 64) != 64:
301307
self.origin_message = ComponentMessage(
302308
state=self.bot._connection, channel=self.channel, data=_json["message"]
303309
)
310+
self.component = self.origin_message.get_component(self.custom_id)
304311

305312
async def defer(self, hidden: bool = False, edit_origin: bool = False):
306313
"""
307314
'Defers' the response, showing a loading state to the user
308315
309316
:param hidden: Whether the deferred response should be ephemeral . Default ``False``.
310-
:param edit_origin: Whether the response is editing the origin message. If ``False``, the deferred response will be for a follow up message. Defaults ``False``.
317+
:param edit_origin: Whether the type is editing the origin message. If ``False``, the deferred response will be for a follow up message. Defaults ``False``.
311318
"""
312319
if self.deferred or self.responded:
313320
raise error.AlreadyResponded("You have already responded to this command!")
321+
314322
base = {"type": 6 if edit_origin else 5}
315-
if hidden and not edit_origin:
323+
324+
if hidden:
325+
if edit_origin:
326+
raise error.IncorrectFormat(
327+
"'hidden' and 'edit_origin' flags are mutually exclusive"
328+
)
316329
base["data"] = {"flags": 64}
317330
self._deferred_hidden = True
331+
332+
self._deferred_edit_origin = edit_origin
333+
318334
await self._http.post_initial_response(base, self.interaction_id, self._token)
319335
self.deferred = True
320336

337+
async def send(
338+
self,
339+
content: str = "",
340+
*,
341+
embed: discord.Embed = None,
342+
embeds: typing.List[discord.Embed] = None,
343+
tts: bool = False,
344+
file: discord.File = None,
345+
files: typing.List[discord.File] = None,
346+
allowed_mentions: discord.AllowedMentions = None,
347+
hidden: bool = False,
348+
delete_after: float = None,
349+
components: typing.List[dict] = None,
350+
) -> model.SlashMessage:
351+
if self.deferred and self._deferred_edit_origin:
352+
self._logger.warning(
353+
"Deferred response might not be what you set it to! (edit origin / send response message) "
354+
"This is because it was deferred with different response type."
355+
)
356+
return await super().send(
357+
content,
358+
embed=embed,
359+
embeds=embeds,
360+
tts=tts,
361+
file=file,
362+
files=files,
363+
allowed_mentions=allowed_mentions,
364+
hidden=hidden,
365+
delete_after=delete_after,
366+
components=components,
367+
)
368+
321369
async def edit_origin(self, **fields):
322370
"""
323371
Edits the origin message of the component.
@@ -366,6 +414,11 @@ async def edit_origin(self, **fields):
366414
if files and not self.deferred:
367415
await self.defer(edit_origin=True)
368416
if self.deferred:
417+
if not self._deferred_edit_origin:
418+
self._logger.warning(
419+
"Deferred response might not be what you set it to! (edit origin / send response message) "
420+
"This is because it was deferred with different response type."
421+
)
369422
_json = await self._http.edit(_resp, self._token, files=files)
370423
self.deferred = False
371424
else: # noqa: F841

discord_slash/dpy_overrides.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing
2+
13
import discord
24
from discord import AllowedMentions, File, InvalidArgument, abc, http, utils
35
from discord.ext import commands
@@ -11,6 +13,19 @@ def __init__(self, *, state, channel, data):
1113
super().__init__(state=state, channel=channel, data=data)
1214
self.components = data["components"]
1315

16+
def get_component(self, custom_id: int) -> typing.Optional[dict]:
17+
"""
18+
Returns first component with matching custom_id
19+
20+
:param custom_id: custom_id of component to get from message components
21+
:return: Optional[dict]
22+
23+
"""
24+
for row in self.components:
25+
for component in row["components"]:
26+
if component["custom_id"] is custom_id:
27+
return component
28+
1429

1530
def new_override(cls, *args, **kwargs):
1631
if cls is discord.Message:

discord_slash/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,24 @@ def from_type(cls, t: type):
563563
return cls.ROLE
564564
if issubclass(t, discord.abc.User):
565565
return cls.USER
566+
567+
568+
class ComponentType(IntEnum):
569+
actionrow = 1
570+
button = 2
571+
select = 3
572+
573+
574+
class ButtonStyle(IntEnum):
575+
blue = 1
576+
blurple = 1
577+
gray = 2
578+
grey = 2
579+
green = 3
580+
red = 4
581+
URL = 5
582+
583+
primary = 1
584+
secondary = 2
585+
success = 3
586+
danger = 4

discord_slash/utils/manage_components.py

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
import enum
21
import typing
32
import uuid
43

54
import discord
65

76
from ..context import ComponentContext
8-
from ..error import IncorrectFormat
9-
10-
11-
class ComponentsType(enum.IntEnum):
12-
actionrow = 1
13-
button = 2
14-
select = 3
7+
from ..error import IncorrectFormat, IncorrectType
8+
from ..model import ButtonStyle, ComponentType
159

1610

1711
def create_actionrow(*components: dict) -> dict:
@@ -24,27 +18,12 @@ def create_actionrow(*components: dict) -> dict:
2418
if not components or len(components) > 5:
2519
raise IncorrectFormat("Number of components in one row should be between 1 and 5.")
2620
if (
27-
ComponentsType.select in [component["type"] for component in components]
21+
ComponentType.select in [component["type"] for component in components]
2822
and len(components) > 1
2923
):
3024
raise IncorrectFormat("Action row must have only one select component and nothing else")
3125

32-
return {"type": ComponentsType.actionrow, "components": components}
33-
34-
35-
class ButtonStyle(enum.IntEnum):
36-
blue = 1
37-
blurple = 1
38-
gray = 2
39-
grey = 2
40-
green = 3
41-
red = 4
42-
URL = 5
43-
44-
primary = 1
45-
secondary = 2
46-
success = 3
47-
danger = 4
26+
return {"type": ComponentType.actionrow, "components": components}
4827

4928

5029
def emoji_to_dict(emoji: typing.Union[discord.Emoji, discord.PartialEmoji, str]) -> dict:
@@ -103,7 +82,7 @@ def create_button(
10382
emoji = emoji_to_dict(emoji)
10483

10584
data = {
106-
"type": ComponentsType.button,
85+
"type": ComponentType.button,
10786
"style": style,
10887
}
10988

@@ -146,7 +125,11 @@ def create_select_option(
146125

147126

148127
def create_select(
149-
options: typing.List[dict], custom_id=None, placeholder=None, min_values=None, max_values=None
128+
options: typing.List[dict],
129+
custom_id=None,
130+
placeholder=None,
131+
min_values=None,
132+
max_values=None,
150133
):
151134
"""
152135
Creates a select (dropdown) component for use with the ``components`` field. Must be inside an ActionRow to be used (see :meth:`create_actionrow`).
@@ -158,7 +141,7 @@ def create_select(
158141
raise IncorrectFormat("Options length should be between 1 and 25.")
159142

160143
return {
161-
"type": ComponentsType.select,
144+
"type": ComponentType.select,
162145
"options": options,
163146
"custom_id": custom_id or str(uuid.uuid4()),
164147
"placeholder": placeholder or "",
@@ -167,51 +150,77 @@ def create_select(
167150
}
168151

169152

170-
async def wait_for_component(
171-
client: discord.Client, component: typing.Union[dict, str], check=None, timeout=None
172-
) -> ComponentContext:
153+
def get_components_ids(component: typing.Union[str, dict, list]) -> typing.Iterator[str]:
173154
"""
174-
Waits for a component interaction. Only accepts interactions based on the custom ID of the component, and optionally a check function.
155+
Returns generator with 'custom_id' of component or components.
175156
176-
:param client: The client/bot object.
177-
:type client: :class:`discord.Client`
178-
:param component: The component dict or custom ID.
179-
:type component: Union[dict, str]
180-
:param check: Optional check function. Must take a `ComponentContext` as the first parameter.
181-
:param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`.
182-
:raises: :exc:`asyncio.TimeoutError`
157+
:param component: Custom ID or component dict (actionrow or button) or list of previous two.
183158
"""
184159

185-
def _check(ctx):
186-
if check and not check(ctx):
187-
return False
188-
return (
189-
component["custom_id"] if isinstance(component, dict) else component
190-
) == ctx.custom_id
191-
192-
return await client.wait_for("component", check=_check, timeout=timeout)
160+
if isinstance(component, str):
161+
yield component
162+
elif isinstance(component, dict):
163+
if component["type"] == ComponentType.actionrow:
164+
yield from (comp["custom_id"] for comp in component["components"])
165+
else:
166+
yield component["custom_id"]
167+
elif isinstance(component, list):
168+
# Either list of components (actionrows or buttons) or list of ids
169+
yield from (comp_id for comp in component for comp_id in get_components_ids(comp))
170+
else:
171+
raise IncorrectType(
172+
f"Unknown component type of {component} ({type(component)}). "
173+
f"Expected str, dict or list"
174+
)
175+
176+
177+
def _get_messages_ids(message: typing.Union[discord.Message, int, list]) -> typing.Iterator[int]:
178+
if isinstance(message, int):
179+
yield message
180+
elif isinstance(message, discord.Message):
181+
yield message.id
182+
elif isinstance(message, list):
183+
yield from (msg_id for msg in message for msg_id in _get_messages_ids(msg))
184+
else:
185+
raise IncorrectType(
186+
f"Unknown component type of {message} ({type(message)}). "
187+
f"Expected discord.Message, int or list"
188+
)
193189

194190

195-
async def wait_for_any_component(
196-
client: discord.Client, message: typing.Union[discord.Message, int], check=None, timeout=None
191+
async def wait_for_component(
192+
client: discord.Client,
193+
component: typing.Union[str, dict, list] = None,
194+
message: typing.Union[discord.Message, int, list] = None,
195+
check=None,
196+
timeout=None,
197197
) -> ComponentContext:
198198
"""
199-
Waits for any component interaction. Only accepts interactions based on the message ID given and optionally a check function.
199+
Helper function - wrapper around 'client.wait_for("component", ...)'
200+
Waits for a component interaction. Only accepts interactions based on the custom ID of the component or/and message ID, and optionally a check function.
200201
201202
:param client: The client/bot object.
202203
:type client: :class:`discord.Client`
203-
:param message: The message object to check for, or the message ID.
204-
:type message: Union[discord.Message, int]
204+
:param component: Custom ID or component dict (actionrow or button) or list of previous two.
205+
:param message: The message object to check for, or the message ID or list of previous two.
206+
:type component: Union[dict, str]
205207
:param check: Optional check function. Must take a `ComponentContext` as the first parameter.
206208
:param timeout: The number of seconds to wait before timing out and raising :exc:`asyncio.TimeoutError`.
207209
:raises: :exc:`asyncio.TimeoutError`
208210
"""
209211

210-
def _check(ctx):
212+
if not (component or message):
213+
raise IncorrectFormat("You must specify component or message (or both)")
214+
215+
components_ids = list(get_components_ids(component)) if component else None
216+
message_ids = list(_get_messages_ids(message)) if message else None
217+
218+
def _check(ctx: ComponentContext):
211219
if check and not check(ctx):
212220
return False
213-
return (
214-
message.id if isinstance(message, discord.Message) else message
215-
) == ctx.origin_message_id
221+
# if components_ids is empty or there is a match
222+
wanted_component = not components_ids or ctx.custom_id in components_ids
223+
wanted_message = not message_ids or ctx.origin_message_id in message_ids
224+
return wanted_component and wanted_message
216225

217226
return await client.wait_for("component", check=_check, timeout=timeout)

0 commit comments

Comments
 (0)