diff --git a/discord/ui/item.py b/discord/ui/item.py index 958f60a7a7..491a8df344 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -25,7 +25,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar +import inspect +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Generic, + TypeVar, +) + +from typing_extensions import Self from ..interactions import Interaction @@ -42,11 +52,39 @@ from .modal import BaseModal from .view import BaseView -I = TypeVar("I", bound="Item") +I = TypeVar("I", bound="Item", covariant=True) + T = TypeVar("T", bound="ItemInterface", covariant=True) V = TypeVar("V", bound="BaseView", covariant=True) M = TypeVar("M", bound="BaseModal", covariant=True) + +V_co = TypeVar("V_co", bound="BaseView", covariant=True) + ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]] +SetItemCallbackType = ( + Callable[[Interaction], Coroutine[object, Any, Any]] + | Callable[[Interaction, I], Coroutine[object, Any, Any]] + | Callable[[Interaction, I, V], Coroutine[object, Any, Any]] +) + + +class _ProxyItemCallback: + def __init__( + self, func: SetItemCallbackType, item: ViewItem, parameters_amount: int + ) -> None: + self.func: SetItemCallbackType = func + self.item: ViewItem = item + self.parameters_amount: int = parameters_amount + + def __call__(self, interaction: Interaction) -> Coroutine[Any, Any, Any]: + if self.parameters_amount == 1: + return self.func(interaction) # type: ignore # type checker doesn't like optional params + elif self.parameters_amount == 2: + return self.func(interaction, self.item) # type: ignore # type checker doesn't like optional params + elif self.parameters_amount == 3: + return self.func(interaction, self.item, self.item.view) # type: ignore # type checker doesn't like optional params + else: + raise TypeError("callback must accept 1 to 3 parameters") class Item(Generic[T]): @@ -124,7 +162,7 @@ def id(self, value) -> None: self._underlying.id = value -class ViewItem(Item[V]): +class ViewItem(Item[V_co], Generic[V_co]): """Represents an item used in Views. The following are the original items supported in :class:`discord.ui.View`: @@ -149,7 +187,7 @@ class ViewItem(Item[V]): def __init__(self): super().__init__() - self._view: V | None = None + self._view: V_co | None = None self._row: int | None = None self._rendered_row: int | None = None self.parent: ViewItem | BaseView | None = self.view @@ -197,7 +235,7 @@ def width(self) -> int: return 1 @property - def view(self) -> V | None: + def view(self) -> V_co | None: """Gets the parent view associated with this item. The view refers to the structure that holds this item. This is typically set @@ -227,6 +265,34 @@ async def callback(self, interaction: Interaction): The interaction that triggered this UI item. """ + def set_callback(self, func: SetItemCallbackType[Self, V_co], /) -> None: + """Sets the callback for this item. + + Parameters + ---------- + func + The callback function to set. + + This function must be a coroutine that accepts 1 to 3 parameters: + + - :class:`.Interaction`, this will always be passed. + - :class:`.BaseView`, this will be passed if the function accepts 2 or 3 parameters. + - :class:`.Item`, this will be passed if the function accepts 3 parameters. + + Raises + ------ + TypeError + If the provided function is not a coroutine or does not have the correct number of parameters. + """ + if not inspect.iscoroutinefunction(func): + raise TypeError("callback must be a coroutine function") + + params = inspect.signature(func).parameters + if len(params) > 3: + raise TypeError("callback must accept 1 to 3 parameters") + + self.callback = _ProxyItemCallback(func, self, len(params)) # type: ignore + class ModalItem(Item[M]): """Represents an item used in Modals.