Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions discord/ui/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, TypeVar
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Generic,
TypeVar,
)

from typing_extensions import Self

from ..interactions import Interaction

Expand All @@ -42,11 +52,39 @@
from .modal import BaseModal
from .view import BaseView

I = TypeVar("I", bound="Item")
I = TypeVar("I", bound="Item", covariant=True)

T = TypeVar("T", bound="ItemInterface", covariant=True)
V = TypeVar("V", bound="BaseView", covariant=True)
M = TypeVar("M", bound="BaseModal", covariant=True)

V_co = TypeVar("V_co", bound="BaseView", covariant=True)

ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
SetItemCallbackType = (
Callable[[Interaction], Coroutine[object, Any, Any]]
| Callable[[Interaction, I], Coroutine[object, Any, Any]]
| Callable[[Interaction, I, V], Coroutine[object, Any, Any]]
)


class _ProxyItemCallback:
def __init__(
self, func: SetItemCallbackType, item: ViewItem, parameters_amount: int
) -> None:
self.func: SetItemCallbackType = func
self.item: ViewItem = item
self.parameters_amount: int = parameters_amount

def __call__(self, interaction: Interaction) -> Coroutine[Any, Any, Any]:
if self.parameters_amount == 1:
return self.func(interaction) # type: ignore # type checker doesn't like optional params
elif self.parameters_amount == 2:
return self.func(interaction, self.item) # type: ignore # type checker doesn't like optional params
elif self.parameters_amount == 3:
return self.func(interaction, self.item, self.item.view) # type: ignore # type checker doesn't like optional params
else:
raise TypeError("callback must accept 1 to 3 parameters")


class Item(Generic[T]):
Expand Down Expand Up @@ -124,7 +162,7 @@ def id(self, value) -> None:
self._underlying.id = value


class ViewItem(Item[V]):
class ViewItem(Item[V_co], Generic[V_co]):
"""Represents an item used in Views.

The following are the original items supported in :class:`discord.ui.View`:
Expand All @@ -149,7 +187,7 @@ class ViewItem(Item[V]):

def __init__(self):
super().__init__()
self._view: V | None = None
self._view: V_co | None = None
self._row: int | None = None
self._rendered_row: int | None = None
self.parent: ViewItem | BaseView | None = self.view
Expand Down Expand Up @@ -197,7 +235,7 @@ def width(self) -> int:
return 1

@property
def view(self) -> V | None:
def view(self) -> V_co | None:
"""Gets the parent view associated with this item.

The view refers to the structure that holds this item. This is typically set
Expand Down Expand Up @@ -227,6 +265,34 @@ async def callback(self, interaction: Interaction):
The interaction that triggered this UI item.
"""

def set_callback(self, func: SetItemCallbackType[Self, V_co], /) -> None:
"""Sets the callback for this item.

Parameters
----------
func
The callback function to set.

This function must be a coroutine that accepts 1 to 3 parameters:

- :class:`.Interaction`, this will always be passed.
- :class:`.BaseView`, this will be passed if the function accepts 2 or 3 parameters.
- :class:`.Item`, this will be passed if the function accepts 3 parameters.

Raises
------
TypeError
If the provided function is not a coroutine or does not have the correct number of parameters.
"""
if not inspect.iscoroutinefunction(func):
raise TypeError("callback must be a coroutine function")

params = inspect.signature(func).parameters
if len(params) > 3:
raise TypeError("callback must accept 1 to 3 parameters")

self.callback = _ProxyItemCallback(func, self, len(params)) # type: ignore


class ModalItem(Item[M]):
"""Represents an item used in Modals.
Expand Down