Skip to content

Commit e78b605

Browse files
feat: add regex support for modal callback (#1388)
* feat: add modal regex callback * ci: correct from checks. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5ed2b31 commit e78b605

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

interactions/client/client.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def __init__(
386386
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
387387
self._regex_component_callbacks: Dict[re.Pattern, Callable[..., Coroutine]] = {}
388388
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
389+
self._regex_modal_callbacks: Dict[re.Pattern, Callable[..., Coroutine]] = {}
389390
self._global_autocompletes: Dict[str, GlobalAutoComplete] = {}
390391
self.processors: Dict[str, Callable[..., Coroutine]] = {}
391392
self.__modules = {}
@@ -1305,9 +1306,14 @@ def add_modal_callback(self, command: ModalCommand) -> None:
13051306
command: The command to add
13061307
"""
13071308
for listener in command.listeners:
1308-
if listener in self._modal_callbacks.keys():
1309-
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
1310-
self._modal_callbacks[listener] = command
1309+
if isinstance(listener, re.Pattern):
1310+
if listener in self._regex_component_callbacks.keys():
1311+
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
1312+
self._regex_modal_callbacks[listener] = command
1313+
else:
1314+
if listener in self._modal_callbacks.keys():
1315+
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
1316+
self._modal_callbacks[listener] = command
13111317
continue
13121318

13131319
def add_global_autocomplete(self, callback: GlobalAutoComplete) -> None:
@@ -1791,8 +1797,18 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: # noqa:
17911797
ctx = await self.get_context(interaction_data)
17921798
self.dispatch(events.ModalCompletion(ctx=ctx))
17931799

1794-
if callback := self._modal_callbacks.get(ctx.custom_id):
1795-
await self.__dispatch_interaction(ctx=ctx, callback=callback(ctx), error_callback=events.ModalError)
1800+
modal_callback = self._modal_callbacks.get(ctx.custom_id)
1801+
if not modal_callback:
1802+
# evaluate regex component callbacks
1803+
for regex, callback in self._regex_modal_callbacks.items():
1804+
if regex.match(ctx.custom_id):
1805+
modal_callback = callback
1806+
break
1807+
1808+
if modal_callback:
1809+
await self.__dispatch_interaction(
1810+
ctx=ctx, callback=modal_callback(ctx), error_callback=events.ModalError
1811+
)
17961812

17971813
else:
17981814
raise NotImplementedError(f"Unknown Interaction Received: {interaction_data['type']}")

interactions/models/internal/application_commands.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def desc_validator(_: Any, attr: Attribute, value: str) -> None:
9696
raise ValueError(f"Description must be between 1 and {SLASH_CMD_MAX_DESC_LENGTH} characters long")
9797

9898

99+
def custom_ids_validator(*custom_id: str | re.Pattern) -> None:
100+
if not (all(isinstance(i, re.Pattern) for i in custom_id) or all(isinstance(i, str) for i in custom_id)):
101+
raise ValueError("All custom IDs be either a string or a regex pattern, not a mix of both.")
102+
103+
99104
@attrs.define(
100105
eq=False,
101106
order=False,
@@ -1145,18 +1150,21 @@ def wrapper(func: AsyncCallable) -> ComponentCommand:
11451150
return ComponentCommand(name=f"ComponentCallback::{custom_id}", callback=func, listeners=custom_id)
11461151

11471152
custom_id = _unpack_helper(custom_id)
1148-
if not (all(isinstance(i, re.Pattern) for i in custom_id) or all(isinstance(i, str) for i in custom_id)):
1149-
raise ValueError("All custom IDs be either a string or a regex pattern, not a mix of both.")
1153+
custom_ids_validator(*custom_id)
11501154
return wrapper
11511155

11521156

1153-
def modal_callback(*custom_id: str) -> Callable[[AsyncCallable], ModalCommand]:
1157+
def modal_callback(*custom_id: str | re.Pattern) -> Callable[[AsyncCallable], ModalCommand]:
11541158
"""
11551159
Register a coroutine as a modal callback.
11561160
11571161
Modal callbacks work the same way as commands, just using modals as a way of invoking, instead of messages.
11581162
Your callback will be given a single argument, `ModalContext`
11591163
1164+
Note:
1165+
This can optionally take a regex pattern, which will be used to match against the custom ID of the modal
1166+
1167+
11601168
Args:
11611169
*custom_id: The custom ID of the modal to wait for
11621170
"""
@@ -1168,6 +1176,7 @@ def wrapper(func: AsyncCallable) -> ModalCommand:
11681176
return ModalCommand(name=f"ModalCallback::{custom_id}", callback=func, listeners=custom_id)
11691177

11701178
custom_id = _unpack_helper(custom_id)
1179+
custom_ids_validator(*custom_id)
11711180
return wrapper
11721181

11731182

0 commit comments

Comments
 (0)