diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..16761a1db9c8 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Replace Black with Ruff, then format whole project. +44a44e938fb2bd0bb085d8aa4577abeb01653ad3 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000000..e2cfe594e4b2 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +open_collective: discordpy diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 1f3084ba0edd..000000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,34 +0,0 @@ ---- -name: Bug Report -about: Report broken or incorrect behaviour ---- - -### Summary - - - -### Reproduction Steps - - - -### Expected Results - - - -### Actual Results - - - - -### Checklist - - - -- [ ] I have searched the open issues for duplicates. -- [ ] I have shown the entire traceback, if possible. -- [ ] I have removed my token from display, if visible. - -### System Information - - - diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 000000000000..671f440fc2d6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,78 @@ +name: Bug Report +description: Report broken or incorrect behaviour +labels: unconfirmed bug +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out a bug. + If you want real-time support, consider joining our Discord at https://discord.gg/r3sSKJJ instead. + + Please note that this form is for bugs only! + - type: input + attributes: + label: Summary + description: A simple summary of your bug report + validations: + required: true + - type: textarea + attributes: + label: Reproduction Steps + description: > + What you did to make it happen. + validations: + required: true + - type: textarea + attributes: + label: Minimal Reproducible Code + description: > + A short snippet of code that showcases the bug. + render: python + - type: textarea + attributes: + label: Expected Results + description: > + What did you expect to happen? + validations: + required: true + - type: textarea + attributes: + label: Actual Results + description: > + What actually happened? + validations: + required: true + - type: input + attributes: + label: Intents + description: > + What intents are you using for your bot? + This is the `discord.Intents` class you pass to the client. + validations: + required: true + - type: textarea + attributes: + label: System Information + description: > + Run `python -m discord -v` and paste this information below. + + This command required v1.1.0 or higher of the library. If this errors out then show some basic + information involving your system such as operating system and Python version. + validations: + required: true + - type: checkboxes + attributes: + label: Checklist + description: > + Let's make sure you've properly done due diligence when reporting this issue! + options: + - label: I have searched the open issues for duplicates. + required: true + - label: I have shown the entire traceback, if possible. + required: true + - label: I have removed my token from display, if visible. + required: true + - type: textarea + attributes: + label: Additional Context + description: If there is anything else to say, please do so here. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000000..7934e4a850e4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Ask a question + about: Ask questions and discuss with other users of the library. + url: https://github.com/Rapptz/discord.py/discussions + - name: Discord Server + about: Use our official Discord server to ask for help and questions as well. + url: https://discord.gg/r3sSKJJ diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 8bb8edeb54ef..000000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -name: Feature Request -about: Suggest a feature for this library ---- - -### The Problem - - - -### The Ideal Solution - - - -### The Current Solution - - - -### Summary - - diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 000000000000..cf5f57e03abd --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,49 @@ +name: Feature Request +description: Suggest a feature for this library +labels: feature request +body: + - type: input + attributes: + label: Summary + description: > + A short summary of what your feature request is. + validations: + required: true + - type: dropdown + attributes: + multiple: false + label: What is the feature request for? + options: + - The core library + - discord.ext.commands + - discord.ext.tasks + - The documentation + validations: + required: true + - type: textarea + attributes: + label: The Problem + description: > + What problem is your feature trying to solve? + What becomes easier or possible when this feature is implemented? + validations: + required: true + - type: textarea + attributes: + label: The Ideal Solution + description: > + What is your ideal solution to the problem? + What would you like this feature to do? + validations: + required: true + - type: textarea + attributes: + label: The Current Solution + description: > + What is the current solution to the problem, if any? + validations: + required: false + - type: textarea + attributes: + label: Additional Context + description: If there is anything else to say, please do so here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md deleted file mode 100644 index 2f0a6a7fdaaa..000000000000 --- a/.github/ISSUE_TEMPLATE/question.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -name: Question about the library -about: Please ask for help in Discord instead - https://discord.gg/r3sSKJJ ---- - -Generally speaking support questions are better answered in our Discord server. The response rate is faster and many people are willing to help. If you **really** feel like the question belongs here then feel free to delete this text and continue on. **Please do not open issues about asking how to implement a feature in your bot, these will be instantly closed.** - -Our support servers can be found here: - -Official server: https://discord.gg/r3sSKJJ -Discord API: https://discord.gg/discord-api diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e39c651da219..55941f4e1c7a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,8 +1,8 @@ -### Summary +## Summary -### Checklist +## Checklist diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000000..d82ae9a9f0f7 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,60 @@ +name: build + +on: + push: + pull_request: + types: [ opened, reopened, synchronize ] + +jobs: + dists-and-docs: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ '3.8', '3.x' ] + language: [ 'en', 'ja' ] + + name: dists & docs (${{ matrix.python-version }}/${{ matrix.language }}) + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up CPython ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install -U -r requirements.txt + + - name: Build distributions + run: | + python ./setup.py sdist bdist_wheel + + # - name: Upload artifacts + # uses: actions/upload-artifact@v2 + # with: + # name: distributions + # path: dist/* + + - name: Install package + run: | + pip install -e .[docs,speed,voice] + + - name: Build docs + shell: bash + run: | + cd docs + sphinx-build -b html -D language=${DOCS_LANGUAGE} -j auto -a -n -T -W --keep-going . _build/html + env: + DOCS_LANGUAGE: ${{ matrix.language }} + + # - name: Upload docs + # uses: actions/upload-artifact@v2 + # if: always() + # with: + # name: docs-${{ matrix.language }} + # path: docs/_build/html/* diff --git a/.github/workflows/crowdin_download.yml b/.github/workflows/crowdin_download.yml new file mode 100644 index 000000000000..2e02428b768b --- /dev/null +++ b/.github/workflows/crowdin_download.yml @@ -0,0 +1,73 @@ +name: crowdin download + +on: + schedule: + - cron: '0 18 * * 1' + workflow_dispatch: + +jobs: + check-environment: + runs-on: ubuntu-latest + environment: Crowdin + outputs: + available: ${{ steps.check.outputs.available }} + steps: + - id: check + if: env.CROWDIN_API_KEY != null + run: | + echo "available=true" >> $GITHUB_OUTPUT + env: + CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }} + + download: + runs-on: ubuntu-latest + needs: [ check-environment ] + # secrets cannot be accessed inside an `if` so this needs to be checked in separate job + if: needs.check-environment.outputs.available == 'true' + environment: Crowdin + name: download + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + ref: master + + - name: Install system dependencies + run: | + wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add - + echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list + sudo apt-get update -qq + sudo apt-get install -y crowdin3 + + - name: Download translations + shell: bash + run: | + cd docs + crowdin download --all + env: + CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }} + + - name: Create pull request + id: cpr_crowdin + uses: peter-evans/create-pull-request@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Crowdin translations download + title: "[Crowdin] Updated translation files" + body: | + Created by the [Crowdin download workflow](.github/workflows/crowdin_download.yml). + branch: "auto/crowdin" + author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> + + - name: Close and reopen the PR with different token to trigger CI + uses: actions/github-script@v3 + env: + PR_NUMBER: ${{ steps.cpr_crowdin.outputs.pull-request-number }} + PR_OPERATION: ${{ steps.cpr_crowdin.outputs.pull-request-operation }} + with: + github-token: ${{ secrets.GH_REPO_SCOPED_TOKEN }} + script: | + const script = require( + `${process.env.GITHUB_WORKSPACE}/.github/workflows/scripts/close_and_reopen_pr.js` + ); + console.log(script({github, context})); diff --git a/.github/workflows/crowdin_upload.yml b/.github/workflows/crowdin_upload.yml new file mode 100644 index 000000000000..949528b8e7de --- /dev/null +++ b/.github/workflows/crowdin_upload.yml @@ -0,0 +1,44 @@ +name: crowdin upload + +on: + workflow_dispatch: + +jobs: + upload: + runs-on: ubuntu-latest + environment: Crowdin + name: upload + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up CPython 3.x + uses: actions/setup-python@v4 + with: + python-version: 3.x + + - name: Install system dependencies + run: | + wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add - + echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list + sudo apt-get update -qq + sudo apt-get install -y crowdin3 + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install -e .[docs,speed,voice] + + - name: Build gettext + run: | + cd docs + make gettext + + - name: Upload sources + shell: bash + run: | + cd docs + crowdin upload + env: + CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000000..73992a155241 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,48 @@ +name: lint + +on: + push: + pull_request: + types: [ opened, reopened, synchronize ] + +jobs: + check: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ '3.8', '3.x' ] + + name: check ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up CPython ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + id: install-deps + run: | + python -m pip install --upgrade pip setuptools wheel ruff==0.12 requests "typing_extensions>=4.3,<5" + pip install -U -r requirements.txt + + - name: Setup node.js + uses: actions/setup-node@v3 + with: + node-version: '16' + + - name: Run Pyright + uses: jakebailey/pyright-action@v1 + with: + version: '1.1.394' + warnings: false + no-comments: ${{ matrix.python-version != '3.x' }} + + - name: Run ruff + if: ${{ always() && steps.install-deps.outcome == 'success' }} + run: | + ruff format --check discord examples diff --git a/.github/workflows/scripts/close_and_reopen_pr.js b/.github/workflows/scripts/close_and_reopen_pr.js new file mode 100644 index 000000000000..dc1214301930 --- /dev/null +++ b/.github/workflows/scripts/close_and_reopen_pr.js @@ -0,0 +1,22 @@ +module.exports = (async function ({github, context}) { + const pr_number = process.env.PR_NUMBER; + const pr_operation = process.env.PR_OPERATION; + + if (!['created', 'updated'].includes(pr_operation)) { + console.log('PR was not created as there were no changes.') + return; + } + + for (const state of ['closed', 'open']) { + // Wait a moment for GitHub to process the previous action.. + await new Promise(r => setTimeout(r, 5000)); + + // Close the PR + github.issues.update({ + issue_number: pr_number, + owner: context.repo.owner, + repo: context.repo.repo, + state + }); + } +}) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000000..f81a384a1323 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,34 @@ +name: test + +on: + push: + pull_request: + types: [ opened, reopened, synchronize ] + +jobs: + pytest: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ '3.8', '3.x' ] + + name: pytest ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up CPython ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install -e .[test] + + - name: Run tests + shell: bash + run: | + PYTHONPATH="$(pwd)" pytest -vs --cov=discord --cov-report term-missing:skip-covered diff --git a/.gitignore b/.gitignore index b556ebbb9146..62782dbcf1d6 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,7 @@ docs/crowdin.py *.jpg *.flac *.mo +/.coverage +build/* +uv.lock* +pylock*.toml \ No newline at end of file diff --git a/.readthedocs.yml b/.readthedocs.yml index ab9f4daf382a..68c792379acc 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,7 +2,9 @@ version: 2 formats: [] build: - image: latest + os: "ubuntu-22.04" + tools: + python: "3.8" sphinx: configuration: docs/conf.py @@ -10,7 +12,6 @@ sphinx: builder: html python: - version: 3.7 install: - method: pip path: . diff --git a/LICENSE b/LICENSE index 4003396c4f67..700c21b65a04 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), diff --git a/MANIFEST.in b/MANIFEST.in index e0a5ef4b083b..8e93fd092a5d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include README.rst include LICENSE include requirements.txt -include discord/bin/*.dll +include discord/bin/* +include discord/py.typed diff --git a/README.ja.rst b/README.ja.rst index aa8db00d4eb8..979d82247616 100644 --- a/README.ja.rst +++ b/README.ja.rst @@ -1,7 +1,7 @@ discord.py ========== -.. image:: https://discordapp.com/api/guilds/336642139381301249/embed.png +.. image:: https://discord.com/api/guilds/336642139381301249/embed.png :target: https://discord.gg/nXzj3dg :alt: Discordサーバーの招待 .. image:: https://img.shields.io/pypi/v/discord.py.svg @@ -18,19 +18,18 @@ discord.py は機能豊富かつモダンで使いやすい、非同期処理に - ``async`` と ``await`` を使ったモダンなPythonらしいAPI。 - 適切なレート制限処理 -- Discord APIによってサポートされているものを100%カバー。 - メモリと速度の両方を最適化。 インストール ------------- -**Python 3.5.3 以降のバージョンが必須です** +**Python 3.8 以降のバージョンが必須です** 完全な音声サポートなしでライブラリをインストールする場合は次のコマンドを実行してください: .. code:: sh - # Linux/OS X + # Linux/macOS python3 -m pip install -U discord.py # Windows @@ -40,7 +39,7 @@ discord.py は機能豊富かつモダンで使いやすい、非同期処理に .. code:: sh - # Linux/OS X + # Linux/macOS python3 -m pip install -U discord.py[voice] # Windows @@ -61,10 +60,10 @@ discord.py は機能豊富かつモダンで使いやすい、非同期処理に * PyNaCl (音声サポート用) -Linuxで音声サポートを導入するには、前述のコマンドを実行する前にお気に入りのパッケージマネージャー(例えば ``apt`` や ``yum`` など)を使って以下のパッケージをインストールする必要があります: +Linuxで音声サポートを導入するには、前述のコマンドを実行する前にお気に入りのパッケージマネージャー(例えば ``apt`` や ``dnf`` など)を使って以下のパッケージをインストールする必要があります: * libffi-dev (システムによっては ``libffi-devel``) -* python-dev (例えばPython 3.6用の ``python3.6-dev``) +* python-dev (例えばPython 3.8用の ``python3.8-dev``) 簡単な例 -------------- @@ -85,7 +84,9 @@ Linuxで音声サポートを導入するには、前述のコマンドを実行 if message.content == 'ping': await message.channel.send('pong') - client = MyClient() + intents = discord.Intents.default() + intents.message_content = True + client = MyClient(intents=intents) client.run('token') Botの例 @@ -96,7 +97,9 @@ Botの例 import discord from discord.ext import commands - bot = commands.Bot(command_prefix='>') + intents = discord.Intents.default() + intents.message_content = True + bot = commands.Bot(command_prefix='>', intents=intents) @bot.command() async def ping(ctx): diff --git a/README.rst b/README.rst index b8cbd2819976..b2112f9d6b81 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,7 @@ discord.py ========== -.. image:: https://discordapp.com/api/guilds/336642139381301249/embed.png +.. image:: https://discord.com/api/guilds/336642139381301249/embed.png :target: https://discord.gg/r3sSKJJ :alt: Discord server invite .. image:: https://img.shields.io/pypi/v/discord.py.svg @@ -18,19 +18,25 @@ Key Features - Modern Pythonic API using ``async`` and ``await``. - Proper rate limit handling. -- 100% coverage of the supported Discord API. - Optimised in both speed and memory. Installing ---------- -**Python 3.5.3 or higher is required** +**Python 3.8 or higher is required** To install the library without full voice support, you can just run the following command: +.. note:: + + A `Virtual Environment `__ is recommended to install + the library, especially on Linux where the system Python is externally managed and restricts which + packages you can install on it. + + .. code:: sh - # Linux/OS X + # Linux/macOS python3 -m pip install -U discord.py # Windows @@ -40,8 +46,8 @@ Otherwise to get voice support you should run the following command: .. code:: sh - # Linux/OS X - python3 -m pip install -U discord.py[voice] + # Linux/macOS + python3 -m pip install -U "discord.py[voice]" # Windows py -3 -m pip install -U discord.py[voice] @@ -59,12 +65,12 @@ To install the development version, do the following: Optional Packages ~~~~~~~~~~~~~~~~~~ -* PyNaCl (for voice support) +* `PyNaCl `__ (for voice support) -Please note that on Linux installing voice you must install the following packages via your favourite package manager (e.g. ``apt``, ``yum``, etc) before running the above commands: +Please note that when installing voice support on Linux, you must install the following packages via your favourite package manager (e.g. ``apt``, ``dnf``, etc) before running the above commands: * libffi-dev (or ``libffi-devel`` on some systems) -* python-dev (e.g. ``python3.6-dev`` for Python 3.6) +* python-dev (e.g. ``python3.8-dev`` for Python 3.8) Quick Example -------------- @@ -85,7 +91,9 @@ Quick Example if message.content == 'ping': await message.channel.send('pong') - client = MyClient() + intents = discord.Intents.default() + intents.message_content = True + client = MyClient(intents=intents) client.run('token') Bot Example @@ -96,7 +104,9 @@ Bot Example import discord from discord.ext import commands - bot = commands.Bot(command_prefix='>') + intents = discord.Intents.default() + intents.message_content = True + bot = commands.Bot(command_prefix='>', intents=intents) @bot.command() async def ping(ctx): diff --git a/discord/__init__.py b/discord/__init__.py index 3623a5283ca8..3279f8b8c048 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -1,12 +1,10 @@ -# -*- coding: utf-8 -*- - """ Discord API Wrapper ~~~~~~~~~~~~~~~~~~~ A basic wrapper for the Discord API. -:copyright: (c) 2015-2019 Rapptz +:copyright: (c) 2015-present Rapptz :license: MIT, see LICENSE for more details. """ @@ -14,52 +12,88 @@ __title__ = 'discord' __author__ = 'Rapptz' __license__ = 'MIT' -__copyright__ = 'Copyright 2015-2019 Rapptz' -__version__ = '1.2.3' +__copyright__ = 'Copyright 2015-present Rapptz' +__version__ = '2.7.0a' + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) -from collections import namedtuple import logging +from typing import NamedTuple, Literal -from .client import Client -from .appinfo import AppInfo -from .user import User, ClientUser, Profile -from .emoji import Emoji, PartialEmoji +from .client import * +from .appinfo import * +from .user import * +from .emoji import * +from .partial_emoji import * from .activity import * from .channel import * -from .guild import Guild, SystemChannelFlags -from .relationship import Relationship -from .member import Member, VoiceState -from .message import Message, Attachment -from .asset import Asset +from .guild import * +from .flags import * +from .member import * +from .message import * +from .asset import * from .errors import * -from .calls import CallMessage, GroupCall -from .permissions import Permissions, PermissionOverwrite -from .role import Role -from .file import File -from .colour import Color, Colour -from .invite import Invite, PartialInviteChannel, PartialInviteGuild -from .widget import Widget, WidgetMember, WidgetChannel -from .object import Object -from .reaction import Reaction -from . import utils, opus, abc +from .permissions import * +from .role import * +from .file import * +from .colour import * +from .integrations import * +from .invite import * +from .template import * +from .welcome_screen import * +from .sku import * +from .widget import * +from .object import * +from .reaction import * +from . import ( + utils as utils, + opus as opus, + abc as abc, + ui as ui, + app_commands as app_commands, +) from .enums import * -from .embeds import Embed -from .shard import AutoShardedClient +from .embeds import * +from .mentions import * +from .shard import * from .player import * from .webhook import * -from .voice_client import VoiceClient -from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff +from .voice_client import * +from .audit_logs import * from .raw_models import * +from .team import * +from .sticker import * +from .stage_instance import * +from .scheduled_event import * +from .interactions import * +from .components import * +from .threads import * +from .automod import * +from .poll import * +from .soundboard import * +from .subscription import * +from .presences import * +from .primary_guild import * +from .onboarding import * +from .collectible import * + + +class VersionInfo(NamedTuple): + major: int + minor: int + micro: int + releaselevel: Literal['alpha', 'beta', 'candidate', 'final'] + serial: int + -VersionInfo = namedtuple('VersionInfo', 'major minor micro releaselevel serial') +version_info: VersionInfo = VersionInfo(major=2, minor=7, micro=0, releaselevel='alpha', serial=0) -version_info = VersionInfo(major=1, minor=2, micro=3, releaselevel='final', serial=0) +logging.getLogger(__name__).addHandler(logging.NullHandler()) -try: - from logging import NullHandler -except ImportError: - class NullHandler(logging.Handler): - def emit(self, record): - pass +# This is a backwards compatibility hack and should be removed in v3 +# Essentially forcing the exception to have different base classes +# In the future, this should only inherit from ClientException +if len(MissingApplicationID.__bases__) == 1: + MissingApplicationID.__bases__ = (app_commands.AppCommandError, ClientException) -logging.getLogger(__name__).addHandler(NullHandler()) +del logging, NamedTuple, Literal, VersionInfo diff --git a/discord/__main__.py b/discord/__main__.py index 3831938e5df1..455c5e8ed119 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,65 +22,75 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import Optional, Tuple, Dict + import argparse import sys -from pathlib import Path +from pathlib import Path, PurePath, PureWindowsPath import discord -import pkg_resources +import importlib.metadata import aiohttp -import websockets import platform -def show_version(): + +def show_version() -> None: entries = [] entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) version_info = discord.version_info entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info)) if version_info.releaselevel != 'final': - pkg = pkg_resources.get_distribution('discord.py') - if pkg: - entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version)) + version = importlib.metadata.version('discord.py') + if version: + entries.append(f' - discord.py metadata: v{version}') - entries.append('- aiohttp v{0.__version__}'.format(aiohttp)) - entries.append('- websockets v{0.__version__}'.format(websockets)) + entries.append(f'- aiohttp v{aiohttp.__version__}') uname = platform.uname() entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) print('\n'.join(entries)) -def core(parser, args): + +def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: if args.version: show_version() + else: + parser.print_help() + -bot_template = """#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +_bot_template = """#!/usr/bin/env python3 from discord.ext import commands import discord import config class Bot(commands.{base}): - def __init__(self, **kwargs): - super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs) + def __init__(self, intents: discord.Intents, **kwargs): + super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), intents=intents, **kwargs) + + async def setup_hook(self): for cog in config.cogs: try: - self.load_extension(cog) + await self.load_extension(cog) except Exception as exc: - print('Could not load extension {{0}} due to {{1.__class__.__name__}}: {{1}}'.format(cog, exc)) + print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}') async def on_ready(self): - print('Logged on as {{0}} (ID: {{0.id}})'.format(self.user)) + print(f'Logged on as {{self.user}} (ID: {{self.user.id}})') -bot = Bot() +intents = discord.Intents.default() +intents.message_content = True +bot = Bot(intents=intents) # write general commands here bot.run(config.token) """ -gitignore_template = """# Byte-compiled / optimized / DLL files +_gitignore_template = """# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class @@ -112,9 +120,7 @@ async def on_ready(self): config.py """ -cog_template = '''# -*- coding: utf-8 -*- - -from discord.ext import commands +_cog_template = '''from discord.ext import commands import discord class {name}(commands.Cog{attrs}): @@ -123,12 +129,16 @@ class {name}(commands.Cog{attrs}): def __init__(self, bot): self.bot = bot {extra} -def setup(bot): - bot.add_cog({name}(bot)) +async def setup(bot): + await bot.add_cog({name}(bot)) ''' -cog_extras = ''' - def cog_unload(self): +_cog_extras = """ + async def cog_load(self): + # loading logic goes here + pass + + async def cog_unload(self): # clean up logic goes here pass @@ -148,6 +158,10 @@ async def cog_command_error(self, ctx, error): # error handling to every command in here pass + async def cog_app_command_error(self, interaction, error): + # error handling to every application command in here + pass + async def cog_before_invoke(self, ctx): # called before a command is called here pass @@ -156,13 +170,13 @@ async def cog_after_invoke(self, ctx): # called after a command is called here pass -''' +""" # certain file names and directory names are forbidden # see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx # although some of this doesn't apply to Linux, we might as well be consistent -_base_table = { +_base_table: Dict[str, Optional[str]] = { '<': '-', '>': '-', ':': '-', @@ -177,24 +191,54 @@ async def cog_after_invoke(self, ctx): # NUL (0) and 1-31 are disallowed _base_table.update((chr(i), None) for i in range(32)) -translation_table = str.maketrans(_base_table) +_translation_table = str.maketrans(_base_table) -def to_path(parser, name, *, replace_spaces=False): + +def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path: if isinstance(name, Path): return name if sys.platform == 'win32': - forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ - 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') + forbidden = ( + 'CON', + 'PRN', + 'AUX', + 'NUL', + 'COM1', + 'COM2', + 'COM3', + 'COM4', + 'COM5', + 'COM6', + 'COM7', + 'COM8', + 'COM9', + 'LPT1', + 'LPT2', + 'LPT3', + 'LPT4', + 'LPT5', + 'LPT6', + 'LPT7', + 'LPT8', + 'LPT9', + ) if len(name) <= 4 and name.upper() in forbidden: parser.error('invalid directory name given, use a different one') + path = PurePath(name) + if isinstance(path, PureWindowsPath) and path.drive: + drive, rest = path.parts[0], path.parts[1:] + transformed = tuple(map(lambda p: p.translate(_translation_table), rest)) + name = drive + '\\'.join(transformed) - name = name.translate(translation_table) + else: + name = name.translate(_translation_table) if replace_spaces: name = name.replace(' ', '-') return Path(name) -def newbot(parser, args): + +def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: new_directory = to_path(parser, args.directory) / to_path(parser, args.name) # as a note exist_ok for Path is a 3.5+ only feature @@ -202,7 +246,7 @@ def newbot(parser, args): try: new_directory.mkdir(exist_ok=True, parents=True) except OSError as exc: - parser.error('could not create our bot directory ({})'.format(exc)) + parser.error(f'could not create our bot directory ({exc})') cogs = new_directory / 'cogs' @@ -211,63 +255,66 @@ def newbot(parser, args): init = cogs / '__init__.py' init.touch() except OSError as exc: - print('warning: could not create cogs directory ({})'.format(exc)) + print(f'warning: could not create cogs directory ({exc})') try: with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp: fp.write('token = "place your token here"\ncogs = []\n') except OSError as exc: - parser.error('could not create config file ({})'.format(exc)) + parser.error(f'could not create config file ({exc})') try: with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp: base = 'Bot' if not args.sharded else 'AutoShardedBot' - fp.write(bot_template.format(base=base, prefix=args.prefix)) + fp.write(_bot_template.format(base=base, prefix=args.prefix)) except OSError as exc: - parser.error('could not create bot file ({})'.format(exc)) + parser.error(f'could not create bot file ({exc})') if not args.no_git: try: with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp: - fp.write(gitignore_template) + fp.write(_gitignore_template) except OSError as exc: - print('warning: could not create .gitignore file ({})'.format(exc)) + print(f'warning: could not create .gitignore file ({exc})') print('successfully made bot at', new_directory) -def newcog(parser, args): + +def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: cog_dir = to_path(parser, args.directory) try: cog_dir.mkdir(exist_ok=True) except OSError as exc: - print('warning: could not create cogs directory ({})'.format(exc)) + print(f'warning: could not create cogs directory ({exc})') directory = cog_dir / to_path(parser, args.name) directory = directory.with_suffix('.py') try: with open(str(directory), 'w', encoding='utf-8') as fp: attrs = '' - extra = cog_extras if args.full else '' + extra = _cog_extras if args.full else '' if args.class_name: name = args.class_name else: name = str(directory.stem) - if '-' in name: - name = name.replace('-', ' ').title().replace(' ', '') + if '-' in name or '_' in name: + translation = str.maketrans('-_', ' ') + name = name.translate(translation).title().replace(' ', '') else: name = name.title() if args.display_name: - attrs += ', name="{}"'.format(args.display_name) + attrs += f', name="{args.display_name}"' if args.hide_commands: attrs += ', command_attrs=dict(hidden=True)' - fp.write(cog_template.format(name=name, extra=extra, attrs=attrs)) + fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) except OSError as exc: - parser.error('could not create cog file ({})'.format(exc)) + parser.error(f'could not create cog file ({exc})') else: print('successfully made cog at', directory) -def add_newbot_args(subparser): + +def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None: parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser.set_defaults(func=newbot) @@ -277,7 +324,8 @@ def add_newbot_args(subparser): parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') -def add_newcog_args(subparser): + +def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None: parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser.set_defaults(func=newcog) @@ -288,7 +336,8 @@ def add_newcog_args(subparser): parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') parser.add_argument('--full', help='add all special methods as well', action='store_true') -def parse_args(): + +def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]: parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.set_defaults(func=core) @@ -298,8 +347,11 @@ def parse_args(): add_newcog_args(subparser) return parser, parser.parse_args() -def main(): + +def main() -> None: parser, args = parse_args() args.func(parser, args) -main() + +if __name__ == '__main__': + main() diff --git a/discord/_types.py b/discord/_types.py new file mode 100644 index 000000000000..331063544a4e --- /dev/null +++ b/discord/_types.py @@ -0,0 +1,34 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeVar + from .client import Client + + ClientT = TypeVar('ClientT', bound=Client, covariant=True, default=Client) +else: + ClientT = TypeVar('ClientT', bound='Client', covariant=True) diff --git a/discord/abc.py b/discord/abc.py index 1d490cd4610b..95ccfd67b690 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,60 +22,232 @@ DEALINGS IN THE SOFTWARE. """ -import abc +from __future__ import annotations + import copy +import time +import secrets import asyncio -from collections import namedtuple - -from .iterators import HistoryIterator +from datetime import datetime +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Generator, + Iterable, + List, + Literal, + Optional, + TYPE_CHECKING, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, + overload, + runtime_checkable, +) + +from .object import OLDEST_OBJECT, Object from .context_managers import Typing -from .errors import InvalidArgument, ClientException, HTTPException +from .enums import ChannelType, InviteTarget +from .errors import ClientException, NotFound +from .mentions import AllowedMentions from .permissions import PermissionOverwrite, Permissions from .role import Role from .invite import Invite from .file import File -from .voice_client import VoiceClient +from .http import handle_message_parameters +from .voice_client import VoiceClient, VoiceProtocol +from .sticker import GuildSticker, StickerItem from . import utils +from .flags import InviteFlags +import warnings + +__all__ = ( + 'Snowflake', + 'User', + 'PrivateChannel', + 'GuildChannel', + 'Messageable', + 'Connectable', +) + +T = TypeVar('T', bound=VoiceProtocol) + +if TYPE_CHECKING: + from typing_extensions import Self, Unpack + + from .client import Client + from .user import ClientUser + from .asset import Asset + from .state import ConnectionState + from .guild import Guild + from .member import Member + from .channel import CategoryChannel + from .embeds import Embed + from .message import Message, MessageReference, PartialMessage + from .channel import ( + TextChannel, + DMChannel, + GroupChannel, + PartialMessageable, + VocalGuildChannel, + VoiceChannel, + StageChannel, + ) + from .poll import Poll + from .threads import Thread + from .ui.view import BaseView, View, LayoutView + from .types.channel import ( + PermissionOverwrite as PermissionOverwritePayload, + Channel as ChannelPayload, + GuildChannel as GuildChannelPayload, + OverwriteType, + ) + from .types.guild import ( + ChannelPositionUpdate, + ) + from .types.snowflake import ( + SnowflakeList, + ) + from .permissions import _PermissionOverwriteKwargs + + PartialMessageableChannel = Union[TextChannel, VoiceChannel, StageChannel, Thread, DMChannel, PartialMessageable] + MessageableChannel = Union[PartialMessageableChannel, GroupChannel] + SnowflakeTime = Union['Snowflake', datetime] + + class PinnedMessage(Message): + pinned_at: datetime + pinned: Literal[True] + + +MISSING = utils.MISSING + class _Undefined: - def __repr__(self): + def __repr__(self) -> str: return 'see-below' -_undefined = _Undefined() -class Snowflake(metaclass=abc.ABCMeta): +_undefined: Any = _Undefined() + + +class _PinsIterator: + def __init__(self, iterator: AsyncIterator[PinnedMessage]) -> None: + self.__iterator: AsyncIterator[PinnedMessage] = iterator + + def __await__(self) -> Generator[Any, None, List[PinnedMessage]]: + warnings.warn( + '`await .pins()` is deprecated; use `async for message in .pins()` instead.', + DeprecationWarning, + stacklevel=2, + ) + + async def gather() -> List[PinnedMessage]: + return [msg async for msg in self.__iterator] + + return gather().__await__() + + def __aiter__(self) -> AsyncIterator[PinnedMessage]: + return self.__iterator + + +async def _single_delete_strategy(messages: Iterable[Message], *, reason: Optional[str] = None): + for m in messages: + try: + await m.delete() + except NotFound as exc: + if exc.code == 10008: + continue # bulk deletion ignores not found messages, single deletion does not. + # several other race conditions with deletion should fail without continuing, + # such as the channel being deleted and not found. + raise + + +async def _purge_helper( + channel: Union[Thread, TextChannel, VocalGuildChannel], + *, + limit: Optional[int] = 100, + check: Callable[[Message], bool] = MISSING, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + bulk: bool = True, + reason: Optional[str] = None, +) -> List[Message]: + if check is MISSING: + check = lambda m: True + + iterator = channel.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) + ret: List[Message] = [] + count = 0 + + minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 + strategy = channel.delete_messages if bulk else _single_delete_strategy + + async for message in iterator: + if count == 100: + to_delete = ret[-100:] + await strategy(to_delete, reason=reason) + count = 0 + await asyncio.sleep(1) + + if not message.type.is_deletable(): + continue + + if not check(message): + continue + + if message.id < minimum_time: + # older than 14 days old + if count == 1: + await ret[-1].delete() + elif count >= 2: + to_delete = ret[-count:] + await strategy(to_delete, reason=reason) + + count = 0 + strategy = _single_delete_strategy + + count += 1 + ret.append(message) + + # Some messages remaining to poll + if count >= 2: + # more than 2 messages -> bulk delete + to_delete = ret[-count:] + await strategy(to_delete, reason=reason) + elif count == 1: + # delete a single message + await ret[-1].delete() + + return ret + + +@runtime_checkable +class Snowflake(Protocol): """An ABC that details the common operations on a Discord model. Almost all :ref:`Discord models ` meet this abstract base class. + If you want to create a snowflake on your own, consider using + :class:`.Object`. + Attributes ----------- id: :class:`int` The model's unique ID. """ - __slots__ = () - @property - @abc.abstractmethod - def created_at(self): - """:class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC.""" - raise NotImplementedError + id: int - @classmethod - def __subclasshook__(cls, C): - if cls is Snowflake: - mro = C.__mro__ - for attr in ('created_at', 'id'): - for base in mro: - if attr in base.__dict__: - break - else: - return NotImplemented - return True - return NotImplemented -class User(metaclass=abc.ABCMeta): +@runtime_checkable +class User(Snowflake, Protocol): """An ABC that details the common operations on a Discord user. The following implement this ABC: @@ -93,43 +263,84 @@ class User(metaclass=abc.ABCMeta): name: :class:`str` The user's username. discriminator: :class:`str` - The user's discriminator. - avatar: Optional[:class:`str`] - The avatar hash the user has. + The user's discriminator. This is a legacy concept that is no longer used. + global_name: Optional[:class:`str`] + The user's global nickname. bot: :class:`bool` If the user is a bot account. + system: :class:`bool` + If the user is a system account. """ - __slots__ = () + + name: str + discriminator: str + global_name: Optional[str] + bot: bool + system: bool @property - @abc.abstractmethod - def display_name(self): + def display_name(self) -> str: """:class:`str`: Returns the user's display name.""" raise NotImplementedError @property - @abc.abstractmethod - def mention(self): + def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the given user.""" raise NotImplementedError - @classmethod - def __subclasshook__(cls, C): - if cls is User: - if Snowflake.__subclasshook__(C) is NotImplemented: - return NotImplemented - - mro = C.__mro__ - for attr in ('display_name', 'mention', 'name', 'avatar', 'discriminator', 'bot'): - for base in mro: - if attr in base.__dict__: - break - else: - return NotImplemented - return True - return NotImplemented + @property + def avatar(self) -> Optional[Asset]: + """Optional[:class:`~discord.Asset`]: Returns an Asset that represents the user's avatar, if present.""" + raise NotImplementedError + + @property + def avatar_decoration(self) -> Optional[Asset]: + """Optional[:class:`~discord.Asset`]: Returns an Asset that represents the user's avatar decoration, if present. + + .. versionadded:: 2.4 + """ + raise NotImplementedError + + @property + def avatar_decoration_sku_id(self) -> Optional[int]: + """Optional[:class:`int`]: Returns an integer that represents the user's avatar decoration SKU ID, if present. -class PrivateChannel(metaclass=abc.ABCMeta): + .. versionadded:: 2.4 + """ + raise NotImplementedError + + @property + def default_avatar(self) -> Asset: + """:class:`~discord.Asset`: Returns the default avatar for a given user.""" + raise NotImplementedError + + @property + def display_avatar(self) -> Asset: + """:class:`~discord.Asset`: Returns the user's display avatar. + + For regular users this is just their default avatar or uploaded avatar. + + .. versionadded:: 2.0 + """ + raise NotImplementedError + + def mentioned_in(self, message: Message) -> bool: + """Checks if the user is mentioned in the specified message. + + Parameters + ----------- + message: :class:`~discord.Message` + The message to check if you're mentioned in. + + Returns + ------- + :class:`bool` + Indicates if the user is mentioned in the message. + """ + raise NotImplementedError + + +class PrivateChannel: """An ABC that details the common operations on a private Discord channel. The following implement this ABC: @@ -144,22 +355,39 @@ class PrivateChannel(metaclass=abc.ABCMeta): me: :class:`~discord.ClientUser` The user presenting yourself. """ + __slots__ = () - @classmethod - def __subclasshook__(cls, C): - if cls is PrivateChannel: - if Snowflake.__subclasshook__(C) is NotImplemented: - return NotImplemented + id: int + me: ClientUser + + +class _Overwrites: + __slots__ = ('id', 'allow', 'deny', 'type') - mro = C.__mro__ - for base in mro: - if 'me' in base.__dict__: - return True - return NotImplemented - return NotImplemented + ROLE = 0 + MEMBER = 1 + + def __init__(self, data: PermissionOverwritePayload) -> None: + self.id: int = int(data['id']) + self.allow: int = int(data.get('allow', 0)) + self.deny: int = int(data.get('deny', 0)) + self.type: OverwriteType = data['type'] + + def _asdict(self) -> PermissionOverwritePayload: + return { + 'id': self.id, + 'allow': str(self.allow), + 'deny': str(self.deny), + 'type': self.type, + } + + def is_role(self) -> bool: + return self.type == 0 + + def is_member(self) -> bool: + return self.type == 1 -_Overwrites = namedtuple('_Overwrites', 'id allow deny type') class GuildChannel: """An ABC that details the common operations on a Discord guild channel. @@ -169,6 +397,8 @@ class GuildChannel: - :class:`~discord.TextChannel` - :class:`~discord.VoiceChannel` - :class:`~discord.CategoryChannel` + - :class:`~discord.StageChannel` + - :class:`~discord.ForumChannel` This ABC must also implement :class:`~discord.abc.Snowflake`. @@ -182,25 +412,46 @@ class GuildChannel: The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. """ + __slots__ = () - def __str__(self): + id: int + name: str + guild: Guild + type: ChannelType + position: int + category_id: Optional[int] + _state: ConnectionState + _overwrites: List[_Overwrites] + + if TYPE_CHECKING: + + def __init__(self, *, state: ConnectionState, guild: Guild, data: GuildChannelPayload): ... + + def __str__(self) -> str: return self.name @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: + raise NotImplementedError + + def _update(self, guild: Guild, data: Dict[str, Any]) -> None: raise NotImplementedError - async def _move(self, position, parent_id=None, lock_permissions=False, *, reason): + async def _move( + self, + position: int, + parent_id: Optional[Any] = None, + lock_permissions: bool = False, + *, + reason: Optional[str], + ) -> None: if position < 0: - raise InvalidArgument('Channel position cannot be less than 0.') + raise ValueError('Channel position cannot be less than 0.') http = self._state.http bucket = self._sorting_bucket - channels = [c for c in self.guild.channels if c._sorting_bucket == bucket] - - if position >= len(channels): - raise InvalidArgument('Channel position cannot be greater than {}'.format(len(channels) - 1)) + channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] channels.sort(key=lambda c: c.position) @@ -211,22 +462,20 @@ async def _move(self, position, parent_id=None, lock_permissions=False, *, reaso # not there somehow lol return else: + index = next((i for i, c in enumerate(channels) if c.position >= position), len(channels)) # add ourselves at our designated position - channels.insert(position, self) + channels.insert(index, self) payload = [] for index, c in enumerate(channels): - d = {'id': c.id, 'position': index} + d: Dict[str, Any] = {'id': c.id, 'position': index} if parent_id is not _undefined and c.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) await http.bulk_channel_update(self.guild.id, payload, reason=reason) - self.position = position - if parent_id is not _undefined: - self.category_id = int(parent_id) if parent_id else None - async def _edit(self, options, reason): + async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]: try: parent = options.pop('category') except KeyError: @@ -239,6 +488,25 @@ async def _edit(self, options, reason): except KeyError: pass + try: + options['default_thread_rate_limit_per_user'] = options.pop('default_thread_slowmode_delay') + except KeyError: + pass + + try: + rtc_region = options.pop('rtc_region') + except KeyError: + pass + else: + options['rtc_region'] = None if rtc_region is None else str(rtc_region) + + try: + video_quality_mode = options.pop('video_quality_mode') + except KeyError: + pass + else: + options['video_quality_mode'] = int(video_quality_mode) + lock_permissions = options.pop('sync_permissions', False) try: @@ -247,33 +515,74 @@ async def _edit(self, options, reason): if parent_id is not _undefined: if lock_permissions: category = self.guild.get_channel(parent_id) - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + if category: + options['permission_overwrites'] = [c._asdict() for c in category._overwrites] options['parent_id'] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + if category: + options['permission_overwrites'] = [c._asdict() for c in category._overwrites] else: await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) + overwrites = options.get('overwrites', None) + if overwrites is not None: + perms = [] + for target, perm in overwrites.items(): + if not isinstance(perm, PermissionOverwrite): + raise TypeError(f'Expected PermissionOverwrite received {perm.__class__.__name__}') + + allow, deny = perm.pair() + payload = { + 'allow': allow.value, + 'deny': deny.value, + 'id': target.id, + } + + if isinstance(target, Role): + payload['type'] = _Overwrites.ROLE + elif isinstance(target, Object): + payload['type'] = _Overwrites.ROLE if target.type is Role else _Overwrites.MEMBER + else: + payload['type'] = _Overwrites.MEMBER + + perms.append(payload) + options['permission_overwrites'] = perms + + try: + ch_type = options['type'] + except KeyError: + pass + else: + if not isinstance(ch_type, ChannelType): + raise TypeError('type field must be of type ChannelType') + options['type'] = ch_type.value + + try: + status = options.pop('status') + except KeyError: + pass + else: + await self._state.http.edit_voice_channel_status(status, channel_id=self.id, reason=reason) + if options: - data = await self._state.http.edit_channel(self.id, reason=reason, **options) - self._update(self.guild, data) + return await self._state.http.edit_channel(self.id, reason=reason, **options) - def _fill_overwrites(self, data): + def _fill_overwrites(self, data: GuildChannelPayload) -> None: self._overwrites = [] everyone_index = 0 everyone_id = self.guild.id for index, overridden in enumerate(data.get('permission_overwrites', [])): - overridden_id = int(overridden.pop('id')) - self._overwrites.append(_Overwrites(id=overridden_id, **overridden)) + overwrite = _Overwrites(overridden) + self._overwrites.append(overwrite) - if overridden['type'] == 'member': + if overwrite.type == _Overwrites.MEMBER: continue - if overridden_id == everyone_id: + if overwrite.id == everyone_id: # the @everyone role is not guaranteed to be the first one # in the list of permission overwrites, however the permission # resolution code kind of requires that it is the first one in @@ -287,12 +596,12 @@ def _fill_overwrites(self, data): tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] @property - def changed_roles(self): + def changed_roles(self) -> List[Role]: """List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from their default values in the :attr:`~discord.Guild.roles` attribute.""" ret = [] g = self.guild - for overwrite in filter(lambda o: o.type == 'role', self._overwrites): + for overwrite in filter(lambda o: o.is_role(), self._overwrites): role = g.get_role(overwrite.id) if role is None: continue @@ -303,23 +612,30 @@ def changed_roles(self): return ret @property - def mention(self): + def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" - return '<#%s>' % self.id + return f'<#{self.id}>' + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f'https://discord.com/channels/{self.guild.id}/{self.id}' @property - def created_at(self): + def created_at(self) -> datetime: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def overwrites_for(self, obj): + def overwrites_for(self, obj: Union[Role, User, Object]) -> PermissionOverwrite: """Returns the channel-specific overwrites for a member or a role. Parameters ----------- - obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`] - The role or user denoting - whose overwrite to get. + obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`, :class:`~discord.Object`] + The role or user denoting whose overwrite to get. Returns --------- @@ -328,9 +644,9 @@ def overwrites_for(self, obj): """ if isinstance(obj, User): - predicate = lambda p: p.type == 'member' + predicate = lambda p: p.is_member() elif isinstance(obj, Role): - predicate = lambda p: p.type == 'role' + predicate = lambda p: p.is_role() else: predicate = lambda p: True @@ -343,16 +659,19 @@ def overwrites_for(self, obj): return PermissionOverwrite() @property - def overwrites(self): + def overwrites(self) -> Dict[Union[Role, Member, Object], PermissionOverwrite]: """Returns all of the channel's overwrites. This is returned as a dictionary where the key contains the target which - can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the key is the + can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the value is the overwrite as a :class:`~discord.PermissionOverwrite`. + .. versionchanged:: 2.0 + Overwrites can now be type-aware :class:`~discord.Object` in case of cache lookup failure + Returns -------- - Mapping[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`] + Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`, :class:`~discord.Object`], :class:`~discord.PermissionOverwrite`] The channel's permission overwrites. """ ret = {} @@ -360,31 +679,60 @@ def overwrites(self): allow = Permissions(ow.allow) deny = Permissions(ow.deny) overwrite = PermissionOverwrite.from_pair(allow, deny) + target = None - if ow.type == 'role': + if ow.is_role(): target = self.guild.get_role(ow.id) - elif ow.type == 'member': + elif ow.is_member(): target = self.guild.get_member(ow.id) - # TODO: There is potential data loss here in the non-chunked - # case, i.e. target is None because get_member returned nothing. - # This can be fixed with a slight breaking change to the return type, - # i.e. adding discord.Object to the list of it - # However, for now this is an acceptable compromise. - if target is not None: - ret[target] = overwrite + if target is None: + target_type = Role if ow.is_role() else User + target = Object(id=ow.id, type=target_type) # type: ignore + + ret[target] = overwrite return ret @property - def category(self): + def category(self) -> Optional[CategoryChannel]: """Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to. If there is no category then this is ``None``. """ - return self.guild.get_channel(self.category_id) + return self.guild.get_channel(self.category_id) # type: ignore # These are coerced into CategoryChannel + + @property + def permissions_synced(self) -> bool: + """:class:`bool`: Whether or not the permissions for this channel are synced with the + category it belongs to. + + If there is no category then this is ``False``. + + .. versionadded:: 1.3 + """ + if self.category_id is None: + return False + + category = self.guild.get_channel(self.category_id) + return bool(category and category.overwrites == self.overwrites) + + def _apply_implicit_permissions(self, base: Permissions) -> None: + # if you can't send a message in a channel then you can't have certain + # permissions as well + if not base.send_messages: + base.send_tts_messages = False + base.mention_everyone = False + base.embed_links = False + base.attach_files = False - def permissions_for(self, member): - """Handles permission resolution for the current :class:`~discord.Member`. + # if you can't read a channel then you have no permissions there + if not base.read_messages: + denied = Permissions.all_channel() + base.value &= ~denied.value + + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + """Handles permission resolution for the :class:`~discord.Member` + or :class:`~discord.Role`. This function takes into consideration the following cases: @@ -392,16 +740,41 @@ def permissions_for(self, member): - Guild roles - Channel overrides - Member overrides + - Implicit permissions + - Member timeout + - User installed app + + If a :class:`~discord.Role` is passed, then it checks the permissions + someone with that role would have, which is essentially: + + - The default role permissions + - The permissions of the role used as a parameter + - The default role permission overwrites + - The permission overwrites of the role used as a parameter + + .. versionchanged:: 2.0 + The object passed in can now be a role object. + + .. versionchanged:: 2.0 + ``obj`` parameter is now positional-only. + + .. versionchanged:: 2.4 + User installed apps are now taken into account. + The permissions returned for a user installed app mirrors the + permissions Discord returns in :attr:`~discord.Interaction.app_permissions`, + though it is recommended to use that attribute instead. Parameters ---------- - member: :class:`~discord.Member` - The member to resolve permissions for. + obj: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The object to resolve permissions for. This could be either + a member or a role. If it's a role then member overwrites + are not computed. Returns ------- :class:`~discord.Permissions` - The resolved permissions for the member. + The resolved permissions for the member or role. """ # The current cases can be explained as: @@ -418,17 +791,50 @@ def permissions_for(self, member): # The operation first takes into consideration the denied # and then the allowed. - o = self.guild.owner - if o is not None and member.id == o.id: + if self.guild.owner_id == obj.id: return Permissions.all() default = self.guild.default_role + if default is None: + if self._state.self_id == obj.id: + return Permissions._user_installed_permissions(in_guild=True) + else: + return Permissions.none() + base = Permissions(default.permissions.value) - roles = member.roles + + # Handle the role case first + if isinstance(obj, Role): + base.value |= obj._permissions + + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + except IndexError: + pass + + if obj.is_default(): + return base + + overwrite = utils.find(lambda ow: ow.type == _Overwrites.ROLE and ow.id == obj.id, self._overwrites) + if overwrite is not None: + base.handle_overwrite(overwrite.allow, overwrite.deny) + + return base + + roles = obj._roles + get_role = self.guild.get_role # Apply guild roles that the member has. - for role in roles: - base.value |= role.permissions.value + for role_id in roles: + role = get_role(role_id) + if role is not None: + base.value |= role._permissions # Guild-wide Administrator -> True for everything # Bypass all channel-specific overrides @@ -446,19 +852,12 @@ def permissions_for(self, member): except IndexError: remaining_overwrites = self._overwrites - # not sure if doing member._roles.get(...) is better than the - # set approach. While this is O(N) to re-create into a set for O(1) - # the direct approach would just be O(log n) for searching with no - # extra memory overhead. For now, I'll keep the set cast - # Note that the member.roles accessor up top also creates a - # temporary list - member_role_ids = {r.id for r in roles} denies = 0 allows = 0 # Apply channel specific role permission overwrites for overwrite in remaining_overwrites: - if overwrite.type == 'role' and overwrite.id in member_role_ids: + if overwrite.is_role() and roles.has(overwrite.id): denies |= overwrite.deny allows |= overwrite.allow @@ -466,31 +865,24 @@ def permissions_for(self, member): # Apply member specific permission overwrites for overwrite in remaining_overwrites: - if overwrite.type == 'member' and overwrite.id == member.id: + if overwrite.is_member() and overwrite.id == obj.id: base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) break - # if you can't send a message in a channel then you can't have certain - # permissions as well - if not base.send_messages: - base.send_tts_messages = False - base.mention_everyone = False - base.embed_links = False - base.attach_files = False - - # if you can't read a channel then you have no permissions there - if not base.read_messages: - denied = Permissions.all_channel() - base.value &= ~denied.value + if obj.is_timed_out(): + # Timeout leads to every permission except VIEW_CHANNEL and READ_MESSAGE_HISTORY + # being explicitly denied + # N.B.: This *must* come last, because it's a conclusive mask + base.value &= Permissions._timeout_mask() return base - async def delete(self, *, reason=None): + async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| Deletes the channel. - You must have :attr:`~.Permissions.manage_channels` permission to use this. + You must have :attr:`~discord.Permissions.manage_channels` to do this. Parameters ----------- @@ -500,16 +892,41 @@ async def delete(self, *, reason=None): Raises ------- - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have proper permissions to delete the channel. - :exc:`~discord.NotFound` + ~discord.NotFound The channel was not found or was already deleted. - :exc:`~discord.HTTPException` + ~discord.HTTPException Deleting the channel failed. """ await self._state.http.delete_channel(self.id, reason=reason) - async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): + @overload + async def set_permissions( + self, + target: Union[Member, Role], + *, + overwrite: Optional[Union[PermissionOverwrite, _Undefined]] = ..., + reason: Optional[str] = ..., + ) -> None: ... + + @overload + async def set_permissions( + self, + target: Union[Member, Role], + *, + reason: Optional[str] = ..., + **permissions: Unpack[_PermissionOverwriteKwargs], + ) -> None: ... + + async def set_permissions( + self, + target: Union[Member, Role], + *, + overwrite: Any = _undefined, + reason: Optional[str] = None, + **permissions: Unpack[_PermissionOverwriteKwargs], + ) -> None: r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -527,7 +944,11 @@ async def set_permissions(self, target, *, overwrite=_undefined, reason=None, ** If the ``overwrite`` parameter is ``None``, then the permission overwrites are deleted. - You must have the :attr:`~.Permissions.manage_roles` permission to use this. + You must have :attr:`~discord.Permissions.manage_roles` to do this. + + .. note:: + + This method *replaces* the old overwrites with the ones given. Examples ---------- @@ -548,12 +969,18 @@ async def set_permissions(self, target, *, overwrite=_undefined, reason=None, ** overwrite.read_messages = True await channel.set_permissions(member, overwrite=overwrite) + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` instead of + ``InvalidArgument``. + + Parameters ----------- target: Union[:class:`~discord.Member`, :class:`~discord.Role`] The member or role to overwrite permissions for. - overwrite: :class:`~discord.PermissionOverwrite` - The permissions to allow and deny to the target. + overwrite: Optional[:class:`~discord.PermissionOverwrite`] + The permissions to allow and deny to the target, or ``None`` to + delete the overwrite. \*\*permissions A keyword argument list of permissions to set for ease of use. Cannot be mixed with ``overwrite``. @@ -562,103 +989,315 @@ async def set_permissions(self, target, *, overwrite=_undefined, reason=None, ** Raises ------- - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have permissions to edit channel specific permissions. - :exc:`~discord.HTTPException` + ~discord.HTTPException Editing channel specific permissions failed. - :exc:`~discord.NotFound` + ~discord.NotFound The role or member being edited is not part of the guild. - :exc:`~discord.InvalidArgument` - The overwrite parameter invalid or the target type was not + TypeError + The ``overwrite`` parameter was invalid or the target type was not :class:`~discord.Role` or :class:`~discord.Member`. + ValueError + The ``overwrite`` parameter and ``positions`` parameters were both + unset. """ http = self._state.http if isinstance(target, User): - perm_type = 'member' + perm_type = _Overwrites.MEMBER elif isinstance(target, Role): - perm_type = 'role' + perm_type = _Overwrites.ROLE else: - raise InvalidArgument('target parameter must be either Member or Role') + raise ValueError('target parameter must be either Member or Role') - if isinstance(overwrite, _Undefined): + if overwrite is _undefined: if len(permissions) == 0: - raise InvalidArgument('No overwrite provided.') + raise ValueError('No overwrite provided.') try: overwrite = PermissionOverwrite(**permissions) except (ValueError, TypeError): - raise InvalidArgument('Invalid permissions given to keyword arguments.') + raise TypeError('Invalid permissions given to keyword arguments.') else: if len(permissions) > 0: - raise InvalidArgument('Cannot mix overwrite and keyword arguments.') - - # TODO: wait for event + raise TypeError('Cannot mix overwrite and keyword arguments.') if overwrite is None: await http.delete_channel_permissions(self.id, target.id, reason=reason) elif isinstance(overwrite, PermissionOverwrite): (allow, deny) = overwrite.pair() - await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) + await http.edit_channel_permissions( + self.id, target.id, str(allow.value), str(deny.value), perm_type, reason=reason + ) else: - raise InvalidArgument('Invalid overwrite type provided.') - - async def _clone_impl(self, base_attrs, *, name=None, reason=None): - base_attrs['permission_overwrites'] = [ - x._asdict() for x in self._overwrites - ] + raise TypeError('Invalid overwrite type provided.') + + async def _clone_impl( + self, + base_attrs: Dict[str, Any], + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> Self: + base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] base_attrs['parent_id'] = self.category_id base_attrs['name'] = name or self.name + if category is not None: + base_attrs['parent_id'] = category.id + guild_id = self.guild.id cls = self.__class__ data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) obj = cls(state=self._state, guild=self.guild, data=data) # temporarily add it to the cache - self.guild._channels[obj.id] = obj + self.guild._channels[obj.id] = obj # type: ignore # obj is a GuildChannel return obj - async def clone(self, *, name=None, reason=None): + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> Self: """|coro| Clones this channel. This creates a channel with the same properties as this channel. - .. versionadded:: 1.1.0 + You must have :attr:`~discord.Permissions.manage_channels` to do this. + + .. versionadded:: 1.1 Parameters ------------ name: Optional[:class:`str`] The name of the new channel. If not provided, defaults to this channel name. + category: Optional[:class:`~discord.CategoryChannel`] + The category the new channel belongs to. + This parameter is ignored if cloning a category channel. + + .. versionadded:: 2.5 reason: Optional[:class:`str`] The reason for cloning this channel. Shows up on the audit log. Raises ------- - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have the proper permissions to create this channel. - :exc:`~discord.HTTPException` + ~discord.HTTPException Creating the channel failed. + + Returns + -------- + :class:`.abc.GuildChannel` + The channel that was created. """ raise NotImplementedError - async def create_invite(self, *, reason=None, **fields): + @overload + async def move( + self, + *, + beginning: bool, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: Optional[str] = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + end: bool, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + before: Snowflake, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + after: Snowflake, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: ... + + async def move(self, **kwargs: Any) -> None: + """|coro| + + A rich interface to help move a channel relative to other channels. + + If exact position movement is required, ``edit`` should be used instead. + + You must have :attr:`~discord.Permissions.manage_channels` to do this. + + .. note:: + + Voice channels will always be sorted below text channels. + This is a Discord limitation. + + .. versionadded:: 1.7 + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. + + Parameters + ------------ + beginning: :class:`bool` + Whether to move the channel to the beginning of the + channel list (or category if given). + This is mutually exclusive with ``end``, ``before``, and ``after``. + end: :class:`bool` + Whether to move the channel to the end of the + channel list (or category if given). + This is mutually exclusive with ``beginning``, ``before``, and ``after``. + before: :class:`~discord.abc.Snowflake` + Whether to move the channel before the given channel. + This is mutually exclusive with ``beginning``, ``end``, and ``after``. + after: :class:`~discord.abc.Snowflake` + Whether to move the channel after the given channel. + This is mutually exclusive with ``beginning``, ``end``, and ``before``. + offset: :class:`int` + The number of channels to offset the move by. For example, + an offset of ``2`` with ``beginning=True`` would move + it 2 after the beginning. A positive number moves it below + while a negative number moves it above. Note that this + number is relative and computed after the ``beginning``, + ``end``, ``before``, and ``after`` parameters. + category: Optional[:class:`~discord.abc.Snowflake`] + The category to move this channel under. + If ``None`` is given then it moves it out of the category. + This parameter is ignored if moving a category channel. + sync_permissions: :class:`bool` + Whether to sync the permissions with the category (if given). + reason: :class:`str` + The reason for the move. + + Raises + ------- + ValueError + An invalid position was given. + TypeError + A bad mix of arguments were passed. + Forbidden + You do not have permissions to move the channel. + HTTPException + Moving the channel failed. + """ + + if not kwargs: + return + + beginning, end = kwargs.get('beginning'), kwargs.get('end') + before, after = kwargs.get('before'), kwargs.get('after') + offset = kwargs.get('offset', 0) + if sum(bool(a) for a in (beginning, end, before, after)) > 1: + raise TypeError('Only one of [before, after, end, beginning] can be used.') + + bucket = self._sorting_bucket + parent_id = kwargs.get('category', MISSING) + # fmt: off + channels: List[GuildChannel] + if parent_id not in (MISSING, None): + parent_id = parent_id.id + channels = [ + ch + for ch in self.guild.channels + if ch._sorting_bucket == bucket + and ch.category_id == parent_id + ] + else: + channels = [ + ch + for ch in self.guild.channels + if ch._sorting_bucket == bucket + and ch.category_id == self.category_id + ] + # fmt: on + + channels.sort(key=lambda c: (c.position, c.id)) + + try: + # Try to remove ourselves from the channel list + channels.remove(self) + except ValueError: + # If we're not there then it's probably due to not being in the category + pass + + index = None + if beginning: + index = 0 + elif end: + index = len(channels) + elif before: + index = next((i for i, c in enumerate(channels) if c.id == before.id), None) + elif after: + index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) + + if index is None: + raise ValueError('Could not resolve appropriate move position') + + channels.insert(max((index + offset), 0), self) + payload: List[ChannelPositionUpdate] = [] + lock_permissions = kwargs.get('sync_permissions', False) + reason = kwargs.get('reason') + for index, channel in enumerate(channels): + d: ChannelPositionUpdate = {'id': channel.id, 'position': index} + if parent_id is not MISSING and channel.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + + async def create_invite( + self, + *, + reason: Optional[str] = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + target_type: Optional[InviteTarget] = None, + target_user: Optional[User] = None, + target_application_id: Optional[int] = None, + guest: bool = False, + ) -> Invite: """|coro| - Creates an instant invite. + Creates an instant invite from a text or voice channel. - You must have the :attr:`~.Permissions.create_instant_invite` permission to - do this. + You must have :attr:`~discord.Permissions.create_instant_invite` to do this. Parameters ------------ max_age: :class:`int` - How long the invite should last. If it's 0 then the invite - doesn't expire. Defaults to 0. + How long the invite should last in seconds. If it's 0 then the invite + doesn't expire. Defaults to ``0``. max_uses: :class:`int` How many uses the invite could be used for. If it's 0 then there - are unlimited uses. Defaults to 0. + are unlimited uses. Defaults to ``0``. temporary: :class:`bool` Denotes that the invite grants temporary membership (i.e. they get kicked after they disconnect). Defaults to ``False``. @@ -668,33 +1307,72 @@ async def create_invite(self, *, reason=None, **fields): invite. reason: Optional[:class:`str`] The reason for creating this invite. Shows up on the audit log. + target_type: Optional[:class:`.InviteTarget`] + The type of target for the voice channel invite, if any. + + .. versionadded:: 2.0 + + target_user: Optional[:class:`User`] + The user whose stream to display for this invite, required if ``target_type`` is :attr:`.InviteTarget.stream`. The user must be streaming in the channel. + + .. versionadded:: 2.0 + + target_application_id:: Optional[:class:`int`] + The id of the embedded application for the invite, required if ``target_type`` is :attr:`.InviteTarget.embedded_application`. + + .. versionadded:: 2.0 + guest: :class:`bool` + Whether the invite is a guest invite. + + .. versionadded:: 2.6 Raises ------- - :exc:`~discord.HTTPException` + ~discord.HTTPException Invite creation failed. + ~discord.NotFound + The channel that was passed is a category or an invalid channel. + Returns -------- :class:`~discord.Invite` The invite that was created. """ - - data = await self._state.http.create_invite(self.id, reason=reason, **fields) + if target_type is InviteTarget.unknown: + raise ValueError('Cannot create invite with an unknown target type') + + flags: Optional[InviteFlags] = None + if guest: + flags = InviteFlags._from_value(0) + flags.guest = True + + data = await self._state.http.create_invite( + self.id, + reason=reason, + max_age=max_age, + max_uses=max_uses, + temporary=temporary, + unique=unique, + target_type=target_type.value if target_type else None, + target_user_id=target_user.id if target_user else None, + target_application_id=target_application_id, + flags=flags.value if flags else None, + ) return Invite.from_incomplete(data=data, state=self._state) - async def invites(self): + async def invites(self) -> List[Invite]: """|coro| Returns a list of all active instant invites from this channel. - You must have :attr:`~.Permissions.manage_guild` to get this information. + You must have :attr:`~discord.Permissions.manage_channels` to get this information. Raises ------- - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have proper permissions to get the information. - :exc:`~discord.HTTPException` + ~discord.HTTPException An error occurred while fetching the information. Returns @@ -705,37 +1383,163 @@ async def invites(self): state = self._state data = await state.http.invites_from_channel(self.id) - result = [] + guild = self.guild + return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] - for invite in data: - invite['channel'] = self - invite['guild'] = self.guild - result.append(Invite(state=state, data=invite)) - return result - -class Messageable(metaclass=abc.ABCMeta): +class Messageable: """An ABC that details the common operations on a model that can send messages. - The following implement this ABC: + The following classes implement this ABC: - :class:`~discord.TextChannel` + - :class:`~discord.VoiceChannel` + - :class:`~discord.StageChannel` - :class:`~discord.DMChannel` - :class:`~discord.GroupChannel` + - :class:`~discord.PartialMessageable` - :class:`~discord.User` - :class:`~discord.Member` - :class:`~discord.ext.commands.Context` - - This ABC must also implement :class:`~discord.abc.Snowflake`. + - :class:`~discord.Thread` """ __slots__ = () + _state: ConnectionState - @abc.abstractmethod - async def _get_channel(self): + async def _get_channel(self) -> MessageableChannel: raise NotImplementedError - async def send(self, content=None, *, tts=False, embed=None, file=None, files=None, delete_after=None, nonce=None): + @overload + async def send( + self, + *, + file: File = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def send( + self, + *, + files: Sequence[File] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + async def send( + self, + content: Optional[str] = None, + *, + tts: bool = False, + embed: Optional[Embed] = None, + embeds: Optional[Sequence[Embed]] = None, + file: Optional[File] = None, + files: Optional[Sequence[File]] = None, + stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None, + delete_after: Optional[float] = None, + nonce: Optional[Union[str, int]] = None, + allowed_mentions: Optional[AllowedMentions] = None, + reference: Optional[Union[Message, MessageReference, PartialMessage]] = None, + mention_author: Optional[bool] = None, + view: Optional[BaseView] = None, + suppress_embeds: bool = False, + silent: bool = False, + poll: Optional[Poll] = None, + ) -> Message: """|coro| Sends a message to the destination with the content given. @@ -749,17 +1553,27 @@ async def send(self, content=None, *, tts=False, embed=None, file=None, files=No parameter should be used with a :class:`list` of :class:`~discord.File` objects. **Specifying both parameters will lead to an exception**. - If the ``embed`` parameter is provided, it must be of type :class:`~discord.Embed` and - it must be a rich embed type. + To upload a single embed, the ``embed`` parameter should be used with a + single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds`` + parameter should be used with a :class:`list` of :class:`~discord.Embed` objects. + **Specifying both parameters will lead to an exception**. + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. Parameters ------------ - content: :class:`str` + content: Optional[:class:`str`] The content of the message to send. tts: :class:`bool` Indicates if the message should be sent using text-to-speech. embed: :class:`~discord.Embed` The rich embed for the content. + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + + .. versionadded:: 2.0 file: :class:`~discord.File` The file to upload. files: List[:class:`~discord.File`] @@ -771,16 +1585,67 @@ async def send(self, content=None, *, tts=False, embed=None, file=None, files=No If provided, the number of seconds to wait in the background before deleting the message we just sent. If the deletion fails, then it is silently ignored. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + + .. versionadded:: 1.4 + + reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`] + A reference to the :class:`~discord.Message` to which you are referencing, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. + In the event of a replying reference, you can control whether this mentions the author of the referenced + message using the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions`` or by + setting ``mention_author``. + + .. versionadded:: 1.6 + + mention_author: Optional[:class:`bool`] + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + + .. versionadded:: 1.6 + view: Union[:class:`discord.ui.View`, :class:`discord.ui.LayoutView`] + A Discord UI View to add to the message. + + .. versionadded:: 2.0 + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. + + .. versionadded:: 2.0 + suppress_embeds: :class:`bool` + Whether to suppress embeds for the message. This sends the message without any embeds if set to ``True``. + + .. versionadded:: 2.0 + silent: :class:`bool` + Whether to suppress push and desktop notifications for the message. This will increment the mention counter + in the UI, but will not actually send a notification. + + .. versionadded:: 2.2 + poll: :class:`~discord.Poll` + The poll to send with this message. + + .. versionadded:: 2.4 Raises -------- - :exc:`~discord.HTTPException` + ~discord.HTTPException Sending the message failed. - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have the proper permissions to send the message. - :exc:`~discord.InvalidArgument` - The ``files`` list is not of the appropriate size or - you specified both ``file`` and ``files``. + ~discord.NotFound + You sent a message with the same nonce as one that has been explicitly + deleted shortly earlier. + ValueError + The ``files`` or ``embeds`` list is not of the appropriate size. + TypeError + You specified both ``file`` and ``files``, + or you specified both ``embed`` and ``embeds``, + or the ``reference`` object is not a :class:`~discord.Message`, + :class:`~discord.MessageReference` or :class:`~discord.PartialMessage`. Returns --------- @@ -791,79 +1656,98 @@ async def send(self, content=None, *, tts=False, embed=None, file=None, files=No channel = await self._get_channel() state = self._state content = str(content) if content is not None else None - if embed is not None: - embed = embed.to_dict() - - if file is not None and files is not None: - raise InvalidArgument('cannot pass both file and files parameter to send()') + previous_allowed_mention = state.allowed_mentions - if file is not None: - if not isinstance(file, File): - raise InvalidArgument('file parameter must be File') + if stickers is not None: + sticker_ids: SnowflakeList = [sticker.id for sticker in stickers] + else: + sticker_ids = MISSING + if reference is not None: try: - data = await state.http.send_files(channel.id, files=[file], - content=content, tts=tts, embed=embed, nonce=nonce) - finally: - file.close() + reference_dict = reference.to_message_reference_dict() + except AttributeError: + raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None + else: + reference_dict = MISSING - elif files is not None: - if len(files) > 10: - raise InvalidArgument('files parameter must be a list of up to 10 elements') - elif not all(isinstance(file, File) for file in files): - raise InvalidArgument('files parameter must be a list of File') + if view and not hasattr(view, '__discord_ui_view__'): + raise TypeError(f'view parameter must be View not {view.__class__.__name__}') - try: - data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, - embed=embed, nonce=nonce) - finally: - for f in files: - f.close() + if suppress_embeds or silent: + from .message import MessageFlags # circular import + + flags = MessageFlags._from_value(0) + flags.suppress_embeds = suppress_embeds + flags.suppress_notifications = silent else: - data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, nonce=nonce) + flags = MISSING + + if nonce is None: + nonce = secrets.randbits(64) + + with handle_message_parameters( + content=content, + tts=tts, + file=file if file is not None else MISSING, + files=files if files is not None else MISSING, + embed=embed if embed is not None else MISSING, + embeds=embeds if embeds is not None else MISSING, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference_dict, + previous_allowed_mentions=previous_allowed_mention, + mention_author=mention_author, + stickers=sticker_ids, + view=view, + flags=flags, + poll=poll, + ) as params: + data = await state.http.send_message(channel.id, params=params) ret = state.create_message(channel=channel, data=data) + if view and not view.is_finished() and view.is_dispatchable(): + state.store_view(view, ret.id) + + if poll: + poll._update(ret) + if delete_after is not None: await ret.delete(delay=delete_after) return ret - async def trigger_typing(self): - """|coro| - - Triggers a *typing* indicator to the destination. + def typing(self) -> Typing: + """Returns an asynchronous context manager that allows you to send a typing indicator to + the destination for an indefinite period of time, or 10 seconds if the context manager + is called using ``await``. - *Typing* indicator will go away after 10 seconds, or after a message is sent. - """ - - channel = await self._get_channel() - await self._state.http.send_typing(channel.id) - - def typing(self): - """Returns a context manager that allows you to type for an indefinite period of time. - - This is useful for denoting long computations in your bot. + Example Usage: :: - .. note:: + async with channel.typing(): + # simulate something heavy + await asyncio.sleep(20) - This is both a regular context manager and an async context manager. - This means that both ``with`` and ``async with`` work with this. + await channel.send('Done!') Example Usage: :: - async with channel.typing(): - # do expensive stuff here - await channel.send('done!') + await channel.typing() + # Do some computational magic for about 10 seconds + await channel.send('Done!') + .. versionchanged:: 2.0 + This no longer works with the ``with`` syntax, ``async with`` must be used instead. + + .. versionchanged:: 2.0 + Added functionality to ``await`` the context manager to send a typing indicator for 10 seconds. """ return Typing(self) - async def fetch_message(self, id): + async def fetch_message(self, id: int, /) -> Message: """|coro| Retrieves a single :class:`~discord.Message` from the destination. - This can only be used by bot accounts. - Parameters ------------ id: :class:`int` @@ -871,11 +1755,11 @@ async def fetch_message(self, id): Raises -------- - :exc:`~discord.NotFound` + ~discord.NotFound The specified message was not found. - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have the permissions required to get a message. - :exc:`~discord.HTTPException` + ~discord.HTTPException Retrieving the message failed. Returns @@ -888,37 +1772,145 @@ async def fetch_message(self, id): data = await self._state.http.get_message(channel.id, id) return self._state.create_message(channel=channel, data=data) - async def pins(self): - """|coro| + async def __pins( + self, + *, + limit: Optional[int] = 50, + before: Optional[SnowflakeTime] = None, + oldest_first: bool = False, + ) -> AsyncIterator[PinnedMessage]: + channel = await self._get_channel() + state = self._state + max_limit: int = 50 + + time: Optional[str] = ( + (before if isinstance(before, datetime) else utils.snowflake_time(before.id)).isoformat() + if before is not None + else None + ) + + while True: + retrieve = max_limit if limit is None else min(limit, max_limit) + if retrieve < 1: + break + + data = await self._state.http.pins_from( + channel_id=channel.id, + limit=retrieve, + before=time, + ) + + items = data and data['items'] + if items: + if limit is not None: + limit -= len(items) + + time = items[-1]['pinned_at'] + + # Terminate loop on next iteration; there's no data left after this + if len(items) < max_limit or not data['has_more']: + limit = 0 + + if oldest_first: + items = reversed(items) + + count = 0 + for count, m in enumerate(items, start=1): + message: Message = state.create_message(channel=channel, data=m['message']) + message._pinned_at = utils.parse_time(m['pinned_at']) + yield message # pyright: ignore[reportReturnType] + + if count < max_limit: + break + + def pins( + self, + *, + limit: Optional[int] = 50, + before: Optional[SnowflakeTime] = None, + oldest_first: bool = False, + ) -> _PinsIterator: + """Retrieves an :term:`asynchronous iterator` of the pinned messages in the channel. + + You must have :attr:`~discord.Permissions.view_channel` and + :attr:`~discord.Permissions.read_message_history` in order to use this. + + .. versionchanged:: 2.6 + + Due to a change in Discord's API, this now returns a paginated iterator instead of a list. - Retrieves all messages that are currently pinned in the channel. + For backwards compatibility, you can still retrieve a list of pinned messages by + using ``await`` on the returned object. This is however deprecated. .. note:: Due to a limitation with the Discord API, the :class:`.Message` - objects returned by this method do not contain complete + object returned by this method does not contain complete :attr:`.Message.reactions` data. + Examples + --------- + + Usage :: + + counter = 0 + async for message in channel.pins(limit=250): + counter += 1 + + Flattening into a list: :: + + messages = [message async for message in channel.pins(limit=50)] + # messages is now a list of Message... + + All parameters are optional. + + Parameters + ----------- + limit: Optional[int] + The number of pinned messages to retrieve. If ``None``, it retrieves + every pinned message in the channel. Note, however, that this would + make it a slow operation. + Defaults to ``50``. + + .. versionadded:: 2.6 + before: Optional[Union[:class:`datetime.datetime`, :class:`.abc.Snowflake`]] + Retrieve pinned messages before this time or snowflake. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + + .. versionadded:: 2.6 + oldest_first: :class:`bool` + If set to ``True``, return messages in oldest pin->newest pin order. + Defaults to ``False``. + + .. versionadded:: 2.6 + Raises ------- - :exc:`~discord.HTTPException` + ~discord.Forbidden + You do not have the permission to retrieve pinned messages. + ~discord.HTTPException Retrieving the pinned messages failed. - Returns - -------- - List[:class:`~discord.Message`] - The messages that are currently pinned. + Yields + ------- + :class:`~discord.Message` + The pinned message with :attr:`.Message.pinned_at` set. """ + return _PinsIterator(self.__pins(limit=limit, before=before, oldest_first=oldest_first)) - channel = await self._get_channel() - state = self._state - data = await state.http.pins_from(channel.id) - return [state.create_message(channel=channel, data=m) for m in data] - - def history(self, *, limit=100, before=None, after=None, around=None, oldest_first=None): - """Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history. + async def history( + self, + *, + limit: Optional[int] = 100, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + ) -> AsyncIterator[Message]: + """Returns an :term:`asynchronous iterator` that enables receiving the destination's message history. - You must have :attr:`~.Permissions.read_message_history` permissions to use this. + You must have :attr:`~discord.Permissions.read_message_history` to do this. Examples --------- @@ -932,7 +1924,7 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir Flattening into a list: :: - messages = await channel.history(limit=123).flatten() + messages = [message async for message in channel.history(limit=123)] # messages is now a list of Message... All parameters are optional. @@ -945,13 +1937,16 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir that this would make it a slow operation. before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] Retrieve messages before this date or message. - If a date is provided it must be a timezone-naive datetime representing UTC time. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] Retrieve messages after this date or message. - If a date is provided it must be a timezone-naive datetime representing UTC time. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. around: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] Retrieve messages around this date or message. - If a date is provided it must be a timezone-naive datetime representing UTC time. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. When using this argument, the maximum limit is 101. Note that if the limit is an even number then this will return at most limit + 1 messages. oldest_first: Optional[:class:`bool`] @@ -960,9 +1955,9 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir Raises ------ - :exc:`~discord.Forbidden` + ~discord.Forbidden You do not have permissions to get channel message history. - :exc:`~discord.HTTPException` + ~discord.HTTPException The request to get message history failed. Yields @@ -970,73 +1965,197 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir :class:`~discord.Message` The message with the message data parsed. """ - return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) + async def _around_strategy(retrieve: int, around: Optional[Snowflake], limit: Optional[int]): + if not around: + return [], None, 0 + + around_id = around.id if around else None + data = await self._state.http.logs_from(channel.id, retrieve, around=around_id) + + return data, None, 0 -class Connectable(metaclass=abc.ABCMeta): + async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]): + after_id = after.id if after else None + data = await self._state.http.logs_from(channel.id, retrieve, after=after_id) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[0]['id'])) + + return data, after, limit + + async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]): + before_id = before.id if before else None + data = await self._state.http.logs_from(channel.id, retrieve, before=before_id) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[-1]['id'])) + + return data, before, limit + + if isinstance(before, datetime): + before = Object(id=utils.time_snowflake(before, high=False)) + if isinstance(after, datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + if isinstance(around, datetime): + around = Object(id=utils.time_snowflake(around)) + + if oldest_first is None: + reverse = after is not None + else: + reverse = oldest_first + + after = after or OLDEST_OBJECT + predicate = None + + if around: + if limit is None: + raise ValueError('history does not support around with limit=None') + if limit > 101: + raise ValueError('history max limit 101 when specifying around parameter') + + # Strange Discord quirk + limit = 100 if limit == 101 else limit + + strategy, state = _around_strategy, around + + if before and after: + predicate = lambda m: after.id < int(m['id']) < before.id + elif before: + predicate = lambda m: int(m['id']) < before.id + elif after: + predicate = lambda m: after.id < int(m['id']) + elif reverse: + strategy, state = _after_strategy, after + if before: + predicate = lambda m: int(m['id']) < before.id + else: + strategy, state = _before_strategy, before + if after and after != OLDEST_OBJECT: + predicate = lambda m: int(m['id']) > after.id + + channel = await self._get_channel() + + while True: + retrieve = 100 if limit is None else min(limit, 100) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + if reverse: + data = reversed(data) + if predicate: + data = filter(predicate, data) + + count = 0 + + for count, raw_message in enumerate(data, 1): + yield self._state.create_message(channel=channel, data=raw_message) + + if count < 100: + # There's no data left after this + break + + +class Connectable(Protocol): """An ABC that details the common operations on a channel that can connect to a voice server. The following implement this ABC: - :class:`~discord.VoiceChannel` + - :class:`~discord.StageChannel` """ + __slots__ = () + _state: ConnectionState - @abc.abstractmethod - def _get_voice_client_key(self): + def _get_voice_client_key(self) -> Tuple[int, str]: raise NotImplementedError - @abc.abstractmethod - def _get_voice_state_pair(self): + def _get_voice_state_pair(self) -> Tuple[int, int]: raise NotImplementedError - async def connect(self, *, timeout=60.0, reconnect=True): + async def connect( + self, + *, + timeout: float = 30.0, + reconnect: bool = True, + cls: Callable[[Client, Connectable], T] = VoiceClient, + self_deaf: bool = False, + self_mute: bool = False, + ) -> T: """|coro| - Connects to voice and creates a :class:`VoiceClient` to establish + Connects to voice and creates a :class:`~discord.VoiceClient` to establish your connection to the voice server. + This requires :attr:`~discord.Intents.voice_states`. + Parameters ----------- timeout: :class:`float` - The timeout in seconds to wait for the voice endpoint. + The timeout in seconds to wait the connection to complete. reconnect: :class:`bool` Whether the bot should automatically attempt a reconnect if a part of the handshake fails or the gateway goes down. + cls: Type[:class:`~discord.VoiceProtocol`] + A type that subclasses :class:`~discord.VoiceProtocol` to connect with. + Defaults to :class:`~discord.VoiceClient`. + self_mute: :class:`bool` + Indicates if the client should be self-muted. + + .. versionadded:: 2.0 + self_deaf: :class:`bool` + Indicates if the client should be self-deafened. + + .. versionadded:: 2.0 Raises ------- - :exc:`asyncio.TimeoutError` + asyncio.TimeoutError Could not connect to the voice channel in time. - :exc:`~discord.ClientException` + ~discord.ClientException You are already connected to a voice channel. - :exc:`~discord.opus.OpusNotLoaded` + ~discord.opus.OpusNotLoaded The opus library has not been loaded. Returns -------- - :class:`~discord.VoiceClient` + :class:`~discord.VoiceProtocol` A voice client that is fully connected to the voice server. """ + key_id, _ = self._get_voice_client_key() state = self._state if state._get_voice_client(key_id): raise ClientException('Already connected to a voice channel.') - voice = VoiceClient(state=state, timeout=timeout, channel=self) + client = state._get_client() + voice: T = cls(client, self) + + if not isinstance(voice, VoiceProtocol): + raise TypeError('Type must meet VoiceProtocol abstract base class.') + state._add_voice_client(key_id, voice) try: - await voice.connect(reconnect=reconnect) + await voice.connect(timeout=timeout, reconnect=reconnect, self_deaf=self_deaf, self_mute=self_mute) except asyncio.TimeoutError: try: await voice.disconnect(force=True) except Exception: # we don't care if disconnect failed because connection failed pass - raise # re-raise + raise # re-raise return voice diff --git a/discord/activity.py b/discord/activity.py index 3acc046d3d30..d15da49a51a9 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,17 +22,24 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import datetime +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, overload -from .enums import ActivityType, try_enum +from .asset import Asset +from .enums import ActivityType, StatusDisplayType, try_enum from .colour import Colour +from .partial_emoji import PartialEmoji from .utils import _get_as_snowflake __all__ = ( + 'BaseActivity', 'Activity', 'Streaming', 'Game', 'Spotify', + 'CustomActivity', ) """If curious, this is the current schema for an activity. @@ -69,6 +74,7 @@ sync_id: str session_id: str flags: int +buttons: list[str (max: 32)] There are also activity flags which are mostly uninteresting for the library atm. @@ -82,10 +88,55 @@ } """ -class _ActivityTag: - __slots__ = () +if TYPE_CHECKING: + from .types.activity import ( + Activity as ActivityPayload, + ActivityTimestamps, + ActivityParty, + ActivityAssets, + ) + + from .state import ConnectionState + + +class BaseActivity: + """The base activity that all user-settable activities inherit from. + A user-settable activity is one that can be used in :meth:`Client.change_presence`. + + The following types currently count as user-settable: + + - :class:`Activity` + - :class:`Game` + - :class:`Streaming` + - :class:`CustomActivity` + + Note that although these types are considered user-settable by the library, + Discord typically ignores certain combinations of activity depending on + what is currently set. This behaviour may change in the future so there are + no guarantees on whether Discord will actually let you set these types. + + .. versionadded:: 1.3 + """ + + __slots__ = ('_created_at',) + + def __init__(self, **kwargs: Any) -> None: + self._created_at: Optional[float] = kwargs.pop('created_at', None) + + @property + def created_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC. + + .. versionadded:: 1.3 + """ + if self._created_at is not None: + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) + + def to_dict(self) -> ActivityPayload: + raise NotImplementedError + -class Activity(_ActivityTag): +class Activity(BaseActivity): """Represents an activity in Discord. This could be an activity such as streaming, playing, listening @@ -99,18 +150,22 @@ class Activity(_ActivityTag): Attributes ------------ - application_id: :class:`int` + application_id: Optional[:class:`int`] The application ID of the game. - name: :class:`str` + name: Optional[:class:`str`] The name of the activity. - url: :class:`str` + url: Optional[:class:`str`] A stream URL that the activity could be doing. type: :class:`ActivityType` The type of activity currently being done. - state: :class:`str` + state: Optional[:class:`str`] The user's current state. For example, "In Game". - details: :class:`str` + details: Optional[:class:`str`] The detail of the user's current activity. + platform: Optional[:class:`str`] + The user's current platform. + + .. versionadded:: 2.4 timestamps: :class:`dict` A dictionary of timestamps. It contains the following optional keys: @@ -125,47 +180,111 @@ class Activity(_ActivityTag): - ``large_image``: A string representing the ID for the large image asset. - ``large_text``: A string representing the text when hovering over the large image asset. + - ``large_url``: A string representing the URL of the large image asset. - ``small_image``: A string representing the ID for the small image asset. - ``small_text``: A string representing the text when hovering over the small image asset. + - ``small_url``: A string representing the URL of the small image asset. party: :class:`dict` A dictionary representing the activity party. It contains the following optional keys: - ``id``: A string representing the party ID. - ``size``: A list of up to two integer elements denoting (current_size, maximum_size). + buttons: List[:class:`str`] + A list of strings representing the labels of custom buttons shown in a rich presence. + + .. versionadded:: 2.0 + + emoji: Optional[:class:`PartialEmoji`] + The emoji that belongs to this activity. + details_url: Optional[:class:`str`] + A URL that is linked to when clicking on the details text of the activity. + + .. versionadded:: 2.6 + state_url: Optional[:class:`str`] + A URL that is linked to when clicking on the state text of the activity. + + .. versionadded:: 2.6 + status_display_type: Optional[:class:`StatusDisplayType`] + Determines which field from the user's status text is displayed + in the members list. + + .. versionadded:: 2.6 """ - __slots__ = ('state', 'details', 'timestamps', 'assets', 'party', - 'flags', 'sync_id', 'session_id', 'type', 'name', 'url', 'application_id') - - def __init__(self, **kwargs): - self.state = kwargs.pop('state', None) - self.details = kwargs.pop('details', None) - self.timestamps = kwargs.pop('timestamps', {}) - self.assets = kwargs.pop('assets', {}) - self.party = kwargs.pop('party', {}) - self.application_id = _get_as_snowflake(kwargs, 'application_id') - self.name = kwargs.pop('name', None) - self.url = kwargs.pop('url', None) - self.flags = kwargs.pop('flags', 0) - self.sync_id = kwargs.pop('sync_id', None) - self.session_id = kwargs.pop('session_id', None) - self.type = try_enum(ActivityType, kwargs.pop('type', -1)) - - def __repr__(self): + __slots__ = ( + 'state', + 'details', + 'timestamps', + 'platform', + 'assets', + 'party', + 'flags', + 'sync_id', + 'session_id', + 'type', + 'name', + 'url', + 'application_id', + 'emoji', + 'buttons', + 'state_url', + 'details_url', + 'status_display_type', + ) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.state: Optional[str] = kwargs.pop('state', None) + self.details: Optional[str] = kwargs.pop('details', None) + self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {}) + self.platform: Optional[str] = kwargs.pop('platform', None) + self.assets: ActivityAssets = kwargs.pop('assets', {}) + self.party: ActivityParty = kwargs.pop('party', {}) + self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id') + self.name: Optional[str] = kwargs.pop('name', None) + self.url: Optional[str] = kwargs.pop('url', None) + self.flags: int = kwargs.pop('flags', 0) + self.sync_id: Optional[str] = kwargs.pop('sync_id', None) + self.session_id: Optional[str] = kwargs.pop('session_id', None) + self.buttons: List[str] = kwargs.pop('buttons', []) + + activity_type = kwargs.pop('type', -1) + self.type: ActivityType = ( + activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) + ) + + emoji = kwargs.pop('emoji', None) + self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None + + self.state_url: Optional[str] = kwargs.pop('state_url', None) + self.details_url: Optional[str] = kwargs.pop('details_url', None) + + status_display_type = kwargs.pop('status_display_type', None) + self.status_display_type: Optional[StatusDisplayType] = ( + status_display_type + if isinstance(status_display_type, StatusDisplayType) + else try_enum(StatusDisplayType, status_display_type) + if status_display_type is not None + else None + ) + + def __repr__(self) -> str: attrs = ( - 'type', - 'name', - 'url', - 'details', - 'application_id', - 'session_id', + ('type', self.type), + ('name', self.name), + ('url', self.url), + ('platform', self.platform), + ('details', self.details), + ('application_id', self.application_id), + ('session_id', self.session_id), + ('emoji', self.emoji), ) - mapped = ' '.join('%s=%r' % (attr, getattr(self, attr)) for attr in attrs) - return '' % mapped + inner = ' '.join('%s=%r' % t for t in attrs) + return f'' - def to_dict(self): - ret = {} + def to_dict(self) -> Dict[str, Any]: + ret: Dict[str, Any] = {} for attr in self.__slots__: value = getattr(self, attr, None) if value is None: @@ -176,61 +295,70 @@ def to_dict(self): ret[attr] = value ret['type'] = int(self.type) + if self.emoji: + ret['emoji'] = self.emoji.to_dict() + if self.status_display_type: + ret['status_display_type'] = int(self.status_display_type.value) return ret @property - def start(self): + def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" try: - return datetime.datetime.utcfromtimestamp(self.timestamps['start'] / 1000) + timestamp = self.timestamps['start'] / 1000 # pyright: ignore[reportTypedDictNotRequiredAccess] except KeyError: return None + else: + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) @property - def end(self): + def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" try: - return datetime.datetime.utcfromtimestamp(self.timestamps['end'] / 1000) + timestamp = self.timestamps['end'] / 1000 # pyright: ignore[reportTypedDictNotRequiredAccess] except KeyError: return None + else: + return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) @property - def large_image_url(self): - """Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable.""" - if self.application_id is None: - return None - + def large_image_url(self) -> Optional[str]: + """Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity, if applicable.""" try: - large_image = self.assets['large_image'] + large_image = self.assets['large_image'] # pyright: ignore[reportTypedDictNotRequiredAccess] except KeyError: return None else: - return 'https://cdn.discordapp.com/app-assets/{0}/{1}.png'.format(self.application_id, large_image) + return self._image_url(large_image) @property - def small_image_url(self): - """Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable.""" - if self.application_id is None: - return None - + def small_image_url(self) -> Optional[str]: + """Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity, if applicable.""" try: - small_image = self.assets['small_image'] + small_image = self.assets['small_image'] # pyright: ignore[reportTypedDictNotRequiredAccess] except KeyError: return None else: - return 'https://cdn.discordapp.com/app-assets/{0}/{1}.png'.format(self.application_id, small_image) + return self._image_url(small_image) + + def _image_url(self, image: str) -> Optional[str]: + if image.startswith('mp:'): + return f'https://media.discordapp.net/{image[3:]}' + elif self.application_id is not None: + return Asset.BASE + f'/app-assets/{self.application_id}/{image}.png' + @property - def large_image_text(self): - """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" + def large_image_text(self) -> Optional[str]: + """Optional[:class:`str`]: Returns the large image asset hover text of this activity, if applicable.""" return self.assets.get('large_text', None) @property - def small_image_text(self): - """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" + def small_image_text(self) -> Optional[str]: + """Optional[:class:`str`]: Returns the small image asset hover text of this activity, if applicable.""" return self.assets.get('small_text', None) -class Game(_ActivityTag): +class Game(BaseActivity): """A slimmed down version of :class:`Activity` that represents a Discord game. This is typically displayed via **Playing** on the official Discord client. @@ -257,69 +385,75 @@ class Game(_ActivityTag): ----------- name: :class:`str` The game's name. - start: Optional[:class:`datetime.datetime`] - A naive UTC timestamp representing when the game started. Keyword-only parameter. Ignored for bots. - end: Optional[:class:`datetime.datetime`] - A naive UTC timestamp representing when the game ends. Keyword-only parameter. Ignored for bots. Attributes ----------- name: :class:`str` The game's name. + platform: Optional[:class:`str`] + Where the user is playing from (ie. PS5, Xbox). + + .. versionadded:: 2.4 + + assets: :class:`dict` + A dictionary representing the images and their hover text of a game. + It contains the following optional keys: + + - ``large_image``: A string representing the ID for the large image asset. + - ``large_text``: A string representing the text when hovering over the large image asset. + - ``small_image``: A string representing the ID for the small image asset. + - ``small_text``: A string representing the text when hovering over the small image asset. + + .. versionadded:: 2.4 """ - __slots__ = ('name', '_end', '_start') + __slots__ = ('name', '_end', '_start', 'platform', 'assets') - def __init__(self, name, **extra): - self.name = name + def __init__(self, name: str, **extra: Any) -> None: + super().__init__(**extra) + self.name: str = name + self.platform: Optional[str] = extra.get('platform') + self.assets: ActivityAssets = extra.get('assets', {}) or {} try: - timestamps = extra['timestamps'] + timestamps: ActivityTimestamps = extra['timestamps'] except KeyError: - self._extract_timestamp(extra, 'start') - self._extract_timestamp(extra, 'end') + self._start = 0 + self._end = 0 else: self._start = timestamps.get('start', 0) self._end = timestamps.get('end', 0) - def _extract_timestamp(self, data, key): - try: - dt = data[key] - except KeyError: - setattr(self, '_' + key, 0) - else: - setattr(self, '_' + key, dt.timestamp() * 1000.0) - @property - def type(self): - """Returns the game's type. This is for compatibility with :class:`Activity`. + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. It always returns :attr:`ActivityType.playing`. """ return ActivityType.playing @property - def start(self): + def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" if self._start: - return datetime.datetime.utcfromtimestamp(self._start / 1000) + return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc) return None @property - def end(self): + def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" if self._end: - return datetime.datetime.utcfromtimestamp(self._end / 1000) + return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc) return None - def __str__(self): + def __str__(self) -> str: return str(self.name) - def __repr__(self): - return ''.format(self) + def __repr__(self) -> str: + return f'' - def to_dict(self): - timestamps = {} + def to_dict(self) -> Dict[str, Any]: + timestamps: Dict[str, Any] = {} if self._start: timestamps['start'] = self._start @@ -329,19 +463,22 @@ def to_dict(self): return { 'type': ActivityType.playing.value, 'name': str(self.name), - 'timestamps': timestamps + 'timestamps': timestamps, + 'platform': str(self.platform) if self.platform else None, + 'assets': self.assets, } - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Game) and other.name == self.name - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) -class Streaming(_ActivityTag): + +class Streaming(BaseActivity): """A slimmed down version of :class:`Activity` that represents a Discord streaming status. This is typically displayed via **Streaming** on the official Discord client. @@ -366,41 +503,53 @@ class Streaming(_ActivityTag): Attributes ----------- - name: :class:`str` + platform: Optional[:class:`str`] + Where the user is streaming from (ie. YouTube, Twitch). + + .. versionadded:: 1.3 + + name: Optional[:class:`str`] The stream's name. - url: :class:`str` - The stream's URL. Currently only twitch.tv URLs are supported. Anything else is silently - discarded. details: Optional[:class:`str`] - If provided, typically the game the streamer is playing. + An alias for :attr:`name` + game: Optional[:class:`str`] + The game being streamed. + + .. versionadded:: 1.3 + + url: :class:`str` + The stream's URL. assets: :class:`dict` A dictionary comprising of similar keys than those in :attr:`Activity.assets`. """ - __slots__ = ('name', 'url', 'details', 'assets') + __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') - def __init__(self, *, name, url, **extra): - self.name = name - self.url = url - self.details = extra.pop('details', None) - self.assets = extra.pop('assets', {}) + def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None: + super().__init__(**extra) + self.platform: Optional[str] = name + self.name: Optional[str] = extra.pop('details', name) + self.game: Optional[str] = extra.pop('state', None) + self.url: str = url + self.details: Optional[str] = extra.pop('details', self.name) # compatibility + self.assets: ActivityAssets = extra.pop('assets', {}) @property - def type(self): - """Returns the game's type. This is for compatibility with :class:`Activity`. + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. It always returns :attr:`ActivityType.streaming`. """ return ActivityType.streaming - def __str__(self): + def __str__(self) -> str: return str(self.name) - def __repr__(self): - return ''.format(self) + def __repr__(self) -> str: + return f'' @property - def twitch_name(self): + def twitch_name(self) -> Optional[str]: """Optional[:class:`str`]: If provided, the twitch name of the user streaming. This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` @@ -408,32 +557,33 @@ def twitch_name(self): """ try: - name = self.assets['large_image'] + name = self.assets['large_image'] # pyright: ignore[reportTypedDictNotRequiredAccess] except KeyError: return None else: return name[7:] if name[:7] == 'twitch:' else None - def to_dict(self): - ret = { + def to_dict(self) -> Dict[str, Any]: + ret: Dict[str, Any] = { 'type': ActivityType.streaming.value, 'name': str(self.name), 'url': str(self.url), - 'assets': self.assets + 'assets': self.assets, } if self.details: ret['details'] = self.details return ret - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Streaming) and other.name == self.name and other.url == self.url - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) + class Spotify: """Represents a Spotify listening activity from Discord. This is a special case of :class:`Activity` that makes it easier to work with the Spotify integration. @@ -457,42 +607,52 @@ class Spotify: Returns the string 'Spotify'. """ - __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id') + __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') - def __init__(self, **data): - self._state = data.pop('state', None) - self._details = data.pop('details', None) - self._timestamps = data.pop('timestamps', {}) - self._assets = data.pop('assets', {}) - self._party = data.pop('party', {}) - self._sync_id = data.pop('sync_id') - self._session_id = data.pop('session_id') + def __init__(self, **data: Any) -> None: + self._state: str = data.pop('state', '') + self._details: str = data.pop('details', '') + self._timestamps: ActivityTimestamps = data.pop('timestamps', {}) + self._assets: ActivityAssets = data.pop('assets', {}) + self._party: ActivityParty = data.pop('party', {}) + self._sync_id: str = data.pop('sync_id', '') + self._session_id: Optional[str] = data.pop('session_id') + self._created_at: Optional[float] = data.pop('created_at', None) @property - def type(self): - """Returns the activity's type. This is for compatibility with :class:`Activity`. + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. It always returns :attr:`ActivityType.listening`. """ return ActivityType.listening @property - def colour(self): - """Returns the Spotify integration colour, as a :class:`Colour`. + def created_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: When the user started listening in UTC. + + .. versionadded:: 1.3 + """ + if self._created_at is not None: + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) + + @property + def colour(self) -> Colour: + """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. - There is an alias for this named :meth:`color`""" - return Colour(0x1db954) + There is an alias for this named :attr:`color`""" + return Colour(0x1DB954) @property - def color(self): - """Returns the Spotify integration colour, as a :class:`Colour`. + def color(self) -> Colour: + """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. - There is an alias for this named :meth:`colour`""" + There is an alias for this named :attr:`colour`""" return self.colour - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { - 'flags': 48, # SYNC | PLAY + 'flags': 48, # SYNC | PLAY 'name': 'Spotify', 'assets': self._assets, 'party': self._party, @@ -500,42 +660,46 @@ def to_dict(self): 'session_id': self._session_id, 'timestamps': self._timestamps, 'details': self._details, - 'state': self._state + 'state': self._state, } @property - def name(self): + def name(self) -> str: """:class:`str`: The activity's name. This will always return "Spotify".""" return 'Spotify' - def __eq__(self, other): - return (isinstance(other, Spotify) and other._session_id == self._session_id - and other._sync_id == self._sync_id and other.start == self.start) + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Spotify) + and other._session_id == self._session_id + and other._sync_id == self._sync_id + and other.start == self.start + ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self._session_id) - def __str__(self): + def __str__(self) -> str: return 'Spotify' - def __repr__(self): - return ''.format(self) + def __repr__(self) -> str: + return f'' @property - def title(self): + def title(self) -> str: """:class:`str`: The title of the song being played.""" return self._details @property - def artists(self): + def artists(self) -> List[str]: """List[:class:`str`]: The artists of the song being played.""" return self._state.split('; ') @property - def artist(self): + def artist(self) -> str: """:class:`str`: The artist of the song being played. This does not attempt to split the artist information into @@ -544,12 +708,12 @@ def artist(self): return self._state @property - def album(self): + def album(self) -> str: """:class:`str`: The album that the song being played belongs to.""" return self._assets.get('large_text', '') @property - def album_cover_url(self): + def album_cover_url(self) -> str: """:class:`str`: The album cover image URL from Spotify's CDN.""" large_image = self._assets.get('large_image', '') if large_image[:8] != 'spotify:': @@ -558,31 +722,153 @@ def album_cover_url(self): return 'https://i.scdn.co/image/' + album_image_id @property - def track_id(self): + def track_id(self) -> str: """:class:`str`: The track ID used by Spotify to identify this song.""" return self._sync_id @property - def start(self): + def track_url(self) -> str: + """:class:`str`: The track URL to listen on Spotify. + + .. versionadded:: 2.0 + """ + return f'https://open.spotify.com/track/{self.track_id}' + + @property + def start(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user started playing this song in UTC.""" - return datetime.datetime.utcfromtimestamp(self._timestamps['start'] / 1000) + # the start key will be present here + return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore @property - def end(self): + def end(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" - return datetime.datetime.utcfromtimestamp(self._timestamps['end'] / 1000) + # the end key will be present here + return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore @property - def duration(self): + def duration(self) -> datetime.timedelta: """:class:`datetime.timedelta`: The duration of the song being played.""" return self.end - self.start @property - def party_id(self): + def party_id(self) -> str: """:class:`str`: The party ID of the listening party.""" return self._party.get('id', '') -def create_activity(data): + +class CustomActivity(BaseActivity): + """Represents a custom activity from Discord. + + .. container:: operations + + .. describe:: x == y + + Checks if two activities are equal. + + .. describe:: x != y + + Checks if two activities are not equal. + + .. describe:: hash(x) + + Returns the activity's hash. + + .. describe:: str(x) + + Returns the custom status text. + + .. versionadded:: 1.3 + + Attributes + ----------- + name: Optional[:class:`str`] + The custom activity's name. + emoji: Optional[:class:`PartialEmoji`] + The emoji to pass to the activity, if any. + """ + + __slots__ = ('name', 'emoji', 'state') + + def __init__( + self, name: Optional[str], *, emoji: Optional[Union[PartialEmoji, Dict[str, Any], str]] = None, **extra: Any + ) -> None: + super().__init__(**extra) + self.name: Optional[str] = name + self.state: Optional[str] = extra.pop('state', name) + if self.name == 'Custom Status': + self.name = self.state + + self.emoji: Optional[PartialEmoji] + if emoji is None: + self.emoji = emoji + elif isinstance(emoji, dict): + self.emoji = PartialEmoji.from_dict(emoji) + elif isinstance(emoji, str): + self.emoji = PartialEmoji(name=emoji) + elif isinstance(emoji, PartialEmoji): + self.emoji = emoji + else: + raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.') + + @property + def type(self) -> ActivityType: + """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.custom`. + """ + return ActivityType.custom + + def to_dict(self) -> Dict[str, Any]: + if self.name == self.state: + o = { + 'type': ActivityType.custom.value, + 'state': self.name, + 'name': 'Custom Status', + } + else: + o = { + 'type': ActivityType.custom.value, + 'name': self.name, + } + + if self.emoji: + o['emoji'] = self.emoji.to_dict() + return o + + def __eq__(self, other: object) -> bool: + return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash((self.name, str(self.emoji))) + + def __str__(self) -> str: + if self.emoji: + if self.name: + return f'{self.emoji} {self.name}' + return str(self.emoji) + else: + return str(self.name) + + def __repr__(self) -> str: + return f'' + + +ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + + +@overload +def create_activity(data: ActivityPayload, state: ConnectionState) -> ActivityTypes: ... + + +@overload +def create_activity(data: None, state: ConnectionState) -> None: ... + + +def create_activity(data: Optional[ActivityPayload], state: ConnectionState) -> Optional[ActivityTypes]: if not data: return None @@ -591,10 +877,24 @@ def create_activity(data): if 'application_id' in data or 'session_id' in data: return Activity(**data) return Game(**data) + elif game_type is ActivityType.custom: + try: + name = data.pop('name') # type: ignore + except KeyError: + ret = Activity(**data) + else: + # we removed the name key from data already + ret = CustomActivity(name=name, **data) # type: ignore elif game_type is ActivityType.streaming: if 'url' in data: - return Streaming(**data) + # the url won't be None here + return Streaming(**data) # type: ignore return Activity(**data) elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: return Spotify(**data) - return Activity(**data) + else: + ret = Activity(**data) + + if isinstance(ret.emoji, PartialEmoji): + ret.emoji._state = state + return ret diff --git a/discord/app_commands/__init__.py b/discord/app_commands/__init__.py new file mode 100644 index 000000000000..a338cab75dc5 --- /dev/null +++ b/discord/app_commands/__init__.py @@ -0,0 +1,21 @@ +""" +discord.app_commands +~~~~~~~~~~~~~~~~~~~~~ + +Application commands support for the Discord API + +:copyright: (c) 2015-present Rapptz +:license: MIT, see LICENSE for more details. + +""" + +from .commands import * +from .errors import * +from .models import * +from .tree import * +from .namespace import * +from .transformers import * +from .translator import * +from .installs import * +from . import checks as checks +from .checks import Cooldown as Cooldown diff --git a/discord/app_commands/checks.py b/discord/app_commands/checks.py new file mode 100644 index 000000000000..0ee65dea6aa0 --- /dev/null +++ b/discord/app_commands/checks.py @@ -0,0 +1,537 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import ( + Any, + Coroutine, + Dict, + Hashable, + Union, + Callable, + TypeVar, + Optional, + TYPE_CHECKING, +) + +import time + +from .commands import check +from .errors import ( + NoPrivateMessage, + MissingRole, + MissingAnyRole, + MissingPermissions, + BotMissingPermissions, + CommandOnCooldown, +) + +from ..user import User +from ..permissions import Permissions +from ..utils import get as utils_get, MISSING, maybe_coroutine + +T = TypeVar('T') + +if TYPE_CHECKING: + from typing_extensions import Self, Unpack + from ..interactions import Interaction + from ..permissions import _PermissionsKwargs + + CooldownFunction = Union[ + Callable[[Interaction[Any]], Coroutine[Any, Any, T]], + Callable[[Interaction[Any]], T], + ] + +__all__ = ( + 'has_role', + 'has_any_role', + 'has_permissions', + 'bot_has_permissions', + 'cooldown', + 'dynamic_cooldown', +) + + +class Cooldown: + """Represents a cooldown for a command. + + .. versionadded:: 2.0 + + Attributes + ----------- + rate: :class:`float` + The total number of tokens available per :attr:`per` seconds. + per: :class:`float` + The length of the cooldown period in seconds. + """ + + __slots__ = ('rate', 'per', '_window', '_tokens', '_last') + + def __init__(self, rate: float, per: float) -> None: + self.rate: int = int(rate) + self.per: float = float(per) + self._window: float = 0.0 + self._tokens: int = self.rate + self._last: float = 0.0 + + def get_tokens(self, current: Optional[float] = None) -> int: + """Returns the number of available tokens before rate limiting is applied. + + Parameters + ------------ + current: Optional[:class:`float`] + The time in seconds since Unix epoch to calculate tokens at. + If not supplied then :func:`time.time()` is used. + + Returns + -------- + :class:`int` + The number of tokens available before the cooldown is to be applied. + """ + if not current: + current = time.time() + + # the calculated tokens should be non-negative + tokens = max(self._tokens, 0) + + if current > self._window + self.per: + tokens = self.rate + return tokens + + def get_retry_after(self, current: Optional[float] = None) -> float: + """Returns the time in seconds until the cooldown will be reset. + + Parameters + ------------- + current: Optional[:class:`float`] + The current time in seconds since Unix epoch. + If not supplied, then :func:`time.time()` is used. + + Returns + ------- + :class:`float` + The number of seconds to wait before this cooldown will be reset. + """ + current = current or time.time() + tokens = self.get_tokens(current) + + if tokens == 0: + return self.per - (current - self._window) + + return 0.0 + + def update_rate_limit(self, current: Optional[float] = None, *, tokens: int = 1) -> Optional[float]: + """Updates the cooldown rate limit. + + Parameters + ------------- + current: Optional[:class:`float`] + The time in seconds since Unix epoch to update the rate limit at. + If not supplied, then :func:`time.time()` is used. + tokens: :class:`int` + The amount of tokens to deduct from the rate limit. + + Returns + ------- + Optional[:class:`float`] + The retry-after time in seconds if rate limited. + """ + current = current or time.time() + self._last = current + + self._tokens = self.get_tokens(current) + + # first token used means that we start a new rate limit window + if self._tokens == self.rate: + self._window = current + + # decrement tokens by specified number + self._tokens -= tokens + + # check if we are rate limited and return retry-after + if self._tokens < 0: + return self.per - (current - self._window) + + def reset(self) -> None: + """Reset the cooldown to its initial state.""" + self._tokens = self.rate + self._last = 0.0 + + def copy(self) -> Self: + """Creates a copy of this cooldown. + + Returns + -------- + :class:`Cooldown` + A new instance of this cooldown. + """ + return self.__class__(self.rate, self.per) + + def __repr__(self) -> str: + return f'' + + +def has_role(item: Union[int, str], /) -> Callable[[T], T]: + """A :func:`~discord.app_commands.check` that is added that checks if the member invoking the + command has the role specified via the name or ID specified. + + If a string is specified, you must give the exact name of the role, including + caps and spelling. + + If an integer is specified, you must give the exact snowflake ID of the role. + + This check raises one of two special exceptions, :exc:`~discord.app_commands.MissingRole` + if the user is missing a role, or :exc:`~discord.app_commands.NoPrivateMessage` if + it is used in a private message. Both inherit from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + .. note:: + + This is different from the permission system that Discord provides for application + commands. This is done entirely locally in the program rather than being handled + by Discord. + + Parameters + ----------- + item: Union[:class:`int`, :class:`str`] + The name or ID of the role to check. + """ + + def predicate(interaction: Interaction) -> bool: + if isinstance(interaction.user, User): + raise NoPrivateMessage() + + if isinstance(item, int): + role = interaction.user.get_role(item) + else: + role = utils_get(interaction.user.roles, name=item) + + if role is None: + raise MissingRole(item) + return True + + return check(predicate) + + +def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: + r"""A :func:`~discord.app_commands.check` that is added that checks if the member + invoking the command has **any** of the roles specified. This means that if they have + one out of the three roles specified, then this check will return ``True``. + + Similar to :func:`has_role`\, the names or IDs passed in must be exact. + + This check raises one of two special exceptions, :exc:`~discord.app_commands.MissingAnyRole` + if the user is missing all roles, or :exc:`~discord.app_commands.NoPrivateMessage` if + it is used in a private message. Both inherit from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + .. note:: + + This is different from the permission system that Discord provides for application + commands. This is done entirely locally in the program rather than being handled + by Discord. + + Parameters + ----------- + items: List[Union[:class:`str`, :class:`int`]] + An argument list of names or IDs to check that the member has roles wise. + + Example + -------- + + .. code-block:: python3 + + @tree.command() + @app_commands.checks.has_any_role('Library Devs', 'Moderators', 492212595072434186) + async def cool(interaction: discord.Interaction): + await interaction.response.send_message('You are cool indeed') + """ + + def predicate(interaction: Interaction) -> bool: + if isinstance(interaction.user, User): + raise NoPrivateMessage() + + if any( + interaction.user.get_role(item) is not None + if isinstance(item, int) + else utils_get(interaction.user.roles, name=item) is not None + for item in items + ): + return True + raise MissingAnyRole(list(items)) + + return check(predicate) + + +def has_permissions(**perms: Unpack[_PermissionsKwargs]) -> Callable[[T], T]: + r"""A :func:`~discord.app_commands.check` that is added that checks if the member + has all of the permissions necessary. + + Note that this check operates on the permissions given by + :attr:`discord.Interaction.permissions`. + + The permissions passed in must be exactly like the properties shown under + :class:`discord.Permissions`. + + This check raises a special exception, :exc:`~discord.app_commands.MissingPermissions` + that is inherited from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + .. note:: + + This is different from the permission system that Discord provides for application + commands. This is done entirely locally in the program rather than being handled + by Discord. + + Parameters + ------------ + \*\*perms: :class:`bool` + Keyword arguments denoting the permissions to check for. + + Example + --------- + + .. code-block:: python3 + + @tree.command() + @app_commands.checks.has_permissions(manage_messages=True) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('You can manage messages.') + + """ + + invalid = perms.keys() - Permissions.VALID_FLAGS.keys() + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(interaction: Interaction) -> bool: + permissions = interaction.permissions + + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return True + + raise MissingPermissions(missing) + + return check(predicate) + + +def bot_has_permissions(**perms: Unpack[_PermissionsKwargs]) -> Callable[[T], T]: + """Similar to :func:`has_permissions` except checks if the bot itself has + the permissions listed. This relies on :attr:`discord.Interaction.app_permissions`. + + This check raises a special exception, :exc:`~discord.app_commands.BotMissingPermissions` + that is inherited from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + """ + + invalid = set(perms) - set(Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(interaction: Interaction) -> bool: + permissions = interaction.app_permissions + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return check(predicate) + + +def _create_cooldown_decorator( + key: CooldownFunction[Hashable], factory: CooldownFunction[Optional[Cooldown]] +) -> Callable[[T], T]: + mapping: Dict[Any, Cooldown] = {} + + async def get_bucket( + interaction: Interaction, + *, + mapping: Dict[Any, Cooldown] = mapping, + key: CooldownFunction[Hashable] = key, + factory: CooldownFunction[Optional[Cooldown]] = factory, + ) -> Optional[Cooldown]: + current = interaction.created_at.timestamp() + dead_keys = [k for k, v in mapping.items() if current > v._last + v.per] + for k in dead_keys: + del mapping[k] + + k = await maybe_coroutine(key, interaction) + if k not in mapping: + bucket: Optional[Cooldown] = await maybe_coroutine(factory, interaction) + if bucket is not None: + mapping[k] = bucket + else: + bucket = mapping[k] + + return bucket + + async def predicate(interaction: Interaction) -> bool: + bucket = await get_bucket(interaction) + if bucket is None: + return True + + retry_after = bucket.update_rate_limit(interaction.created_at.timestamp()) + if retry_after is None: + return True + + raise CommandOnCooldown(bucket, retry_after) + + return check(predicate) + + +def cooldown( + rate: float, + per: float, + *, + key: Optional[CooldownFunction[Hashable]] = MISSING, +) -> Callable[[T], T]: + """A decorator that adds a cooldown to a command. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns are based off + of the ``key`` function provided. If a ``key`` is not provided + then it defaults to a user-level cooldown. The ``key`` function + must take a single parameter, the :class:`discord.Interaction` and + return a value that is used as a key to the internal cooldown mapping. + + The ``key`` function can optionally be a coroutine. + + If a cooldown is triggered, then :exc:`~discord.app_commands.CommandOnCooldown` is + raised to the error handlers. + + Examples + --------- + + Setting a one per 5 seconds per member cooldown on a command: + + .. code-block:: python3 + + @tree.command() + @app_commands.checks.cooldown(1, 5.0, key=lambda i: (i.guild_id, i.user.id)) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('Hello') + + @test.error + async def on_test_error(interaction: discord.Interaction, error: app_commands.AppCommandError): + if isinstance(error, app_commands.CommandOnCooldown): + await interaction.response.send_message(str(error), ephemeral=True) + + Parameters + ------------ + rate: :class:`int` + The number of times a command can be used before triggering a cooldown. + per: :class:`float` + The amount of seconds to wait for a cooldown when it's been triggered. + key: Optional[Callable[[:class:`discord.Interaction`], :class:`collections.abc.Hashable`]] + A function that returns a key to the mapping denoting the type of cooldown. + Can optionally be a coroutine. If not given then defaults to a user-level + cooldown. If ``None`` is passed then it is interpreted as a "global" cooldown. + """ + + if key is MISSING: + key_func = lambda interaction: interaction.user.id + elif key is None: + key_func = lambda i: None + else: + key_func = key + + factory = lambda interaction: Cooldown(rate, per) + + return _create_cooldown_decorator(key_func, factory) + + +def dynamic_cooldown( + factory: CooldownFunction[Optional[Cooldown]], + *, + key: Optional[CooldownFunction[Hashable]] = MISSING, +) -> Callable[[T], T]: + """A decorator that adds a dynamic cooldown to a command. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns are based off + of the ``key`` function provided. If a ``key`` is not provided + then it defaults to a user-level cooldown. The ``key`` function + must take a single parameter, the :class:`discord.Interaction` and + return a value that is used as a key to the internal cooldown mapping. + + If a ``factory`` function is given, it must be a function that + accepts a single parameter of type :class:`discord.Interaction` and must + return a :class:`~discord.app_commands.Cooldown` or ``None``. + If ``None`` is returned then that cooldown is effectively bypassed. + + Both ``key`` and ``factory`` can optionally be coroutines. + + If a cooldown is triggered, then :exc:`~discord.app_commands.CommandOnCooldown` is + raised to the error handlers. + + Examples + --------- + + Setting a cooldown for everyone but the owner. + + .. code-block:: python3 + + def cooldown_for_everyone_but_me(interaction: discord.Interaction) -> Optional[app_commands.Cooldown]: + if interaction.user.id == 80088516616269824: + return None + return app_commands.Cooldown(1, 10.0) + + @tree.command() + @app_commands.checks.dynamic_cooldown(cooldown_for_everyone_but_me) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('Hello') + + @test.error + async def on_test_error(interaction: discord.Interaction, error: app_commands.AppCommandError): + if isinstance(error, app_commands.CommandOnCooldown): + await interaction.response.send_message(str(error), ephemeral=True) + + Parameters + ------------ + factory: Optional[Callable[[:class:`discord.Interaction`], Optional[:class:`~discord.app_commands.Cooldown`]]] + A function that takes an interaction and returns a cooldown that will apply to that interaction + or ``None`` if the interaction should not have a cooldown. + key: Optional[Callable[[:class:`discord.Interaction`], :class:`collections.abc.Hashable`]] + A function that returns a key to the mapping denoting the type of cooldown. + Can optionally be a coroutine. If not given then defaults to a user-level + cooldown. If ``None`` is passed then it is interpreted as a "global" cooldown. + """ + + if key is MISSING: + key_func = lambda interaction: interaction.user.id + elif key is None: + key_func = lambda i: None + else: + key_func = key + + return _create_cooldown_decorator(key_func, factory) diff --git a/discord/app_commands/commands.py b/discord/app_commands/commands.py new file mode 100644 index 000000000000..36d07d41c4e6 --- /dev/null +++ b/discord/app_commands/commands.py @@ -0,0 +1,2884 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +import inspect + +from typing import ( + Any, + Callable, + ClassVar, + Coroutine, + Dict, + Generator, + Generic, + List, + MutableMapping, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import re +from copy import copy as shallow_copy + +from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale +from .installs import AppCommandContext, AppInstallationType +from .models import Choice +from .transformers import annotation_to_parameter, CommandParameter, NoneType +from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered +from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str +from ..message import Message +from ..user import User +from ..member import Member +from ..permissions import Permissions +from ..utils import resolve_annotation, MISSING, is_inside_class, maybe_coroutine, async_all, _shorten, _to_kebab_case + +if TYPE_CHECKING: + from typing_extensions import ParamSpec, Concatenate, Unpack + from ..interactions import Interaction + from ..abc import Snowflake + from .namespace import Namespace + from .models import ChoiceT + from .tree import CommandTree + from .._types import ClientT + + # Generally, these two libraries are supposed to be separate from each other. + # However, for type hinting purposes it's unfortunately necessary for one to + # reference the other to prevent type checking errors in callbacks + from discord.ext import commands + from discord.permissions import _PermissionsKwargs + + ErrorFunc = Callable[[Interaction, AppCommandError], Coroutine[Any, Any, None]] + +__all__ = ( + 'Command', + 'ContextMenu', + 'Group', + 'Parameter', + 'context_menu', + 'command', + 'describe', + 'check', + 'rename', + 'choices', + 'autocomplete', + 'guilds', + 'guild_only', + 'dm_only', + 'private_channel_only', + 'allowed_contexts', + 'guild_install', + 'user_install', + 'allowed_installs', + 'default_permissions', +) + +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') + +T = TypeVar('T') +F = TypeVar('F', bound=Callable[..., Any]) +GroupT = TypeVar('GroupT', bound='Binding') +Coro = Coroutine[Any, Any, T] +UnboundError = Callable[['Interaction[Any]', AppCommandError], Coro[Any]] +Error = Union[ + Callable[[GroupT, 'Interaction[Any]', AppCommandError], Coro[Any]], + UnboundError, +] +Check = Callable[['Interaction[Any]'], Union[bool, Coro[bool]]] +Binding = Union['Group', 'commands.Cog'] + + +if TYPE_CHECKING: + CommandCallback = Union[ + Callable[Concatenate[GroupT, 'Interaction[Any]', P], Coro[T]], + Callable[Concatenate['Interaction[Any]', P], Coro[T]], + ] + + ContextMenuCallback = Union[ + # If groups end up support context menus these would be uncommented + # Callable[[GroupT, 'Interaction', Member], Coro[Any]], + # Callable[[GroupT, 'Interaction', User], Coro[Any]], + # Callable[[GroupT, 'Interaction', Message], Coro[Any]], + # Callable[[GroupT, 'Interaction', Union[Member, User]], Coro[Any]], + Callable[['Interaction[Any]', Member], Coro[Any]], + Callable[['Interaction[Any]', User], Coro[Any]], + Callable[['Interaction[Any]', Message], Coro[Any]], + Callable[['Interaction[Any]', Union[Member, User]], Coro[Any]], + ] + + AutocompleteCallback = Union[ + Callable[[GroupT, 'Interaction[Any]', str], Coro[List[Choice[ChoiceT]]]], + Callable[['Interaction[Any]', str], Coro[List[Choice[ChoiceT]]]], + ] +else: + CommandCallback = Callable[..., Coro[T]] + ContextMenuCallback = Callable[..., Coro[T]] + AutocompleteCallback = Callable[..., Coro[T]] + + +CheckInputParameter = Union['Command[Any, ..., Any]', 'ContextMenu', 'CommandCallback[Any, ..., Any]', ContextMenuCallback] + +# The re module doesn't support \p{} so we have to list characters from Thai and Devanagari manually. +THAI_COMBINING = r'\u0e31-\u0e3a\u0e47-\u0e4e' +DEVANAGARI_COMBINING = r'\u0900-\u0903\u093a\u093b\u093c\u093e\u093f\u0940-\u094f\u0955\u0956\u0957\u0962\u0963' +VALID_SLASH_COMMAND_NAME = re.compile(r'^[-_\w' + THAI_COMBINING + DEVANAGARI_COMBINING + r']{1,32}$') + +ARG_NAME_SUBREGEX = r'(?:\\?\*){0,2}(?P\w+)' + +ARG_DESCRIPTION_SUBREGEX = r'(?P(?:.|\n)+?(?:\Z|\r?\n(?=[\S\r\n])))' + +ARG_TYPE_SUBREGEX = r'(?:.+)' + +GOOGLE_DOCSTRING_ARG_REGEX = re.compile( + rf'^{ARG_NAME_SUBREGEX}[ \t]*(?:\({ARG_TYPE_SUBREGEX}\))?[ \t]*:[ \t]*{ARG_DESCRIPTION_SUBREGEX}', + re.MULTILINE, +) + +SPHINX_DOCSTRING_ARG_REGEX = re.compile( + rf'^:param {ARG_NAME_SUBREGEX}:[ \t]+{ARG_DESCRIPTION_SUBREGEX}', + re.MULTILINE, +) + +NUMPY_DOCSTRING_ARG_REGEX = re.compile( + rf'^{ARG_NAME_SUBREGEX}(?:[ \t]*:)?(?:[ \t]+{ARG_TYPE_SUBREGEX})?[ \t]*\r?\n[ \t]+{ARG_DESCRIPTION_SUBREGEX}', + re.MULTILINE, +) + + +def _parse_args_from_docstring(func: Callable[..., Any], params: Dict[str, CommandParameter]) -> Dict[str, str]: + docstring = inspect.getdoc(func) + + if docstring is None: + return {} + + # Extract the arguments + # Note: These are loose regexes, but they are good enough for our purposes + # For Google-style, look only at the lines that are indented + section_lines = inspect.cleandoc('\n'.join(line for line in docstring.splitlines() if line.startswith(' '))) + docstring_styles = ( + GOOGLE_DOCSTRING_ARG_REGEX.finditer(section_lines), + SPHINX_DOCSTRING_ARG_REGEX.finditer(docstring), + NUMPY_DOCSTRING_ARG_REGEX.finditer(docstring), + ) + + return { + m.group('name'): m.group('description') for matches in docstring_styles for m in matches if m.group('name') in params + } + + +def validate_name(name: str) -> str: + match = VALID_SLASH_COMMAND_NAME.match(name) + if match is None: + raise ValueError( + f'{name!r} must be between 1-32 characters and contain only lower-case letters, numbers, hyphens, or underscores.' + ) + + # Ideally, name.islower() would work instead but since certain characters + # are Lo (e.g. CJK) those don't pass the test. I'd use `casefold` instead as + # well, but chances are the server-side check is probably something similar to + # this code anyway. + if name.lower() != name: + raise ValueError(f'{name!r} must be all lower-case') + return name + + +def validate_context_menu_name(name: str) -> str: + if not name or len(name) > 32: + raise ValueError('context menu names must be between 1-32 characters') + return name + + +def validate_auto_complete_callback( + callback: AutocompleteCallback[GroupT, ChoiceT], +) -> AutocompleteCallback[GroupT, ChoiceT]: + # This function needs to ensure the following is true: + # If self.foo is passed then don't pass command.binding to the callback + # If Class.foo is passed then it is assumed command.binding has to be passed + # If free_function_foo is passed then no binding should be passed at all + # Passing command.binding is mandated by pass_command_binding + + binding = getattr(callback, '__self__', None) + pass_command_binding = binding is None and is_inside_class(callback) + + # 'method' objects can't have dynamic attributes + if binding is None: + callback.pass_command_binding = pass_command_binding + + required_parameters = 2 + pass_command_binding + params = inspect.signature(callback).parameters + if len(params) != required_parameters: + raise TypeError(f'autocomplete callback {callback.__qualname__!r} requires either 2 or 3 parameters to be passed') + + return callback + + +def _context_menu_annotation(annotation: Any, *, _none: type = NoneType) -> AppCommandType: + if annotation is Message: + return AppCommandType.message + + supported_types: Set[Any] = {Member, User} + if annotation in supported_types: + return AppCommandType.user + + # Check if there's an origin + origin = getattr(annotation, '__origin__', None) + if origin is not Union: + # Only Union is supported so bail early + msg = ( + f'unsupported type annotation {annotation!r}, must be either discord.Member, ' + 'discord.User, discord.Message, or a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + # Only Union[Member, User] is supported + if not all(arg in supported_types for arg in annotation.__args__): + raise TypeError(f'unsupported types given inside {annotation!r}') + + return AppCommandType.user + + +def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Dict[str, Any]) -> None: + for name, param in params.items(): + description = descriptions.pop(name, MISSING) + if description is MISSING: + param.description = '…' + continue + + if not isinstance(description, (str, locale_str)): + raise TypeError('description must be a string') + + if isinstance(description, str): + param.description = _shorten(description) + else: + param.description = description + + if descriptions: + first = next(iter(descriptions)) + raise TypeError(f'unknown parameter given: {first}') + + +def _populate_renames(params: Dict[str, CommandParameter], renames: Dict[str, Union[str, locale_str]]) -> None: + rename_map: Dict[str, Union[str, locale_str]] = {} + + # original name to renamed name + + for name in params.keys(): + new_name = renames.pop(name, MISSING) + + if new_name is MISSING: + rename_map[name] = name + continue + + if name in rename_map: + raise ValueError(f'{new_name} is already used') + + if isinstance(new_name, str): + new_name = validate_name(new_name) + else: + validate_name(new_name.message) + + rename_map[name] = new_name + params[name]._rename = new_name + + if renames: + first = next(iter(renames)) + raise ValueError(f'unknown parameter given: {first}') + + +def _populate_choices(params: Dict[str, CommandParameter], all_choices: Dict[str, List[Choice]]) -> None: + for name, param in params.items(): + choices = all_choices.pop(name, MISSING) + if choices is MISSING: + continue + + if not isinstance(choices, list): + raise TypeError('choices must be a list of Choice') + + if not all(isinstance(choice, Choice) for choice in choices): + raise TypeError('choices must be a list of Choice') + + if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): + raise TypeError('choices are only supported for integer, string, or number option types') + + if not all(param.type == choice._option_type for choice in choices): + raise TypeError('choices must all have the same inner option type as the parameter choice type') + + param.choices = choices + + if all_choices: + first = next(iter(all_choices)) + raise TypeError(f'unknown parameter given: {first}') + + +def _populate_autocomplete(params: Dict[str, CommandParameter], autocomplete: Dict[str, Any]) -> None: + for name, param in params.items(): + callback = autocomplete.pop(name, MISSING) + if callback is MISSING: + continue + + if not inspect.iscoroutinefunction(callback): + raise TypeError('autocomplete callback must be a coroutine function') + + if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): + raise TypeError('autocomplete is only supported for integer, string, or number option types') + + if param.is_choice_annotation(): + raise TypeError( + 'Choice annotation unsupported for autocomplete parameters, consider using a regular annotation instead' + ) + + param.autocomplete = validate_auto_complete_callback(callback) + + if autocomplete: + first = next(iter(autocomplete)) + raise TypeError(f'unknown parameter given: {first}') + + +def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]: + params = inspect.signature(func).parameters + cache = {} + required_params = is_inside_class(func) + 1 + if len(params) < required_params: + raise TypeError(f'callback {func.__qualname__!r} must have more than {required_params - 1} parameter(s)') + + iterator = iter(params.values()) + for _ in range(0, required_params): + next(iterator) + + parameters: List[CommandParameter] = [] + for parameter in iterator: + if parameter.annotation is parameter.empty: + raise TypeError(f'parameter {parameter.name!r} is missing a type annotation in callback {func.__qualname__!r}') + + resolved = resolve_annotation(parameter.annotation, globalns, globalns, cache) + param = annotation_to_parameter(resolved, parameter) + parameters.append(param) + + values = sorted(parameters, key=lambda a: a.required, reverse=True) + result = {v.name: v for v in values} + + descriptions = _parse_args_from_docstring(func, result) + + try: + descriptions.update(func.__discord_app_commands_param_description__) + except AttributeError: + for param in values: + if param.description is MISSING: + param.description = '…' + if descriptions: + _populate_descriptions(result, descriptions) + + try: + renames = func.__discord_app_commands_param_rename__ + except AttributeError: + pass + else: + _populate_renames(result, renames.copy()) + + try: + choices = func.__discord_app_commands_param_choices__ + except AttributeError: + pass + else: + _populate_choices(result, choices.copy()) + + try: + autocomplete = func.__discord_app_commands_param_autocomplete__ + except AttributeError: + pass + else: + _populate_autocomplete(result, autocomplete.copy()) + + return result + + +def _get_context_menu_parameter(func: ContextMenuCallback) -> Tuple[str, Any, AppCommandType]: + params = inspect.signature(func).parameters + if is_inside_class(func) and not hasattr(func, '__self__'): + raise TypeError('context menus cannot be defined inside a class') + + if len(params) != 2: + msg = ( + f'context menu callback {func.__qualname__!r} requires 2 parameters, ' + 'the first one being the interaction and the other one explicitly ' + 'annotated with either discord.Message, discord.User, discord.Member, ' + 'or a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + iterator = iter(params.values()) + next(iterator) # skip interaction + parameter = next(iterator) + if parameter.annotation is parameter.empty: + msg = ( + f'second parameter of context menu callback {func.__qualname__!r} must be explicitly ' + 'annotated with either discord.Message, discord.User, discord.Member, or ' + 'a typing.Union of discord.Member and discord.User' + ) + raise TypeError(msg) + + resolved = resolve_annotation(parameter.annotation, func.__globals__, func.__globals__, {}) + type = _context_menu_annotation(resolved) + return (parameter.name, resolved, type) + + +def mark_overrideable(func: F) -> F: + func.__discord_app_commands_base_function__ = None + return func + + +class Parameter: + """A class that contains the parameter information of a :class:`Command` callback. + + .. versionadded:: 2.0 + + Attributes + ----------- + name: :class:`str` + The name of the parameter. This is the Python identifier for the parameter. + display_name: :class:`str` + The displayed name of the parameter on Discord. + description: :class:`str` + The description of the parameter. + autocomplete: :class:`bool` + Whether the parameter has an autocomplete handler. + locale_name: Optional[:class:`locale_str`] + The display name's locale string, if available. + locale_description: Optional[:class:`locale_str`] + The description's locale string, if available. + required: :class:`bool` + Whether the parameter is required + choices: List[:class:`~discord.app_commands.Choice`] + A list of choices this parameter takes, if any. + type: :class:`~discord.AppCommandOptionType` + The underlying type of this parameter. + channel_types: List[:class:`~discord.ChannelType`] + The channel types that are allowed for this parameter. + min_value: Optional[Union[:class:`int`, :class:`float`]] + The minimum supported value for this parameter. + max_value: Optional[Union[:class:`int`, :class:`float`]] + The maximum supported value for this parameter. + default: Any + The default value of the parameter, if given. + If not given then this is :data:`~discord.utils.MISSING`. + command: :class:`Command` + The command this parameter is attached to. + """ + + def __init__(self, parent: CommandParameter, command: Command[Any, ..., Any]) -> None: + self.__parent: CommandParameter = parent + self.__command: Command[Any, ..., Any] = command + + @property + def command(self) -> Command[Any, ..., Any]: + return self.__command + + @property + def name(self) -> str: + return self.__parent.name + + @property + def display_name(self) -> str: + return self.__parent.display_name + + @property + def required(self) -> bool: + return self.__parent.required + + @property + def description(self) -> str: + return str(self.__parent.description) + + @property + def locale_name(self) -> Optional[locale_str]: + if isinstance(self.__parent._rename, locale_str): + return self.__parent._rename + return None + + @property + def locale_description(self) -> Optional[locale_str]: + if isinstance(self.__parent.description, locale_str): + return self.__parent.description + return None + + @property + def autocomplete(self) -> bool: + return self.__parent.autocomplete is not None + + @property + def default(self) -> Any: + return self.__parent.default + + @property + def type(self) -> AppCommandOptionType: + return self.__parent.type + + @property + def choices(self) -> List[Choice[Union[int, float, str]]]: + choices = self.__parent.choices + if choices is MISSING: + return [] + return choices.copy() + + @property + def channel_types(self) -> List[ChannelType]: + channel_types = self.__parent.channel_types + if channel_types is MISSING: + return [] + return channel_types.copy() + + @property + def min_value(self) -> Optional[Union[int, float]]: + return self.__parent.min_value + + @property + def max_value(self) -> Optional[Union[int, float]]: + return self.__parent.max_value + + +class Command(Generic[GroupT, P, T]): + """A class that implements an application command. + + These are usually not created manually, instead they are created using + one of the following decorators: + + - :func:`~discord.app_commands.command` + - :meth:`Group.command ` + - :meth:`CommandTree.command ` + + .. versionadded:: 2.0 + + Parameters + ----------- + name: Union[:class:`str`, :class:`locale_str`] + The name of the application command. + description: Union[:class:`str`, :class:`locale_str`] + The description of the application command. This shows up in the UI to describe + the application command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + parent: Optional[:class:`Group`] + The parent application command. ``None`` if there isn't one. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + + Attributes + ------------ + name: :class:`str` + The name of the application command. + description: :class:`str` + The description of the application command. This shows up in the UI to describe + the application command. + checks + A list of predicates that take a :class:`~discord.Interaction` parameter + to indicate whether the command callback should be executed. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`AppCommandError` should be used. If all the checks fail without + propagating an exception, :exc:`CheckFailure` is raised. + default_permissions: Optional[:class:`~discord.Permissions`] + The default permissions that can execute this command on Discord. Note + that server administrators can override this value in the client. + Setting an empty permissions field will disallow anyone except server + administrators from using the command in a guild. + + Due to a Discord limitation, this does not work on subcommands. + guild_only: :class:`bool` + Whether the command should only be usable in guild contexts. + + Due to a Discord limitation, this does not work on subcommands. + allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`] + The contexts that the command is allowed to be used in. + Overrides ``guild_only`` if this is set. + + .. versionadded:: 2.4 + allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`] + The installation contexts that the command is allowed to be installed + on. + + .. versionadded:: 2.4 + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + + Due to a Discord limitation, this does not work on subcommands. + parent: Optional[:class:`Group`] + The parent application command. ``None`` if there isn't one. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def __init__( + self, + *, + name: Union[str, locale_str], + description: Union[str, locale_str], + callback: CommandCallback[GroupT, P, T], + nsfw: bool = False, + parent: Optional[Group] = None, + guild_ids: Optional[List[int]] = None, + allowed_contexts: Optional[AppCommandContext] = None, + allowed_installs: Optional[AppInstallationType] = None, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, + ): + name, locale = (name.message, name) if isinstance(name, locale_str) else (name, None) + self.name: str = validate_name(name) + self._locale_name: Optional[locale_str] = locale + description, locale = ( + (description.message, description) if isinstance(description, locale_str) else (description, None) + ) + self.description: str = description + self._locale_description: Optional[locale_str] = locale + self._attr: Optional[str] = None + self._callback: CommandCallback[GroupT, P, T] = callback + self.parent: Optional[Group] = parent + self.binding: Optional[GroupT] = None + self.on_error: Optional[Error[GroupT]] = None + self.module: Optional[str] = callback.__module__ + + # Unwrap __self__ for bound methods + try: + self.binding = callback.__self__ + self._callback = callback = callback.__func__ + except AttributeError: + pass + + self._params: Dict[str, CommandParameter] = _extract_parameters_from_callback(callback, callback.__globals__) + self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', []) + self._guild_ids: Optional[List[int]] = guild_ids + if self._guild_ids is None: + self._guild_ids = getattr(callback, '__discord_app_commands_default_guilds__', None) + self.default_permissions: Optional[Permissions] = getattr( + callback, '__discord_app_commands_default_permissions__', None + ) + self.guild_only: bool = getattr(callback, '__discord_app_commands_guild_only__', False) + self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts or getattr( + callback, '__discord_app_commands_contexts__', None + ) + self.allowed_installs: Optional[AppInstallationType] = allowed_installs or getattr( + callback, '__discord_app_commands_installation_types__', None + ) + + self.nsfw: bool = nsfw + self.extras: Dict[Any, Any] = extras or {} + + if self._guild_ids is not None and self.parent is not None: + raise ValueError('child commands cannot have default guilds set, consider setting them in the parent instead') + + if auto_locale_strings: + self._convert_to_locale_strings() + + def _convert_to_locale_strings(self) -> None: + if self._locale_name is None: + self._locale_name = locale_str(self.name) + if self._locale_description is None: + self._locale_description = locale_str(self.description) + + for param in self._params.values(): + param._convert_to_locale_strings() + + def __set_name__(self, owner: Type[Any], name: str) -> None: + self._attr = name + + @property + def callback(self) -> CommandCallback[GroupT, P, T]: + """:ref:`coroutine `: The coroutine that is executed when the command is called.""" + return self._callback + + def _copy_with( + self, + *, + parent: Optional[Group], + binding: GroupT, + bindings: MutableMapping[GroupT, GroupT] = MISSING, + set_on_binding: bool = True, + ) -> Command: + bindings = {} if bindings is MISSING else bindings + + copy = shallow_copy(self) + copy._params = self._params.copy() + copy.parent = parent + copy.binding = bindings.get(self.binding) if self.binding is not None else binding + + if copy._attr and set_on_binding: + setattr(copy.binding, copy._attr, copy) + + return copy + + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) + name_localizations: Dict[str, str] = {} + description_localizations: Dict[str, str] = {} + + # Prevent creating these objects in a heavy loop + name_context = TranslationContext(location=TranslationContextLocation.command_name, data=self) + description_context = TranslationContext(location=TranslationContextLocation.command_description, data=self) + + for locale in Locale: + if self._locale_name: + translation = await translator._checked_translate(self._locale_name, locale, name_context) + if translation is not None: + name_localizations[locale.value] = translation + + if self._locale_description: + translation = await translator._checked_translate(self._locale_description, locale, description_context) + if translation is not None: + description_localizations[locale.value] = translation + + base['name_localizations'] = name_localizations + base['description_localizations'] = description_localizations + base['options'] = [ + await param.get_translated_payload(translator, Parameter(param, self)) for param in self._params.values() + ] + return base + + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: + # If we have a parent then our type is a subcommand + # Otherwise, the type falls back to the specific command type (e.g. slash command or context menu) + option_type = AppCommandType.chat_input.value if self.parent is None else AppCommandOptionType.subcommand.value + base: Dict[str, Any] = { + 'name': self.name, + 'description': self.description, + 'type': option_type, + 'options': [param.to_dict() for param in self._params.values()], + } + + if self.parent is None: + base['nsfw'] = self.nsfw + base['dm_permission'] = not self.guild_only + base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value + base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts) + base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs) + + return base + + async def _invoke_error_handlers(self, interaction: Interaction, error: AppCommandError) -> None: + # These type ignores are because the type checker can't narrow this type properly. + if self.on_error is not None: + if self.binding is not None: + await self.on_error(self.binding, interaction, error) # type: ignore + else: + await self.on_error(interaction, error) # type: ignore + + parent = self.parent + if parent is not None: + await parent.on_error(interaction, error) + + if parent.parent is not None: + await parent.parent.on_error(interaction, error) + + binding_error_handler = getattr(self.binding, '__discord_app_commands_error_handler__', None) + if binding_error_handler is not None: + await binding_error_handler(interaction, error) + + def _has_any_error_handlers(self) -> bool: + if self.on_error is not None: + return True + + parent = self.parent + if parent is not None: + # Check if the on_error is overridden + if not hasattr(parent.on_error, '__discord_app_commands_base_function__'): + return True + + if parent.parent is not None: + if not hasattr(parent.parent.on_error, '__discord_app_commands_base_function__'): + return True + + # Check if we have a bound error handler + if getattr(self.binding, '__discord_app_commands_error_handler__', None) is not None: + return True + + return False + + async def _transform_arguments(self, interaction: Interaction, namespace: Namespace) -> Dict[str, Any]: + values = namespace.__dict__ + transformed_values = {} + + for param in self._params.values(): + try: + value = values[param.display_name] + except KeyError: + if not param.required: + transformed_values[param.name] = param.default + else: + raise CommandSignatureMismatch(self) from None + else: + transformed_values[param.name] = await param.transform(interaction, value) + + return transformed_values + + async def _do_call(self, interaction: Interaction, params: Dict[str, Any]) -> T: + # These type ignores are because the type checker doesn't quite understand the narrowing here + # Likewise, it thinks we're missing positional arguments when there aren't any. + try: + if self.binding is not None: + return await self._callback(self.binding, interaction, **params) # type: ignore + return await self._callback(interaction, **params) # type: ignore + except TypeError as e: + # In order to detect mismatch from the provided signature and the Discord data, + # there are many ways it can go wrong yet all of them eventually lead to a TypeError + # from the Python compiler showcasing that the signature is incorrect. This lovely + # piece of code essentially checks the last frame of the caller and checks if the + # locals contains our `self` reference. + # + # This is because there is a possibility that a TypeError is raised within the body + # of the function, and in that case the locals wouldn't contain a reference to + # the command object under the name `self`. + frame = inspect.trace()[-1].frame + if frame.f_locals.get('self') is self: + raise CommandSignatureMismatch(self) from None + raise CommandInvokeError(self, e) from e + except AppCommandError: + raise + except Exception as e: + raise CommandInvokeError(self, e) from e + + async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T: + if not await self._check_can_run(interaction): + raise CheckFailure(f'The check functions for command {self.name!r} failed.') + + transformed_values = await self._transform_arguments(interaction, namespace) + return await self._do_call(interaction, transformed_values) + + async def _invoke_autocomplete(self, interaction: Interaction, name: str, namespace: Namespace): + # The namespace contains the Discord provided names so this will be fine + # even if the name is renamed + value = namespace.__dict__[name] + + try: + param = self._params[name] + except KeyError: + # Slow case, it might be a rename + params = {param.display_name: param for param in self._params.values()} + try: + param = params[name] + except KeyError: + raise CommandSignatureMismatch(self) from None + + if param.autocomplete is None: + raise CommandSignatureMismatch(self) + + predicates = getattr(param.autocomplete, '__discord_app_commands_checks__', []) + if predicates: + try: + passed = await async_all(f(interaction) for f in predicates) # type: ignore + except Exception: + passed = False + + if not passed: + if not interaction.response.is_done(): + await interaction.response.autocomplete([]) + return + + if getattr(param.autocomplete, 'pass_command_binding', False): + binding = self.binding + if binding is not None: + choices = await param.autocomplete(binding, interaction, value) + else: + raise TypeError('autocomplete parameter expected a bound self parameter but one was not provided') + else: + choices = await param.autocomplete(interaction, value) + + if interaction.response.is_done(): + return + + await interaction.response.autocomplete(choices) + + def _get_internal_command(self, name: str) -> Optional[Union[Command, Group]]: + return None + + @property + def parameters(self) -> List[Parameter]: + """Returns a list of parameters for this command. + + This does not include the ``self`` or ``interaction`` parameters. + + Returns + -------- + List[:class:`Parameter`] + The parameters of this command. + """ + return [Parameter(p, self) for p in self._params.values()] + + def get_parameter(self, name: str) -> Optional[Parameter]: + """Retrieves a parameter by its name. + + The name must be the Python identifier rather than the renamed + one for display on Discord. + + Parameters + ----------- + name: :class:`str` + The parameter name in the callback function. + + Returns + -------- + Optional[:class:`Parameter`] + The parameter or ``None`` if not found. + """ + + parent = self._params.get(name) + if parent is not None: + return Parameter(parent, self) + return None + + @property + def root_parent(self) -> Optional[Group]: + """Optional[:class:`Group`]: The root parent of this command.""" + if self.parent is None: + return None + parent = self.parent + return parent.parent or parent + + @property + def qualified_name(self) -> str: + """:class:`str`: Returns the fully qualified command name. + + The qualified name includes the parent name as well. For example, + in a command like ``/foo bar`` the qualified name is ``foo bar``. + """ + # A B C + # ^ self + # ^ parent + # ^ grandparent + if self.parent is None: + return self.name + + names = [self.name, self.parent.name] + grandparent = self.parent.parent + if grandparent is not None: + names.append(grandparent.name) + + return ' '.join(reversed(names)) + + async def _check_can_run(self, interaction: Interaction) -> bool: + if self.parent is not None and self.parent is not self.binding: + # For commands with a parent which isn't the binding, i.e. + # + # + # + # The parent check needs to be called first + if not await maybe_coroutine(self.parent.interaction_check, interaction): + return False + + if self.binding is not None: + check: Optional[Check] = getattr(self.binding, 'interaction_check', None) + if check: + ret = await maybe_coroutine(check, interaction) + if not ret: + return False + + predicates = self.checks + if not predicates: + return True + + return await async_all(f(interaction) for f in predicates) # type: ignore + + def error(self, coro: Error[GroupT]) -> Error[GroupT]: + """A decorator that registers a coroutine as a local error handler. + + The local error handler is called whenever an exception is raised in the body + of the command or during handling of the command. The error handler must take + 2 parameters, the interaction and the error. + + The error passed will be derived from :exc:`AppCommandError`. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + + if not inspect.iscoroutinefunction(coro): + raise TypeError('The error handler must be a coroutine.') + + self.on_error = coro + return coro + + def autocomplete( + self, name: str + ) -> Callable[[AutocompleteCallback[GroupT, ChoiceT]], AutocompleteCallback[GroupT, ChoiceT]]: + """A decorator that registers a coroutine as an autocomplete prompt for a parameter. + + The coroutine callback must have 2 parameters, the :class:`~discord.Interaction`, + and the current value by the user (the string currently being typed by the user). + + To get the values from other parameters that may be filled in, accessing + :attr:`.Interaction.namespace` will give a :class:`Namespace` object with those + values. + + Parent :func:`checks ` are ignored within an autocomplete. However, checks can be added + to the autocomplete callback and the ones added will be called. If the checks fail for any reason + then an empty list is sent as the interaction response. + + The coroutine decorator **must** return a list of :class:`~discord.app_commands.Choice` objects. + Only up to 25 objects are supported. + + .. warning:: + The choices returned from this coroutine are suggestions. The user may ignore them and input their own value. + + Example: + + .. code-block:: python3 + + @app_commands.command() + async def fruits(interaction: discord.Interaction, fruit: str): + await interaction.response.send_message(f'Your favourite fruit seems to be {fruit}') + + @fruits.autocomplete('fruit') + async def fruits_autocomplete( + interaction: discord.Interaction, + current: str, + ) -> List[app_commands.Choice[str]]: + fruits = ['Banana', 'Pineapple', 'Apple', 'Watermelon', 'Melon', 'Cherry'] + return [ + app_commands.Choice(name=fruit, value=fruit) + for fruit in fruits if current.lower() in fruit.lower() + ] + + + Parameters + ----------- + name: :class:`str` + The parameter name to register as autocomplete. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine or + the parameter is not found or of an invalid type. + """ + + def decorator(coro: AutocompleteCallback[GroupT, ChoiceT]) -> AutocompleteCallback[GroupT, ChoiceT]: + if not inspect.iscoroutinefunction(coro): + raise TypeError('The autocomplete callback must be a coroutine function.') + + try: + param = self._params[name] + except KeyError: + raise TypeError(f'unknown parameter: {name!r}') from None + + if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): + raise TypeError('autocomplete is only supported for integer, string, or number option types') + + if param.is_choice_annotation(): + raise TypeError( + 'Choice annotation unsupported for autocomplete parameters, consider using a regular annotation instead' + ) + + param.autocomplete = validate_auto_complete_callback(coro) + return coro + + return decorator + + def add_check(self, func: Check, /) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`check`. + + Parameters + ----------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check, /) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + + +class ContextMenu: + """A class that implements a context menu application command. + + These are usually not created manually, instead they are created using + one of the following decorators: + + - :func:`~discord.app_commands.context_menu` + - :meth:`CommandTree.context_menu ` + + .. versionadded:: 2.0 + + Parameters + ----------- + name: Union[:class:`str`, :class:`locale_str`] + The name of the context menu. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + type: :class:`.AppCommandType` + The type of context menu application command. By default, this is inferred + by the parameter of the callback. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + Defaults to ``False``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + + Attributes + ------------ + name: :class:`str` + The name of the context menu. + type: :class:`.AppCommandType` + The type of context menu application command. By default, this is inferred + by the parameter of the callback. + default_permissions: Optional[:class:`~discord.Permissions`] + The default permissions that can execute this command on Discord. Note + that server administrators can override this value in the client. + Setting an empty permissions field will disallow anyone except server + administrators from using the command in a guild. + guild_only: :class:`bool` + Whether the command should only be usable in guild contexts. + Defaults to ``False``. + allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`] + The contexts that this context menu is allowed to be used in. + Overrides ``guild_only`` if set. + + .. versionadded:: 2.4 + allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`] + The installation contexts that the command is allowed to be installed + on. + + .. versionadded:: 2.4 + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + Defaults to ``False``. + checks + A list of predicates that take a :class:`~discord.Interaction` parameter + to indicate whether the command callback should be executed. If an exception + is necessary to be thrown to signal failure, then one inherited from + :exc:`AppCommandError` should be used. If all the checks fail without + propagating an exception, :exc:`CheckFailure` is raised. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def __init__( + self, + *, + name: Union[str, locale_str], + callback: ContextMenuCallback, + type: AppCommandType = MISSING, + nsfw: bool = False, + guild_ids: Optional[List[int]] = None, + allowed_contexts: Optional[AppCommandContext] = None, + allowed_installs: Optional[AppInstallationType] = None, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, + ): + name, locale = (name.message, name) if isinstance(name, locale_str) else (name, None) + self.name: str = validate_context_menu_name(name) + self._locale_name: Optional[locale_str] = locale + self._callback: ContextMenuCallback = callback + (param, annotation, actual_type) = _get_context_menu_parameter(callback) + if type is MISSING: + type = actual_type + + if actual_type != type: + raise ValueError(f'context menu callback implies a type of {actual_type} but {type} was passed.') + + self.type: AppCommandType = type + self._param_name = param + self._annotation = annotation + self.module: Optional[str] = callback.__module__ + self._guild_ids = guild_ids + if self._guild_ids is None: + self._guild_ids = getattr(callback, '__discord_app_commands_default_guilds__', None) + self.on_error: Optional[UnboundError] = None + self.default_permissions: Optional[Permissions] = getattr( + callback, '__discord_app_commands_default_permissions__', None + ) + self.nsfw: bool = nsfw + self.guild_only: bool = getattr(callback, '__discord_app_commands_guild_only__', False) + self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts or getattr( + callback, '__discord_app_commands_contexts__', None + ) + self.allowed_installs: Optional[AppInstallationType] = allowed_installs or getattr( + callback, '__discord_app_commands_installation_types__', None + ) + self.checks: List[Check] = getattr(callback, '__discord_app_commands_checks__', []) + self.extras: Dict[Any, Any] = extras or {} + + if auto_locale_strings: + if self._locale_name is None: + self._locale_name = locale_str(self.name) + + @property + def callback(self) -> ContextMenuCallback: + """:ref:`coroutine `: The coroutine that is executed when the context menu is called.""" + return self._callback + + @property + def qualified_name(self) -> str: + """:class:`str`: Returns the fully qualified command name.""" + return self.name + + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) + context = TranslationContext(location=TranslationContextLocation.command_name, data=self) + if self._locale_name: + name_localizations: Dict[str, str] = {} + for locale in Locale: + translation = await translator._checked_translate(self._locale_name, locale, context) + if translation is not None: + name_localizations[locale.value] = translation + + base['name_localizations'] = name_localizations + return base + + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: + return { + 'name': self.name, + 'type': self.type.value, + 'dm_permission': not self.guild_only, + 'contexts': tree.allowed_contexts._merge_to_array(self.allowed_contexts), + 'integration_types': tree.allowed_installs._merge_to_array(self.allowed_installs), + 'default_member_permissions': None if self.default_permissions is None else self.default_permissions.value, + 'nsfw': self.nsfw, + } + + async def _check_can_run(self, interaction: Interaction) -> bool: + predicates = self.checks + if not predicates: + return True + + return await async_all(f(interaction) for f in predicates) # type: ignore + + def _has_any_error_handlers(self) -> bool: + return self.on_error is not None + + async def _invoke(self, interaction: Interaction, arg: Any): + try: + if not await self._check_can_run(interaction): + raise CheckFailure(f'The check functions for context menu {self.name!r} failed.') + + await self._callback(interaction, arg) + except AppCommandError: + raise + except Exception as e: + raise CommandInvokeError(self, e) from e + + def error(self, coro: UnboundError) -> UnboundError: + """A decorator that registers a coroutine as a local error handler. + + The local error handler is called whenever an exception is raised in the body + of the command or during handling of the command. The error handler must take + 2 parameters, the interaction and the error. + + The error passed will be derived from :exc:`AppCommandError`. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine. + """ + + if not inspect.iscoroutinefunction(coro): + raise TypeError('The error handler must be a coroutine.') + + self.on_error = coro + return coro + + def add_check(self, func: Check, /) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`check`. + + Parameters + ----------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: Check, /) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + Parameters + ----------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass + + +class Group: + """A class that implements an application command group. + + These are usually inherited rather than created manually. + + Decorators such as :func:`guild_only`, :func:`guilds`, and :func:`default_permissions` + will apply to the group if used on top of a subclass. For example: + + .. code-block:: python3 + + from discord import app_commands + + @app_commands.guild_only() + class MyGroup(app_commands.Group): + pass + + .. versionadded:: 2.0 + + Parameters + ----------- + name: Union[:class:`str`, :class:`locale_str`] + The name of the group. If not given, it defaults to a lower-case + kebab-case version of the class name. + description: Union[:class:`str`, :class:`locale_str`] + The description of the group. This shows up in the UI to describe + the group. If not given, it defaults to the docstring of the + class shortened to 100 characters. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + default_permissions: Optional[:class:`~discord.Permissions`] + The default permissions that can execute this group on Discord. Note + that server administrators can override this value in the client. + Setting an empty permissions field will disallow anyone except server + administrators from using the command in a guild. + + Due to a Discord limitation, this does not work on subcommands. + guild_only: :class:`bool` + Whether the group should only be usable in guild contexts. + Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + parent: Optional[:class:`Group`] + The parent application command. ``None`` if there isn't one. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + + Attributes + ------------ + name: :class:`str` + The name of the group. + description: :class:`str` + The description of the group. This shows up in the UI to describe + the group. + default_permissions: Optional[:class:`~discord.Permissions`] + The default permissions that can execute this group on Discord. Note + that server administrators can override this value in the client. + Setting an empty permissions field will disallow anyone except server + administrators from using the command in a guild. + + Due to a Discord limitation, this does not work on subcommands. + guild_only: :class:`bool` + Whether the group should only be usable in guild contexts. + + Due to a Discord limitation, this does not work on subcommands. + allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`] + The contexts that this group is allowed to be used in. Overrides + guild_only if set. + + .. versionadded:: 2.4 + allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`] + The installation contexts that the command is allowed to be installed + on. + + .. versionadded:: 2.4 + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + + Due to a Discord limitation, this does not work on subcommands. + parent: Optional[:class:`Group`] + The parent group. ``None`` if there isn't one. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + __discord_app_commands_group_children__: ClassVar[List[Union[Command[Any, ..., Any], Group]]] = [] + __discord_app_commands_skip_init_binding__: bool = False + __discord_app_commands_group_name__: str = MISSING + __discord_app_commands_group_description__: str = MISSING + __discord_app_commands_group_locale_name__: Optional[locale_str] = None + __discord_app_commands_group_locale_description__: Optional[locale_str] = None + __discord_app_commands_group_nsfw__: bool = False + __discord_app_commands_guild_only__: bool = MISSING + __discord_app_commands_contexts__: Optional[AppCommandContext] = MISSING + __discord_app_commands_installation_types__: Optional[AppInstallationType] = MISSING + __discord_app_commands_default_permissions__: Optional[Permissions] = MISSING + __discord_app_commands_has_module__: bool = False + __discord_app_commands_error_handler__: Optional[Callable[[Interaction, AppCommandError], Coroutine[Any, Any, None]]] = ( + None + ) + + def __init_subclass__( + cls, + *, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + guild_only: bool = MISSING, + nsfw: bool = False, + default_permissions: Optional[Permissions] = MISSING, + ) -> None: + if not cls.__discord_app_commands_group_children__: + children: List[Union[Command[Any, ..., Any], Group]] = [ + member for member in cls.__dict__.values() if isinstance(member, (Group, Command)) and member.parent is None + ] + + cls.__discord_app_commands_group_children__ = children + + found = set() + for child in children: + if child.name in found: + raise TypeError(f'Command {child.name!r} is a duplicate') + found.add(child.name) + + if len(children) > 25: + raise TypeError('groups cannot have more than 25 commands') + + if name is MISSING: + cls.__discord_app_commands_group_name__ = validate_name(_to_kebab_case(cls.__name__)) + elif isinstance(name, str): + cls.__discord_app_commands_group_name__ = validate_name(name) + else: + cls.__discord_app_commands_group_name__ = validate_name(name.message) + cls.__discord_app_commands_group_locale_name__ = name + + if description is MISSING: + if cls.__doc__ is None: + cls.__discord_app_commands_group_description__ = '…' + else: + cls.__discord_app_commands_group_description__ = _shorten(cls.__doc__) + elif isinstance(description, str): + cls.__discord_app_commands_group_description__ = description + else: + cls.__discord_app_commands_group_description__ = description.message + cls.__discord_app_commands_group_locale_description__ = description + + if guild_only is not MISSING: + cls.__discord_app_commands_guild_only__ = guild_only + + if default_permissions is not MISSING: + cls.__discord_app_commands_default_permissions__ = default_permissions + + if cls.__module__ != __name__: + cls.__discord_app_commands_has_module__ = True + cls.__discord_app_commands_group_nsfw__ = nsfw + + def __init__( + self, + *, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + parent: Optional[Group] = None, + guild_ids: Optional[List[int]] = None, + guild_only: bool = MISSING, + allowed_contexts: Optional[AppCommandContext] = MISSING, + allowed_installs: Optional[AppInstallationType] = MISSING, + nsfw: bool = MISSING, + auto_locale_strings: bool = True, + default_permissions: Optional[Permissions] = MISSING, + extras: Dict[Any, Any] = MISSING, + ): + cls = self.__class__ + + if name is MISSING: + name, locale = cls.__discord_app_commands_group_name__, cls.__discord_app_commands_group_locale_name__ + elif isinstance(name, str): + name, locale = validate_name(name), None + else: + name, locale = validate_name(name.message), name + self.name: str = name + self._locale_name: Optional[locale_str] = locale + + if description is MISSING: + description, locale = ( + cls.__discord_app_commands_group_description__, + cls.__discord_app_commands_group_locale_description__, + ) + elif isinstance(description, str): + description, locale = description, None + else: + description, locale = description.message, description + self.description: str = description + self._locale_description: Optional[locale_str] = locale + + self._attr: Optional[str] = None + self._owner_cls: Optional[Type[Any]] = None + self._guild_ids: Optional[List[int]] = guild_ids + if self._guild_ids is None: + self._guild_ids = getattr(cls, '__discord_app_commands_default_guilds__', None) + + if default_permissions is MISSING: + if cls.__discord_app_commands_default_permissions__ is MISSING: + default_permissions = None + else: + default_permissions = cls.__discord_app_commands_default_permissions__ + + self.default_permissions: Optional[Permissions] = default_permissions + + if guild_only is MISSING: + if cls.__discord_app_commands_guild_only__ is MISSING: + guild_only = False + else: + guild_only = cls.__discord_app_commands_guild_only__ + + self.guild_only: bool = guild_only + + if allowed_contexts is MISSING: + if cls.__discord_app_commands_contexts__ is MISSING: + allowed_contexts = None + else: + allowed_contexts = cls.__discord_app_commands_contexts__ + + self.allowed_contexts: Optional[AppCommandContext] = allowed_contexts + + if allowed_installs is MISSING: + if cls.__discord_app_commands_installation_types__ is MISSING: + allowed_installs = None + else: + allowed_installs = cls.__discord_app_commands_installation_types__ + + self.allowed_installs: Optional[AppInstallationType] = allowed_installs + + if nsfw is MISSING: + nsfw = cls.__discord_app_commands_group_nsfw__ + + self.nsfw: bool = nsfw + + if not self.description: + raise TypeError('groups must have a description') + + if not self.name: + raise TypeError('groups must have a name') + + self.parent: Optional[Group] = parent + self.module: Optional[str] + if cls.__discord_app_commands_has_module__: + self.module = cls.__module__ + else: + try: + # This is pretty hacky + # It allows the module to be fetched if someone just constructs a bare Group object though. + self.module = inspect.currentframe().f_back.f_globals['__name__'] # type: ignore + except (AttributeError, IndexError, KeyError): + self.module = None + + self._children: Dict[str, Union[Command, Group]] = {} + self.extras: Dict[Any, Any] = extras or {} + + bindings: Dict[Group, Group] = {} + + for child in self.__discord_app_commands_group_children__: + # commands and groups created directly in this class (no parent) + copy = ( + child._copy_with(parent=self, binding=self, bindings=bindings, set_on_binding=False) + if not cls.__discord_app_commands_skip_init_binding__ + else child + ) + + self._children[copy.name] = copy + if copy._attr and not cls.__discord_app_commands_skip_init_binding__: + setattr(self, copy._attr, copy) + + if parent is not None: + if parent.parent is not None: + raise ValueError('groups can only be nested at most one level') + parent.add_command(self) + + if auto_locale_strings: + self._convert_to_locale_strings() + + def _convert_to_locale_strings(self) -> None: + if self._locale_name is None: + self._locale_name = locale_str(self.name) + if self._locale_description is None: + self._locale_description = locale_str(self.description) + + # I don't know if propagating to the children is the right behaviour here. + + def __set_name__(self, owner: Type[Any], name: str) -> None: + self._attr = name + self.module = owner.__module__ + self._owner_cls = owner + + def _copy_with( + self, + *, + parent: Optional[Group], + binding: Binding, + bindings: MutableMapping[Group, Group] = MISSING, + set_on_binding: bool = True, + ) -> Group: + bindings = {} if bindings is MISSING else bindings + + copy = shallow_copy(self) + copy.parent = parent + copy._children = {} + + bindings[self] = copy + + for child in self._children.values(): + child_copy = child._copy_with(parent=copy, binding=binding, bindings=bindings) + child_copy.parent = copy + copy._children[child_copy.name] = child_copy + + if isinstance(child_copy, Group) and child_copy._attr and set_on_binding: + if binding.__class__ is child_copy._owner_cls: + setattr(binding, child_copy._attr, child_copy) + elif child_copy._owner_cls is copy.__class__: + setattr(copy, child_copy._attr, child_copy) + + if copy._attr and set_on_binding: + setattr(parent or binding, copy._attr, copy) + + return copy + + async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]: + base = self.to_dict(tree) + name_localizations: Dict[str, str] = {} + description_localizations: Dict[str, str] = {} + + # Prevent creating these objects in a heavy loop + name_context = TranslationContext(location=TranslationContextLocation.group_name, data=self) + description_context = TranslationContext(location=TranslationContextLocation.group_description, data=self) + for locale in Locale: + if self._locale_name: + translation = await translator._checked_translate(self._locale_name, locale, name_context) + if translation is not None: + name_localizations[locale.value] = translation + + if self._locale_description: + translation = await translator._checked_translate(self._locale_description, locale, description_context) + if translation is not None: + description_localizations[locale.value] = translation + + base['name_localizations'] = name_localizations + base['description_localizations'] = description_localizations + base['options'] = [await child.get_translated_payload(tree, translator) for child in self._children.values()] + return base + + def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]: + # If this has a parent command then it's part of a subcommand group + # Otherwise, it's just a regular command + option_type = 1 if self.parent is None else AppCommandOptionType.subcommand_group.value + base: Dict[str, Any] = { + 'name': self.name, + 'description': self.description, + 'type': option_type, + 'options': [child.to_dict(tree) for child in self._children.values()], + } + + if self.parent is None: + base['nsfw'] = self.nsfw + base['dm_permission'] = not self.guild_only + base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value + base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts) + base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs) + + return base + + @property + def root_parent(self) -> Optional[Group]: + """Optional[:class:`Group`]: The parent of this group.""" + return self.parent + + @property + def qualified_name(self) -> str: + """:class:`str`: Returns the fully qualified group name. + + The qualified name includes the parent name as well. For example, + in a group like ``/foo bar`` the qualified name is ``foo bar``. + """ + + if self.parent is None: + return self.name + return f'{self.parent.name} {self.name}' + + def _get_internal_command(self, name: str) -> Optional[Union[Command[Any, ..., Any], Group]]: + return self._children.get(name) + + @property + def commands(self) -> List[Union[Command[Any, ..., Any], Group]]: + """List[Union[:class:`Command`, :class:`Group`]]: The commands that this group contains.""" + return list(self._children.values()) + + def walk_commands(self) -> Generator[Union[Command[Any, ..., Any], Group], None, None]: + """An iterator that recursively walks through all commands that this group contains. + + Yields + --------- + Union[:class:`Command`, :class:`Group`] + The commands in this group. + """ + + for command in self._children.values(): + yield command + if isinstance(command, Group): + yield from command.walk_commands() + + @mark_overrideable + async def on_error(self, interaction: Interaction, error: AppCommandError, /) -> None: + """|coro| + + A callback that is called when a child's command raises an :exc:`AppCommandError`. + + To get the command that failed, :attr:`discord.Interaction.command` should be used. + + The default implementation does nothing. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that is being handled. + error: :exc:`AppCommandError` + The exception that was raised. + """ + + pass + + def error(self, coro: ErrorFunc) -> ErrorFunc: + """A decorator that registers a coroutine as a local error handler. + + The local error handler is called whenever an exception is raised in a child command. + The error handler must take 2 parameters, the interaction and the error. + + The error passed will be derived from :exc:`AppCommandError`. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine, or is an invalid coroutine. + """ + + if not inspect.iscoroutinefunction(coro): + raise TypeError('The error handler must be a coroutine.') + + params = inspect.signature(coro).parameters + if len(params) != 2: + raise TypeError('The error handler must have 2 parameters.') + + self.on_error = coro # type: ignore + return coro + + async def interaction_check(self, interaction: Interaction, /) -> bool: + """|coro| + + A callback that is called when an interaction happens within the group + that checks whether a command inside the group should be executed. + + This is useful to override if, for example, you want to ensure that the + interaction author is a given user. + + The default implementation of this returns ``True``. + + .. note:: + + If an exception occurs within the body then the check + is considered a failure and error handlers such as + :meth:`on_error` is called. See :exc:`AppCommandError` + for more information. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that occurred. + + Returns + --------- + :class:`bool` + Whether the view children's callbacks should be called. + """ + + return True + + def add_command(self, command: Union[Command[Any, ..., Any], Group], /, *, override: bool = False) -> None: + """Adds a command or group to this group's internal list of commands. + + Parameters + ----------- + command: Union[:class:`Command`, :class:`Group`] + The command or group to add. + override: :class:`bool` + Whether to override a pre-existing command or group with the same name. + If ``False`` then an exception is raised. + + Raises + ------- + CommandAlreadyRegistered + The command or group is already registered. Note that the :attr:`CommandAlreadyRegistered.guild_id` + attribute will always be ``None`` in this case. + ValueError + There are too many commands already registered or the group is too + deeply nested. + TypeError + The wrong command type was passed. + """ + + if not isinstance(command, (Command, Group)): + raise TypeError(f'expected Command or Group not {command.__class__.__name__}') + + if isinstance(command, Group) and self.parent is not None: + # In a tree like so: + # + # + # + # this needs to be forbidden + raise ValueError(f'{command.name!r} is too nested, groups can only be nested at most one level') + + if not override and command.name in self._children: + raise CommandAlreadyRegistered(command.name, guild_id=None) + + self._children[command.name] = command + command.parent = self + if len(self._children) > 25: + raise ValueError('maximum number of child commands exceeded') + + def remove_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]: + """Removes a command or group from the internal list of commands. + + Parameters + ----------- + name: :class:`str` + The name of the command or group to remove. + + Returns + -------- + Optional[Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.Group`]] + The command that was removed. If nothing was removed + then ``None`` is returned instead. + """ + + self._children.pop(name, None) + + def get_command(self, name: str, /) -> Optional[Union[Command[Any, ..., Any], Group]]: + """Retrieves a command or group from its name. + + Parameters + ----------- + name: :class:`str` + The name of the command or group to retrieve. + + Returns + -------- + Optional[Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.Group`]] + The command or group that was retrieved. If nothing was found + then ``None`` is returned instead. + """ + return self._children.get(name) + + def command( + self, + *, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + nsfw: bool = False, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, + ) -> Callable[[CommandCallback[GroupT, P, T]], Command[GroupT, P, T]]: + """A decorator that creates an application command from a regular function under this group. + + Parameters + ------------ + name: Union[:class:`str`, :class:`locale_str`] + The name of the application command. If not given, it defaults to a lower-case + version of the callback name. + description: Union[:class:`str`, :class:`locale_str`] + The description of the application command. This shows up in the UI to describe + the application command. If not given, it defaults to the first line of the docstring + of the callback shortened to 100 characters. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. Defaults to ``False``. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]: + if not inspect.iscoroutinefunction(func): + raise TypeError('command function must be a coroutine function') + + if description is MISSING: + if func.__doc__ is None: + desc = '…' + else: + desc = _shorten(func.__doc__) + else: + desc = description + + command = Command( + name=name if name is not MISSING else func.__name__, + description=desc, + callback=func, + nsfw=nsfw, + parent=self, + auto_locale_strings=auto_locale_strings, + extras=extras, + ) + self.add_command(command) + return command + + return decorator + + +def command( + *, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + nsfw: bool = False, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, +) -> Callable[[CommandCallback[GroupT, P, T]], Command[GroupT, P, T]]: + """Creates an application command from a regular function. + + Parameters + ------------ + name: :class:`str` + The name of the application command. If not given, it defaults to a lower-case + version of the callback name. + description: :class:`str` + The description of the application command. This shows up in the UI to describe + the application command. If not given, it defaults to the first line of the docstring + of the callback shortened to 100 characters. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def decorator(func: CommandCallback[GroupT, P, T]) -> Command[GroupT, P, T]: + if not inspect.iscoroutinefunction(func): + raise TypeError('command function must be a coroutine function') + + if description is MISSING: + if func.__doc__ is None: + desc = '…' + else: + desc = _shorten(func.__doc__) + else: + desc = description + + return Command( + name=name if name is not MISSING else func.__name__, + description=desc, + callback=func, + parent=None, + nsfw=nsfw, + auto_locale_strings=auto_locale_strings, + extras=extras, + ) + + return decorator + + +def context_menu( + *, + name: Union[str, locale_str] = MISSING, + nsfw: bool = False, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, +) -> Callable[[ContextMenuCallback], ContextMenu]: + """Creates an application command context menu from a regular function. + + This function must have a signature of :class:`~discord.Interaction` as its first parameter + and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`, + or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.context_menu() + async def react(interaction: discord.Interaction, message: discord.Message): + await interaction.response.send_message('Very cool message!', ephemeral=True) + + @app_commands.context_menu() + async def ban(interaction: discord.Interaction, user: discord.Member): + await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True) + + Parameters + ------------ + name: Union[:class:`str`, :class:`locale_str`] + The name of the context menu command. If not given, it defaults to a title-case + version of the callback name. Note that unlike regular slash commands this can + have spaces and upper case characters in the name. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def decorator(func: ContextMenuCallback) -> ContextMenu: + if not inspect.iscoroutinefunction(func): + raise TypeError('context menu function must be a coroutine function') + + actual_name = func.__name__.title() if name is MISSING else name + return ContextMenu( + name=actual_name, + nsfw=nsfw, + callback=func, + auto_locale_strings=auto_locale_strings, + extras=extras, + ) + + return decorator + + +def describe(**parameters: Union[str, locale_str]) -> Callable[[T], T]: + r'''Describes the given parameters by their name using the key of the keyword argument + as the name. + + Example: + + .. code-block:: python3 + + @app_commands.command(description='Bans a member') + @app_commands.describe(member='the member to ban') + async def ban(interaction: discord.Interaction, member: discord.Member): + await interaction.response.send_message(f'Banned {member}') + + Alternatively, you can describe parameters using Google, Sphinx, or Numpy style docstrings. + + Example: + + .. code-block:: python3 + + @app_commands.command() + async def ban(interaction: discord.Interaction, member: discord.Member): + """Bans a member + + Parameters + ----------- + member: discord.Member + the member to ban + """ + await interaction.response.send_message(f'Banned {member}') + + Parameters + ----------- + \*\*parameters: Union[:class:`str`, :class:`locale_str`] + The description of the parameters. + + Raises + -------- + TypeError + The parameter name is not found. + ''' + + def decorator(inner: T) -> T: + if isinstance(inner, Command): + _populate_descriptions(inner._params, parameters) + else: + try: + inner.__discord_app_commands_param_description__.update(parameters) # type: ignore # Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_description__ = parameters # type: ignore # Runtime attribute assignment + + return inner + + return decorator + + +def rename(**parameters: Union[str, locale_str]) -> Callable[[T], T]: + r"""Renames the given parameters by their name using the key of the keyword argument + as the name. + + This renames the parameter within the Discord UI. When referring to the parameter in other + decorators, the parameter name used in the function is used instead of the renamed one. + + Example: + + .. code-block:: python3 + + @app_commands.command() + @app_commands.rename(the_member_to_ban='member') + async def ban(interaction: discord.Interaction, the_member_to_ban: discord.Member): + await interaction.response.send_message(f'Banned {the_member_to_ban}') + + Parameters + ----------- + \*\*parameters: Union[:class:`str`, :class:`locale_str`] + The name of the parameters. + + Raises + -------- + ValueError + The parameter name is already used by another parameter. + TypeError + The parameter name is not found. + """ + + def decorator(inner: T) -> T: + if isinstance(inner, Command): + _populate_renames(inner._params, parameters) + else: + try: + inner.__discord_app_commands_param_rename__.update(parameters) # type: ignore # Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_rename__ = parameters # type: ignore # Runtime attribute assignment + + return inner + + return decorator + + +def choices(**parameters: List[Choice[ChoiceT]]) -> Callable[[T], T]: + r"""Instructs the given parameters by their name to use the given choices for their choices. + + Example: + + .. code-block:: python3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + @app_commands.choices(fruits=[ + Choice(name='apple', value=1), + Choice(name='banana', value=2), + Choice(name='cherry', value=3), + ]) + async def fruit(interaction: discord.Interaction, fruits: Choice[int]): + await interaction.response.send_message(f'Your favourite fruit is {fruits.name}.') + + .. note:: + + This is not the only way to provide choices to a command. There are two more ergonomic ways + of doing this. The first one is to use a :obj:`typing.Literal` annotation: + + .. code-block:: python3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + async def fruit(interaction: discord.Interaction, fruits: Literal['apple', 'banana', 'cherry']): + await interaction.response.send_message(f'Your favourite fruit is {fruits}.') + + The second way is to use an :class:`enum.Enum`: + + .. code-block:: python3 + + class Fruits(enum.Enum): + apple = 1 + banana = 2 + cherry = 3 + + @app_commands.command() + @app_commands.describe(fruits='fruits to choose from') + async def fruit(interaction: discord.Interaction, fruits: Fruits): + await interaction.response.send_message(f'Your favourite fruit is {fruits}.') + + + Parameters + ----------- + \*\*parameters + The choices of the parameters. + + Raises + -------- + TypeError + The parameter name is not found or the parameter type was incorrect. + """ + + def decorator(inner: T) -> T: + if isinstance(inner, Command): + _populate_choices(inner._params, parameters) + else: + try: + inner.__discord_app_commands_param_choices__.update(parameters) # type: ignore # Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_choices__ = parameters # type: ignore # Runtime attribute assignment + + return inner + + return decorator + + +def autocomplete(**parameters: AutocompleteCallback[GroupT, ChoiceT]) -> Callable[[T], T]: + r"""Associates the given parameters with the given autocomplete callback. + + Autocomplete is only supported on types that have :class:`str`, :class:`int`, or :class:`float` + values. + + :func:`Checks ` are supported, however they must be attached to the autocomplete + callback in order to work. Checks attached to the command are ignored when invoking the autocomplete + callback. + + For more information, see the :meth:`Command.autocomplete` documentation. + + .. warning:: + The choices returned from this coroutine are suggestions. The user may ignore them and input their own value. + + Example: + + .. code-block:: python3 + + async def fruit_autocomplete( + interaction: discord.Interaction, + current: str, + ) -> List[app_commands.Choice[str]]: + fruits = ['Banana', 'Pineapple', 'Apple', 'Watermelon', 'Melon', 'Cherry'] + return [ + app_commands.Choice(name=fruit, value=fruit) + for fruit in fruits if current.lower() in fruit.lower() + ] + + @app_commands.command() + @app_commands.autocomplete(fruit=fruit_autocomplete) + async def fruits(interaction: discord.Interaction, fruit: str): + await interaction.response.send_message(f'Your favourite fruit seems to be {fruit}') + + Parameters + ----------- + \*\*parameters + The parameters to mark as autocomplete. + + Raises + -------- + TypeError + The parameter name is not found or the parameter type was incorrect. + """ + + def decorator(inner: T) -> T: + if isinstance(inner, Command): + _populate_autocomplete(inner._params, parameters) + else: + try: + inner.__discord_app_commands_param_autocomplete__.update(parameters) # type: ignore # Runtime attribute access + except AttributeError: + inner.__discord_app_commands_param_autocomplete__ = parameters # type: ignore # Runtime attribute assignment + + return inner + + return decorator + + +def guilds(*guild_ids: Union[Snowflake, int]) -> Callable[[T], T]: + r"""Associates the given guilds with the command. + + When the command instance is added to a :class:`CommandTree`, the guilds that are + specified by this decorator become the default guilds that it's added to rather + than being a global command. + + If no arguments are given, then the command will not be synced anywhere. This may + be modified later using the :meth:`CommandTree.add_command` method. + + .. note:: + + Due to an implementation quirk and Python limitation, if this is used in conjunction + with the :meth:`CommandTree.command` or :meth:`CommandTree.context_menu` decorator + then this must go below that decorator. + + .. note :: + + Due to a Discord limitation, this decorator cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + Example: + + .. code-block:: python3 + + MY_GUILD_ID = discord.Object(...) # Guild ID here + + @app_commands.command() + @app_commands.guilds(MY_GUILD_ID) + async def bonk(interaction: discord.Interaction): + await interaction.response.send_message('Bonk', ephemeral=True) + + Parameters + ----------- + \*guild_ids: Union[:class:`int`, :class:`~discord.abc.Snowflake`] + The guilds to associate this command with. The command tree will + use this as the default when added rather than adding it as a global + command. + """ + + defaults: List[int] = [g if isinstance(g, int) else g.id for g in guild_ids] + + def decorator(inner: T) -> T: + if isinstance(inner, (Group, ContextMenu)): + inner._guild_ids = defaults + elif isinstance(inner, Command): + if inner.parent is not None: + raise ValueError('child commands of a group cannot have default guilds set') + + inner._guild_ids = defaults + else: + # Runtime attribute assignment + inner.__discord_app_commands_default_guilds__ = defaults # type: ignore + + return inner + + return decorator + + +def check(predicate: Check) -> Callable[[T], T]: + r"""A decorator that adds a check to an application command. + + These checks should be predicates that take in a single parameter taking + a :class:`~discord.Interaction`. If the check returns a ``False``\-like value then + during invocation a :exc:`CheckFailure` exception is raised and sent to + the appropriate error handlers. + + These checks can be either a coroutine or not. + + Examples + --------- + + Creating a basic check to see if the command invoker is you. + + .. code-block:: python3 + + def check_if_it_is_me(interaction: discord.Interaction) -> bool: + return interaction.user.id == 85309593344815104 + + @tree.command() + @app_commands.check(check_if_it_is_me) + async def only_for_me(interaction: discord.Interaction): + await interaction.response.send_message('I know you!', ephemeral=True) + + Transforming common checks into its own decorator: + + .. code-block:: python3 + + def is_me(): + def predicate(interaction: discord.Interaction) -> bool: + return interaction.user.id == 85309593344815104 + return app_commands.check(predicate) + + @tree.command() + @is_me() + async def only_me(interaction: discord.Interaction): + await interaction.response.send_message('Only you!') + + Parameters + ----------- + predicate: Callable[[:class:`~discord.Interaction`], :class:`bool`] + The predicate to check if the command should be invoked. + """ + + def decorator(func: CheckInputParameter) -> CheckInputParameter: + if isinstance(func, (Command, ContextMenu)): + func.checks.append(predicate) + else: + if not hasattr(func, '__discord_app_commands_checks__'): + func.__discord_app_commands_checks__ = [] + + func.__discord_app_commands_checks__.append(predicate) + + return func + + return decorator # type: ignore + + +@overload +def guild_only(func: None = ...) -> Callable[[T], T]: ... + + +@overload +def guild_only(func: T) -> T: ... + + +def guild_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """A decorator that indicates this command can only be used in a guild context. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + Therefore, there is no error handler called when a command is used within a private message. + + This decorator can be called with or without parentheses. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.guild_only() + async def my_guild_only_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am only available in guilds!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + f.guild_only = True + allowed_contexts = f.allowed_contexts or AppCommandContext() + f.allowed_contexts = allowed_contexts + else: + f.__discord_app_commands_guild_only__ = True # type: ignore # Runtime attribute assignment + + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() + f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment + + allowed_contexts.guild = True + + return f + + # Check if called with parentheses or not + if func is None: + # Called with parentheses + return inner + else: + return inner(func) + + +@overload +def private_channel_only(func: None = ...) -> Callable[[T], T]: ... + + +@overload +def private_channel_only(func: T) -> T: ... + + +def private_channel_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """A decorator that indicates this command can only be used in the context of DMs and group DMs. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + Therefore, there is no error handler called when a command is used within a guild. + + This decorator can be called with or without parentheses. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. versionadded:: 2.4 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.private_channel_only() + async def my_private_channel_only_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am only available in DMs and GDMs!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + f.guild_only = False + allowed_contexts = f.allowed_contexts or AppCommandContext() + f.allowed_contexts = allowed_contexts + else: + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() + f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment + + allowed_contexts.private_channel = True + + return f + + # Check if called with parentheses or not + if func is None: + # Called with parentheses + return inner + else: + return inner(func) + + +@overload +def dm_only(func: None = ...) -> Callable[[T], T]: ... + + +@overload +def dm_only(func: T) -> T: ... + + +def dm_only(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """A decorator that indicates this command can only be used in the context of bot DMs. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + Therefore, there is no error handler called when a command is used within a guild or group DM. + + This decorator can be called with or without parentheses. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.dm_only() + async def my_dm_only_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am only available in DMs!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + f.guild_only = False + allowed_contexts = f.allowed_contexts or AppCommandContext() + f.allowed_contexts = allowed_contexts + else: + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() + f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment + + allowed_contexts.dm_channel = True + return f + + # Check if called with parentheses or not + if func is None: + # Called with parentheses + return inner + else: + return inner(func) + + +def allowed_contexts(guilds: bool = MISSING, dms: bool = MISSING, private_channels: bool = MISSING) -> Callable[[T], T]: + """A decorator that indicates this command can only be used in certain contexts. + Valid contexts are guilds, DMs and private channels. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. versionadded:: 2.4 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.allowed_contexts(guilds=True, dms=False, private_channels=True) + async def my_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am only available in guilds and private channels!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + f.guild_only = False + allowed_contexts = f.allowed_contexts or AppCommandContext() + f.allowed_contexts = allowed_contexts + else: + allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext() + f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment + + if guilds is not MISSING: + allowed_contexts.guild = guilds + + if dms is not MISSING: + allowed_contexts.dm_channel = dms + + if private_channels is not MISSING: + allowed_contexts.private_channel = private_channels + + return f + + return inner + + +@overload +def guild_install(func: None = ...) -> Callable[[T], T]: ... + + +@overload +def guild_install(func: T) -> T: ... + + +def guild_install(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """A decorator that indicates this command should be installed in guilds. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. versionadded:: 2.4 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.guild_install() + async def my_guild_install_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am installed in guilds by default!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + allowed_installs = f.allowed_installs or AppInstallationType() + f.allowed_installs = allowed_installs + else: + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() + f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment + + allowed_installs.guild = True + + return f + + # Check if called with parentheses or not + if func is None: + # Called with parentheses + return inner + else: + return inner(func) + + +@overload +def user_install(func: None = ...) -> Callable[[T], T]: ... + + +@overload +def user_install(func: T) -> T: ... + + +def user_install(func: Optional[T] = None) -> Union[T, Callable[[T], T]]: + """A decorator that indicates this command should be installed for users. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. versionadded:: 2.4 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.user_install() + async def my_user_install_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am installed in users by default!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + allowed_installs = f.allowed_installs or AppInstallationType() + f.allowed_installs = allowed_installs + else: + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() + f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment + + allowed_installs.user = True + + return f + + # Check if called with parentheses or not + if func is None: + # Called with parentheses + return inner + else: + return inner(func) + + +def allowed_installs( + guilds: bool = MISSING, + users: bool = MISSING, +) -> Callable[[T], T]: + """A decorator that indicates this command should be installed in certain contexts. + Valid contexts are guilds and users. + + This is **not** implemented as a :func:`check`, and is instead verified by Discord server side. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. versionadded:: 2.4 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.allowed_installs(guilds=False, users=True) + async def my_command(interaction: discord.Interaction) -> None: + await interaction.response.send_message('I am installed in users by default!') + """ + + def inner(f: T) -> T: + if isinstance(f, (Command, Group, ContextMenu)): + allowed_installs = f.allowed_installs or AppInstallationType() + f.allowed_installs = allowed_installs + else: + allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType() + f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment + + if guilds is not MISSING: + allowed_installs.guild = guilds + + if users is not MISSING: + allowed_installs.user = users + + return f + + return inner + + +def default_permissions(perms_obj: Optional[Permissions] = None, /, **perms: Unpack[_PermissionsKwargs]) -> Callable[[T], T]: + r"""A decorator that sets the default permissions needed to execute this command. + + When this decorator is used, by default users must have these permissions to execute the command. + However, an administrator can change the permissions needed to execute this command using the official + client. Therefore, this only serves as a hint. + + Setting an empty permissions field, including via calling this with no arguments, will disallow anyone + except server administrators from using the command in a guild. + + This is sent to Discord server side, and is not a :func:`check`. Therefore, error handlers are not called. + + Due to a Discord limitation, this decorator does nothing in subcommands and is ignored. + + .. warning:: + + This serves as a *hint* and members are *not* required to have the permissions given to actually + execute this command. If you want to ensure that members have the permissions needed, consider using + :func:`~discord.app_commands.checks.has_permissions` instead. + + Parameters + ----------- + \*\*perms: :class:`bool` + Keyword arguments denoting the permissions to set as the default. + perms_obj: :class:`~discord.Permissions` + A permissions object as positional argument. This can be used in combination with ``**perms``. + + .. versionadded:: 2.5 + + Examples + --------- + + .. code-block:: python3 + + @app_commands.command() + @app_commands.default_permissions(manage_messages=True) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('You may or may not have manage messages.') + + .. code-block:: python3 + + ADMIN_PERMS = discord.Permissions(administrator=True) + + @app_commands.command() + @app_commands.default_permissions(ADMIN_PERMS, manage_messages=True) + async def test(interaction: discord.Interaction): + await interaction.response.send_message('You may or may not have manage messages.') + """ + + if perms_obj is not None: + permissions = perms_obj | Permissions(**perms) + else: + permissions = Permissions(**perms) + + def decorator(func: T) -> T: + if isinstance(func, (Command, Group, ContextMenu)): + func.default_permissions = permissions + else: + func.__discord_app_commands_default_permissions__ = permissions # type: ignore # Runtime attribute assignment + + return func + + return decorator diff --git a/discord/app_commands/errors.py b/discord/app_commands/errors.py new file mode 100644 index 000000000000..2efb4e5b008b --- /dev/null +++ b/discord/app_commands/errors.py @@ -0,0 +1,519 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING, List, Optional, Sequence, Union + +from ..enums import AppCommandOptionType, AppCommandType, Locale +from ..errors import DiscordException, HTTPException, _flatten_error_dict, MissingApplicationID as MissingApplicationID +from ..utils import _human_join + +__all__ = ( + 'AppCommandError', + 'CommandInvokeError', + 'TransformerError', + 'TranslationError', + 'CheckFailure', + 'CommandAlreadyRegistered', + 'CommandSignatureMismatch', + 'CommandNotFound', + 'CommandLimitReached', + 'NoPrivateMessage', + 'MissingRole', + 'MissingAnyRole', + 'MissingPermissions', + 'BotMissingPermissions', + 'CommandOnCooldown', + 'MissingApplicationID', + 'CommandSyncFailure', +) + +if TYPE_CHECKING: + from .commands import Command, Group, ContextMenu, Parameter + from .transformers import Transformer + from .translator import TranslationContextTypes, locale_str + from ..types.snowflake import Snowflake, SnowflakeList + from .checks import Cooldown + + CommandTypes = Union[Command[Any, ..., Any], Group, ContextMenu] + + +class AppCommandError(DiscordException): + """The base exception type for all application command related errors. + + This inherits from :exc:`discord.DiscordException`. + + This exception and exceptions inherited from it are handled + in a special way as they are caught and passed into various error handlers + in this order: + + - :meth:`Command.error ` + - :meth:`Group.on_error ` + - :meth:`CommandTree.on_error ` + + .. versionadded:: 2.0 + """ + + pass + + +class CommandInvokeError(AppCommandError): + """An exception raised when the command being invoked raised an exception. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + + Attributes + ----------- + original: :exc:`Exception` + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + command: Union[:class:`Command`, :class:`ContextMenu`] + The command that failed. + """ + + def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu], e: Exception) -> None: + self.original: Exception = e + self.command: Union[Command[Any, ..., Any], ContextMenu] = command + super().__init__(f'Command {command.name!r} raised an exception: {e.__class__.__name__}: {e}') + + +class TransformerError(AppCommandError): + """An exception raised when a :class:`Transformer` or type annotation fails to + convert to its target type. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + If an exception occurs while converting that does not subclass + :exc:`AppCommandError` then the exception is wrapped into this exception. + The original exception can be retrieved using the ``__cause__`` attribute. + Otherwise if the exception derives from :exc:`AppCommandError` then it will + be propagated as-is. + + .. versionadded:: 2.0 + + Attributes + ----------- + value: Any + The value that failed to convert. + type: :class:`~discord.AppCommandOptionType` + The type of argument that failed to convert. + transformer: :class:`Transformer` + The transformer that failed the conversion. + """ + + def __init__(self, value: Any, opt_type: AppCommandOptionType, transformer: Transformer): + self.value: Any = value + self.type: AppCommandOptionType = opt_type + self.transformer: Transformer = transformer + + super().__init__(f'Failed to convert {value} to {transformer._error_display_name!s}') + + +class TranslationError(AppCommandError): + """An exception raised when the library fails to translate a string. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + If an exception occurs while calling :meth:`Translator.translate` that does + not subclass this then the exception is wrapped into this exception. + The original exception can be retrieved using the ``__cause__`` attribute. + Otherwise it will be propagated as-is. + + .. versionadded:: 2.0 + + Attributes + ----------- + string: Optional[Union[:class:`str`, :class:`locale_str`]] + The string that caused the error, if any. + locale: Optional[:class:`~discord.Locale`] + The locale that caused the error, if any. + context: :class:`~discord.app_commands.TranslationContext` + The context of the translation that triggered the error. + """ + + def __init__( + self, + *msg: str, + string: Optional[Union[str, locale_str]] = None, + locale: Optional[Locale] = None, + context: TranslationContextTypes, + ) -> None: + self.string: Optional[Union[str, locale_str]] = string + self.locale: Optional[Locale] = locale + self.context: TranslationContextTypes = context + + if msg: + super().__init__(*msg) + else: + ctx = context.location.name.replace('_', ' ') + fmt = f'Failed to translate {self.string!r} in a {ctx}' + if self.locale is not None: + fmt = f'{fmt} in the {self.locale.value} locale' + + super().__init__(fmt) + + +class CheckFailure(AppCommandError): + """An exception raised when check predicates in a command have failed. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + """ + + pass + + +class NoPrivateMessage(CheckFailure): + """An exception raised when a command does not work in a direct message. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + """ + + def __init__(self, message: Optional[str] = None) -> None: + super().__init__(message or 'This command cannot be used in direct messages.') + + +class MissingRole(CheckFailure): + """An exception raised when the command invoker lacks a role to run a command. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + Attributes + ----------- + missing_role: Union[:class:`str`, :class:`int`] + The required role that is missing. + This is the parameter passed to :func:`~discord.app_commands.checks.has_role`. + """ + + def __init__(self, missing_role: Snowflake) -> None: + self.missing_role: Snowflake = missing_role + message = f'Role {missing_role!r} is required to run this command.' + super().__init__(message) + + +class MissingAnyRole(CheckFailure): + """An exception raised when the command invoker lacks any of the roles + specified to run a command. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + Attributes + ----------- + missing_roles: List[Union[:class:`str`, :class:`int`]] + The roles that the invoker is missing. + These are the parameters passed to :func:`~discord.app_commands.checks.has_any_role`. + """ + + def __init__(self, missing_roles: SnowflakeList) -> None: + self.missing_roles: SnowflakeList = missing_roles + + fmt = _human_join([f"'{role}'" for role in missing_roles]) + message = f'You are missing at least one of the required roles: {fmt}' + super().__init__(message) + + +class MissingPermissions(CheckFailure): + """An exception raised when the command invoker lacks permissions to run a + command. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + Attributes + ----------- + missing_permissions: List[:class:`str`] + The required permissions that are missing. + """ + + def __init__(self, missing_permissions: List[str], *args: Any) -> None: + self.missing_permissions: List[str] = missing_permissions + + missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + fmt = _human_join(missing, final='and') + message = f'You are missing {fmt} permission(s) to run this command.' + super().__init__(message, *args) + + +class BotMissingPermissions(CheckFailure): + """An exception raised when the bot's member lacks permissions to run a + command. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + Attributes + ----------- + missing_permissions: List[:class:`str`] + The required permissions that are missing. + """ + + def __init__(self, missing_permissions: List[str], *args: Any) -> None: + self.missing_permissions: List[str] = missing_permissions + + missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + fmt = _human_join(missing, final='and') + message = f'Bot requires {fmt} permission(s) to run this command.' + super().__init__(message, *args) + + +class CommandOnCooldown(CheckFailure): + """An exception raised when the command being invoked is on cooldown. + + This inherits from :exc:`~discord.app_commands.CheckFailure`. + + .. versionadded:: 2.0 + + Attributes + ----------- + cooldown: :class:`~discord.app_commands.Cooldown` + The cooldown that was triggered. + retry_after: :class:`float` + The amount of seconds to wait before you can retry again. + """ + + def __init__(self, cooldown: Cooldown, retry_after: float) -> None: + self.cooldown: Cooldown = cooldown + self.retry_after: float = retry_after + super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') + + +class CommandAlreadyRegistered(AppCommandError): + """An exception raised when a command is already registered. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + + Attributes + ----------- + name: :class:`str` + The name of the command already registered. + guild_id: Optional[:class:`int`] + The guild ID this command was already registered at. + If ``None`` then it was a global command. + """ + + def __init__(self, name: str, guild_id: Optional[int]): + self.name: str = name + self.guild_id: Optional[int] = guild_id + super().__init__(f'Command {name!r} already registered.') + + +class CommandNotFound(AppCommandError): + """An exception raised when an application command could not be found. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + + Attributes + ------------ + name: :class:`str` + The name of the application command not found. + parents: List[:class:`str`] + A list of parent command names that were previously found + prior to the application command not being found. + type: :class:`~discord.AppCommandType` + The type of command that was not found. + """ + + def __init__(self, name: str, parents: List[str], type: AppCommandType = AppCommandType.chat_input): + self.name: str = name + self.parents: List[str] = parents + self.type: AppCommandType = type + super().__init__(f'Application command {name!r} not found') + + +class CommandLimitReached(AppCommandError): + """An exception raised when the maximum number of application commands was reached + either globally or in a guild. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + + Attributes + ------------ + type: :class:`~discord.AppCommandType` + The type of command that reached the limit. + guild_id: Optional[:class:`int`] + The guild ID that reached the limit or ``None`` if it was global. + limit: :class:`int` + The limit that was hit. + """ + + def __init__(self, guild_id: Optional[int], limit: int, type: AppCommandType = AppCommandType.chat_input): + self.guild_id: Optional[int] = guild_id + self.limit: int = limit + self.type: AppCommandType = type + + lookup = { + AppCommandType.chat_input: 'slash commands', + AppCommandType.message: 'message context menu commands', + AppCommandType.user: 'user context menu commands', + } + desc = lookup.get(type, 'application commands') + ns = 'globally' if self.guild_id is None else f'for guild ID {self.guild_id}' + super().__init__(f'maximum number of {desc} exceeded {limit} {ns}') + + +class CommandSignatureMismatch(AppCommandError): + """An exception raised when an application command from Discord has a different signature + from the one provided in the code. This happens because your command definition differs + from the command definition you provided Discord. Either your code is out of date or the + data from Discord is out of sync. + + This inherits from :exc:`~discord.app_commands.AppCommandError`. + + .. versionadded:: 2.0 + + Attributes + ------------ + command: Union[:class:`~.app_commands.Command`, :class:`~.app_commands.ContextMenu`, :class:`~.app_commands.Group`] + The command that had the signature mismatch. + """ + + def __init__(self, command: Union[Command[Any, ..., Any], ContextMenu, Group]): + self.command: Union[Command[Any, ..., Any], ContextMenu, Group] = command + msg = ( + f'The signature for command {command.name!r} is different from the one provided by Discord. ' + 'This can happen because either your code is out of date or you have not synced the ' + 'commands with Discord, causing the mismatch in data. It is recommended to sync the ' + 'command tree to fix this issue.' + ) + super().__init__(msg) + + +def _get_command_error( + index: str, + inner: Any, + objects: Sequence[Union[Parameter, CommandTypes]], + messages: List[str], + indent: int = 0, +) -> None: + # Import these here to avoid circular imports + from .commands import Command, Group, ContextMenu + + indentation = ' ' * indent + + # Top level errors are: + # : { : } + # The dicts could be nested, e.g. + # : { : { : } } + # Luckily, this is already handled by the flatten_error_dict utility + if not index.isdigit(): + errors = _flatten_error_dict(inner, index) + messages.extend(f'In {k}: {v}' for k, v in errors.items()) + return + + idx = int(index) + try: + obj = objects[idx] + except IndexError: + dedent_one_level = ' ' * (indent - 2) + errors = _flatten_error_dict(inner, index) + messages.extend(f'{dedent_one_level}In {k}: {v}' for k, v in errors.items()) + return + + children: Sequence[Union[Parameter, CommandTypes]] = [] + if isinstance(obj, Command): + messages.append(f'{indentation}In command {obj.qualified_name!r} defined in function {obj.callback.__qualname__!r}') + children = obj.parameters + elif isinstance(obj, Group): + messages.append(f'{indentation}In group {obj.qualified_name!r} defined in module {obj.module!r}') + children = obj.commands + elif isinstance(obj, ContextMenu): + messages.append( + f'{indentation}In context menu {obj.qualified_name!r} defined in function {obj.callback.__qualname__!r}' + ) + else: + messages.append(f'{indentation}In parameter {obj.name!r}') + + for key, remaining in inner.items(): + # Special case the 'options' key since they have well defined meanings + if key == 'options': + for index, d in remaining.items(): + _get_command_error(index, d, children, messages, indent=indent + 2) + elif key == '_errors': + errors = [x.get('message', '') for x in remaining] + + messages.extend(f'{indentation} {message}' for message in errors) + else: + if isinstance(remaining, dict): + try: + inner_errors = remaining['_errors'] + except KeyError: + errors = _flatten_error_dict(remaining, key=key) + else: + errors = {key: ' '.join(x.get('message', '') for x in inner_errors)} + + if isinstance(errors, dict): + messages.extend(f'{indentation} {k}: {v}' for k, v in errors.items()) + + +class CommandSyncFailure(AppCommandError, HTTPException): + """An exception raised when :meth:`CommandTree.sync` failed. + + This provides syncing failures in a slightly more readable format. + + This inherits from :exc:`~discord.app_commands.AppCommandError` + and :exc:`~discord.HTTPException`. + + .. versionadded:: 2.0 + """ + + def __init__(self, child: HTTPException, commands: List[CommandTypes]) -> None: + # Consume the child exception and make it seem as if we are that exception + self.__dict__.update(child.__dict__) + + messages = [f'Failed to upload commands to Discord (HTTP status {self.status}, error code {self.code})'] + + if self._errors: + # Handle case where the errors dict has no actual chain such as APPLICATION_COMMAND_TOO_LARGE + if len(self._errors) == 1 and '_errors' in self._errors: + errors = self._errors['_errors'] + if len(errors) == 1: + extra = errors[0].get('message') + if extra: + messages[0] += f': {extra}' + else: + messages.extend(f'Error {e.get("code", "")}: {e.get("message", "")}' for e in errors) + else: + for index, inner in self._errors.items(): + _get_command_error(index, inner, commands, messages) + + # Equivalent to super().__init__(...) but skips other constructors + self.args = ('\n'.join(messages),) diff --git a/discord/app_commands/installs.py b/discord/app_commands/installs.py new file mode 100644 index 000000000000..e00d13724031 --- /dev/null +++ b/discord/app_commands/installs.py @@ -0,0 +1,213 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import TYPE_CHECKING, ClassVar, List, Optional, Sequence + +__all__ = ( + 'AppInstallationType', + 'AppCommandContext', +) + +if TYPE_CHECKING: + from typing_extensions import Self + from ..types.interactions import InteractionContextType, InteractionInstallationType + + +class AppInstallationType: + r"""Represents the installation location of an application command. + + .. versionadded:: 2.4 + + Parameters + ----------- + guild: Optional[:class:`bool`] + Whether the integration is a guild install. + user: Optional[:class:`bool`] + Whether the integration is a user install. + """ + + __slots__ = ('_guild', '_user') + + GUILD: ClassVar[int] = 0 + USER: ClassVar[int] = 1 + + def __init__(self, *, guild: Optional[bool] = None, user: Optional[bool] = None): + self._guild: Optional[bool] = guild + self._user: Optional[bool] = user + + def __repr__(self): + return f'' + + @property + def guild(self) -> bool: + """:class:`bool`: Whether the integration is a guild install.""" + return bool(self._guild) + + @guild.setter + def guild(self, value: bool) -> None: + self._guild = bool(value) + + @property + def user(self) -> bool: + """:class:`bool`: Whether the integration is a user install.""" + return bool(self._user) + + @user.setter + def user(self, value: bool) -> None: + self._user = bool(value) + + def merge(self, other: AppInstallationType) -> AppInstallationType: + # Merging is similar to AllowedMentions where `self` is the base + # and the `other` is the override preference + guild = self._guild if other._guild is None else other._guild + user = self._user if other._user is None else other._user + return AppInstallationType(guild=guild, user=user) + + def _is_unset(self) -> bool: + return all(x is None for x in (self._guild, self._user)) + + def _merge_to_array(self, other: Optional[AppInstallationType]) -> Optional[List[InteractionInstallationType]]: + result = self.merge(other) if other is not None else self + if result._is_unset(): + return None + return result.to_array() + + @classmethod + def _from_value(cls, value: Sequence[InteractionInstallationType]) -> Self: + self = cls() + for x in value: + if x == cls.GUILD: + self._guild = True + elif x == cls.USER: + self._user = True + return self + + def to_array(self) -> List[InteractionInstallationType]: + values = [] + if self._guild: + values.append(self.GUILD) + if self._user: + values.append(self.USER) + return values + + +class AppCommandContext: + r"""Wraps up the Discord :class:`~discord.app_commands.Command` execution context. + + .. versionadded:: 2.4 + + Parameters + ----------- + guild: Optional[:class:`bool`] + Whether the context allows usage in a guild. + dm_channel: Optional[:class:`bool`] + Whether the context allows usage in a DM channel. + private_channel: Optional[:class:`bool`] + Whether the context allows usage in a DM or a GDM channel. + """ + + GUILD: ClassVar[int] = 0 + DM_CHANNEL: ClassVar[int] = 1 + PRIVATE_CHANNEL: ClassVar[int] = 2 + + __slots__ = ('_guild', '_dm_channel', '_private_channel') + + def __init__( + self, + *, + guild: Optional[bool] = None, + dm_channel: Optional[bool] = None, + private_channel: Optional[bool] = None, + ): + self._guild: Optional[bool] = guild + self._dm_channel: Optional[bool] = dm_channel + self._private_channel: Optional[bool] = private_channel + + def __repr__(self) -> str: + return f'' + + @property + def guild(self) -> bool: + """:class:`bool`: Whether the context allows usage in a guild.""" + return bool(self._guild) + + @guild.setter + def guild(self, value: bool) -> None: + self._guild = bool(value) + + @property + def dm_channel(self) -> bool: + """:class:`bool`: Whether the context allows usage in a DM channel.""" + return bool(self._dm_channel) + + @dm_channel.setter + def dm_channel(self, value: bool) -> None: + self._dm_channel = bool(value) + + @property + def private_channel(self) -> bool: + """:class:`bool`: Whether the context allows usage in a DM or a GDM channel.""" + return bool(self._private_channel) + + @private_channel.setter + def private_channel(self, value: bool) -> None: + self._private_channel = bool(value) + + def merge(self, other: AppCommandContext) -> AppCommandContext: + guild = self._guild if other._guild is None else other._guild + dm_channel = self._dm_channel if other._dm_channel is None else other._dm_channel + private_channel = self._private_channel if other._private_channel is None else other._private_channel + return AppCommandContext(guild=guild, dm_channel=dm_channel, private_channel=private_channel) + + def _is_unset(self) -> bool: + return all(x is None for x in (self._guild, self._dm_channel, self._private_channel)) + + def _merge_to_array(self, other: Optional[AppCommandContext]) -> Optional[List[InteractionContextType]]: + result = self.merge(other) if other is not None else self + if result._is_unset(): + return None + return result.to_array() + + @classmethod + def _from_value(cls, value: Sequence[InteractionContextType]) -> Self: + self = cls() + for x in value: + if x == cls.GUILD: + self._guild = True + elif x == cls.DM_CHANNEL: + self._dm_channel = True + elif x == cls.PRIVATE_CHANNEL: + self._private_channel = True + return self + + def to_array(self) -> List[InteractionContextType]: + values = [] + if self._guild: + values.append(self.GUILD) + if self._dm_channel: + values.append(self.DM_CHANNEL) + if self._private_channel: + values.append(self.PRIVATE_CHANNEL) + return values diff --git a/discord/app_commands/models.py b/discord/app_commands/models.py new file mode 100644 index 000000000000..b51339c2683d --- /dev/null +++ b/discord/app_commands/models.py @@ -0,0 +1,1293 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from datetime import datetime + +from .errors import MissingApplicationID +from ..flags import AppCommandContext, AppInstallationType, ChannelFlags +from .translator import TranslationContextLocation, TranslationContext, locale_str, Translator +from ..permissions import Permissions +from ..enums import ( + AppCommandOptionType, + AppCommandType, + AppCommandPermissionType, + ChannelType, + Locale, + try_enum, +) +import array +from ..mixins import Hashable +from ..utils import _get_as_snowflake, parse_time, snowflake_time, MISSING +from ..object import Object +from ..role import Role +from ..member import Member + +from typing import Any, Dict, Generic, List, TYPE_CHECKING, Optional, TypeVar, Union + +__all__ = ( + 'AppCommand', + 'AppCommandGroup', + 'AppCommandChannel', + 'AppCommandThread', + 'AppCommandPermissions', + 'GuildAppCommandPermissions', + 'Argument', + 'Choice', + 'AllChannels', +) + +ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float]) + + +def is_app_command_argument_type(value: int) -> bool: + return 11 >= value >= 3 + + +if TYPE_CHECKING: + from ..types.command import ( + ApplicationCommand as ApplicationCommandPayload, + ApplicationCommandOption, + ApplicationCommandOptionChoice, + ApplicationCommandPermissions, + GuildApplicationCommandPermissions, + ) + from ..types.interactions import ( + PartialChannel, + PartialThread, + ) + from ..types.threads import ( + ThreadMetadata, + ThreadArchiveDuration, + ) + + from ..abc import Snowflake + from ..state import ConnectionState + from ..guild import GuildChannel, Guild + from ..channel import TextChannel, ForumChannel, ForumTag + from ..threads import Thread + from ..user import User + + ApplicationCommandParent = Union['AppCommand', 'AppCommandGroup'] + + +class AllChannels: + """Represents all channels for application command permissions. + + .. versionadded:: 2.0 + + Attributes + ----------- + guild: :class:`~discord.Guild` + The guild the application command permission is for. + """ + + __slots__ = ('guild',) + + def __init__(self, guild: Guild): + self.guild: Guild = guild + + @property + def id(self) -> int: + """:class:`int`: The ID sentinel used to represent all channels. Equivalent to the guild's ID minus 1.""" + return self.guild.id - 1 + + def __repr__(self) -> str: + return f'' + + +def _to_locale_dict(data: Dict[str, str]) -> Dict[Locale, str]: + return {try_enum(Locale, key): value for key, value in data.items()} + + +class AppCommand(Hashable): + """Represents an application command. + + In common parlance this is referred to as a "Slash Command" or a + "Context Menu Command". + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two application commands are equal. + + .. describe:: x != y + + Checks if two application commands are not equal. + + .. describe:: hash(x) + + Returns the application command's hash. + + .. describe:: str(x) + + Returns the application command's name. + + Attributes + ----------- + id: :class:`int` + The application command's ID. + application_id: :class:`int` + The application command's application's ID. + type: :class:`~discord.AppCommandType` + The application command's type. + name: :class:`str` + The application command's name. + description: :class:`str` + The application command's description. + name_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised names of the application command. Used for display purposes. + description_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised descriptions of the application command. Used for display purposes. + options: List[Union[:class:`Argument`, :class:`AppCommandGroup`]] + A list of options. + default_member_permissions: Optional[:class:`~discord.Permissions`] + The default member permissions that can run this command. + dm_permission: :class:`bool` + A boolean that indicates whether this command can be run in direct messages. + allowed_contexts: Optional[:class:`~discord.app_commands.AppCommandContext`] + The contexts that this command is allowed to be used in. Overrides the ``dm_permission`` attribute. + + .. versionadded:: 2.4 + allowed_installs: Optional[:class:`~discord.app_commands.AppInstallationType`] + The installation contexts that this command is allowed to be installed in. + + .. versionadded:: 2.4 + guild_id: Optional[:class:`int`] + The ID of the guild this command is registered in. A value of ``None`` + denotes that it is a global command. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. + """ + + __slots__ = ( + 'id', + 'type', + 'application_id', + 'name', + 'description', + 'name_localizations', + 'description_localizations', + 'guild_id', + 'options', + 'default_member_permissions', + 'dm_permission', + 'allowed_contexts', + 'allowed_installs', + 'nsfw', + '_state', + ) + + def __init__(self, *, data: ApplicationCommandPayload, state: ConnectionState) -> None: + self._state: ConnectionState = state + self._from_data(data) + + def _from_data(self, data: ApplicationCommandPayload) -> None: + self.id: int = int(data['id']) + self.application_id: int = int(data['application_id']) + self.name: str = data['name'] + self.description: str = data['description'] + self.guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + self.type: AppCommandType = try_enum(AppCommandType, data.get('type', 1)) + self.options: List[Union[Argument, AppCommandGroup]] = [ + app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', []) + ] + self.default_member_permissions: Optional[Permissions] + permissions = data.get('default_member_permissions') + if permissions is None: + self.default_member_permissions = None + else: + self.default_member_permissions = Permissions(int(permissions)) + + dm_permission = data.get('dm_permission') + # For some reason this field can be explicit null and mean True + if dm_permission is None: + dm_permission = True + + self.dm_permission: bool = dm_permission + + allowed_contexts = data.get('contexts') + if allowed_contexts is None: + self.allowed_contexts: Optional[AppCommandContext] = None + else: + self.allowed_contexts = AppCommandContext._from_value(allowed_contexts) + + allowed_installs = data.get('integration_types') + if allowed_installs is None: + self.allowed_installs: Optional[AppInstallationType] = None + else: + self.allowed_installs = AppInstallationType._from_value(allowed_installs) + + self.nsfw: bool = data.get('nsfw', False) + self.name_localizations: Dict[Locale, str] = _to_locale_dict(data.get('name_localizations') or {}) + self.description_localizations: Dict[Locale, str] = _to_locale_dict(data.get('description_localizations') or {}) + + def to_dict(self) -> ApplicationCommandPayload: + return { + 'id': self.id, + 'type': self.type.value, + 'application_id': self.application_id, + 'name': self.name, + 'description': self.description, + 'name_localizations': {str(k): v for k, v in self.name_localizations.items()}, + 'description_localizations': {str(k): v for k, v in self.description_localizations.items()}, + 'contexts': self.allowed_contexts.to_array() if self.allowed_contexts is not None else None, + 'integration_types': self.allowed_installs.to_array() if self.allowed_installs is not None else None, + 'options': [opt.to_dict() for opt in self.options], + } # type: ignore # Type checker does not understand this literal. + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>' + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given AppCommand.""" + return f'' + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`~discord.Guild`]: Returns the guild this command is registered to + if it exists. + """ + return self._state._get_guild(self.guild_id) + + async def delete(self) -> None: + """|coro| + + Deletes the application command. + + Raises + ------- + NotFound + The application command was not found. + Forbidden + You do not have permission to delete this application command. + HTTPException + Deleting the application command failed. + MissingApplicationID + The client does not have an application ID. + """ + state = self._state + if not state.application_id: + raise MissingApplicationID + + if self.guild_id: + await state.http.delete_guild_command( + state.application_id, + self.guild_id, + self.id, + ) + else: + await state.http.delete_global_command( + state.application_id, + self.id, + ) + + async def edit( + self, + *, + name: str = MISSING, + description: str = MISSING, + default_member_permissions: Optional[Permissions] = MISSING, + dm_permission: bool = MISSING, + options: List[Union[Argument, AppCommandGroup]] = MISSING, + ) -> AppCommand: + """|coro| + + Edits the application command. + + Parameters + ----------- + name: :class:`str` + The new name for the application command. + description: :class:`str` + The new description for the application command. + default_member_permissions: Optional[:class:`~discord.Permissions`] + The new default permissions needed to use this application command. + Pass value of ``None`` to remove any permission requirements. + dm_permission: :class:`bool` + Indicates if the application command can be used in DMs. + options: List[Union[:class:`Argument`, :class:`AppCommandGroup`]] + List of new options for this application command. + + Raises + ------- + NotFound + The application command was not found. + Forbidden + You do not have permission to edit this application command. + HTTPException + Editing the application command failed. + MissingApplicationID + The client does not have an application ID. + + Returns + -------- + :class:`AppCommand` + The newly edited application command. + """ + state = self._state + if not state.application_id: + raise MissingApplicationID + + payload = {} + + if name is not MISSING: + payload['name'] = name + + if description is not MISSING: + payload['description'] = description + + if default_member_permissions is not MISSING: + if default_member_permissions is not None: + payload['default_member_permissions'] = default_member_permissions.value + else: + payload['default_member_permissions'] = None + + if self.guild_id is None and dm_permission is not MISSING: + payload['dm_permission'] = dm_permission + + if options is not MISSING: + payload['options'] = [option.to_dict() for option in options] + + if not payload: + return self + + if self.guild_id: + data = await state.http.edit_guild_command( + state.application_id, + self.guild_id, + self.id, + payload, + ) + else: + data = await state.http.edit_global_command( + state.application_id, + self.id, + payload, + ) + return AppCommand(data=data, state=state) + + async def fetch_permissions(self, guild: Snowflake) -> GuildAppCommandPermissions: + """|coro| + + Retrieves this command's permission in the guild. + + Parameters + ----------- + guild: :class:`~discord.abc.Snowflake` + The guild to retrieve the permissions from. + + Raises + ------- + Forbidden + You do not have permission to fetch the application command's permissions. + HTTPException + Fetching the application command's permissions failed. + MissingApplicationID + The client does not have an application ID. + NotFound + The application command's permissions could not be found. + This can also indicate that the permissions are synced with the guild + (i.e. they are unchanged from the default). + + Returns + -------- + :class:`GuildAppCommandPermissions` + An object representing the application command's permissions in the guild. + """ + state = self._state + if not state.application_id: + raise MissingApplicationID + + data = await state.http.get_application_command_permissions( + state.application_id, + guild.id, + self.id, + ) + return GuildAppCommandPermissions(data=data, state=state, command=self) + + +class Choice(Generic[ChoiceT]): + """Represents an application command argument choice. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two choices are equal. + + .. describe:: x != y + + Checks if two choices are not equal. + + .. describe:: hash(x) + + Returns the choice's hash. + + Parameters + ----------- + name: Union[:class:`str`, :class:`locale_str`] + The name of the choice. Used for display purposes. + Can only be up to 100 characters. + name_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised names of the choice. Used for display purposes. + value: Union[:class:`int`, :class:`str`, :class:`float`] + The value of the choice. If it's a string, it can only be + up to 100 characters long. + """ + + __slots__ = ('name', 'value', '_locale_name', 'name_localizations') + + def __init__(self, *, name: Union[str, locale_str], value: ChoiceT): + name, locale = (name.message, name) if isinstance(name, locale_str) else (name, None) + self.name: str = name + self._locale_name: Optional[locale_str] = locale + self.value: ChoiceT = value + self.name_localizations: Dict[Locale, str] = {} + + @classmethod + def from_dict(cls, data: ApplicationCommandOptionChoice) -> Choice[ChoiceT]: + self = cls.__new__(cls) + self.name = data['name'] + self.value = data['value'] # type: ignore # This seems to break every other pyright release + self.name_localizations = _to_locale_dict(data.get('name_localizations') or {}) + return self + + def __eq__(self, o: object) -> bool: + return isinstance(o, Choice) and self.name == o.name and self.value == o.value + + def __hash__(self) -> int: + return hash((self.name, self.value)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(name={self.name!r}, value={self.value!r})' + + @property + def _option_type(self) -> AppCommandOptionType: + if isinstance(self.value, int): + return AppCommandOptionType.integer + elif isinstance(self.value, float): + return AppCommandOptionType.number + elif isinstance(self.value, str): + return AppCommandOptionType.string + else: + raise TypeError( + f'invalid Choice value type given, expected int, str, or float but received {self.value.__class__.__name__}' + ) + + async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]: + base = self.to_dict() + name_localizations: Dict[str, str] = {} + context = TranslationContext(location=TranslationContextLocation.choice_name, data=self) + if self._locale_name: + for locale in Locale: + translation = await translator._checked_translate(self._locale_name, locale, context) + if translation is not None: + name_localizations[locale.value] = translation + + if name_localizations: + base['name_localizations'] = name_localizations + + return base + + async def get_translated_payload_for_locale(self, translator: Translator, locale: Locale) -> Dict[str, Any]: + base = self.to_dict() + if self._locale_name: + context = TranslationContext(location=TranslationContextLocation.choice_name, data=self) + translation = await translator._checked_translate(self._locale_name, locale, context) + if translation is not None: + base['name'] = translation + + return base + + def to_dict(self) -> Dict[str, Any]: + base = { + 'name': self.name, + 'value': self.value, + } + if self.name_localizations: + base['name_localizations'] = {str(k): v for k, v in self.name_localizations.items()} + return base + + +class AppCommandChannel(Hashable): + """Represents an application command partially resolved channel object. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ----------- + id: :class:`int` + The ID of the channel. + type: :class:`~discord.ChannelType` + The type of channel. + name: :class:`str` + The name of the channel. + permissions: :class:`~discord.Permissions` + The resolved permissions of the user who invoked + the application command in that channel. + guild_id: :class:`int` + The guild ID this channel belongs to. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + + .. versionadded:: 2.6 + topic: Optional[:class:`str`] + The channel's topic. ``None`` if it doesn't exist. + + .. versionadded:: 2.6 + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + + .. versionadded:: 2.6 + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + + .. versionadded:: 2.6 + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~discord.Permissions.manage_channels` or + :attr:`~discord.Permissions.manage_messages` bypass slowmode. + + .. versionadded:: 2.6 + nsfw: :class:`bool` + If the channel is marked as "not safe for work" or "age restricted". + + .. versionadded:: 2.6 + """ + + __slots__ = ( + 'id', + 'type', + 'name', + 'permissions', + 'guild_id', + 'topic', + 'nsfw', + 'position', + 'category_id', + 'slowmode_delay', + 'last_message_id', + '_last_pin', + '_flags', + '_state', + ) + + def __init__( + self, + *, + state: ConnectionState, + data: PartialChannel, + guild_id: int, + ): + self._state: ConnectionState = state + self.guild_id: int = guild_id + self.id: int = int(data['id']) + self.type: ChannelType = try_enum(ChannelType, data['type']) + self.name: str = data['name'] + self.permissions: Permissions = Permissions(int(data['permissions'])) + self.topic: Optional[str] = data.get('topic') + self.position: int = data.get('position') or 0 + self.nsfw: bool = data.get('nsfw') or False + self.category_id: Optional[int] = _get_as_snowflake(data, 'parent_id') + self.slowmode_delay: int = data.get('rate_limit_per_user') or 0 + self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id') + self._last_pin: Optional[datetime] = parse_time(data.get('last_pin_timestamp')) + self._flags: int = data.get('flags', 0) + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>' + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found.""" + return self._state._get_guild(self.guild_id) + + @property + def flags(self) -> ChannelFlags: + """:class:`~discord.ChannelFlags`: The flags associated with this channel object. + + .. versionadded:: 2.6 + """ + return ChannelFlags._from_value(self._flags) + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the channel is NSFW. + + .. versionadded:: 2.6 + """ + return self.nsfw + + def is_news(self) -> bool: + """:class:`bool`: Checks if the channel is a news channel. + + .. versionadded:: 2.6 + """ + return self.type == ChannelType.news + + def resolve(self) -> Optional[GuildChannel]: + """Resolves the application command channel to the appropriate channel + from cache if found. + + Returns + -------- + Optional[:class:`.abc.GuildChannel`] + The resolved guild channel or ``None`` if not found in cache. + """ + guild = self._state._get_guild(self.guild_id) + if guild is not None: + return guild.get_channel(self.id) + return None + + async def fetch(self) -> GuildChannel: + """|coro| + + Fetches the partial channel to a full :class:`.abc.GuildChannel`. + + Raises + -------- + NotFound + The channel was not found. + Forbidden + You do not have the permissions required to get a channel. + HTTPException + Retrieving the channel failed. + + Returns + -------- + :class:`.abc.GuildChannel` + The full channel. + """ + client = self._state._get_client() + return await client.fetch_channel(self.id) # type: ignore # This is explicit narrowing + + @property + def mention(self) -> str: + """:class:`str`: The string that allows you to mention the channel.""" + return f'<#{self.id}>' + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.6 + """ + return f'https://discord.com/channels/{self.guild_id}/{self.id}' + + @property + def created_at(self) -> datetime: + """:class:`datetime.datetime`: An aware timestamp of when this channel was created in UTC.""" + return snowflake_time(self.id) + + +class AppCommandThread(Hashable): + """Represents an application command partially resolved thread object. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two thread are equal. + + .. describe:: x != y + + Checks if two thread are not equal. + + .. describe:: hash(x) + + Returns the thread's hash. + + .. describe:: str(x) + + Returns the thread's name. + + Attributes + ----------- + id: :class:`int` + The ID of the thread. + type: :class:`~discord.ChannelType` + The type of thread. + name: :class:`str` + The name of the thread. + parent_id: :class:`int` + The parent text channel ID this thread belongs to. + owner_id: :class:`int` + The user's ID that created this thread. + + .. versionadded:: 2.6 + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this thread. It may + *not* point to an existing or valid message. + + .. versionadded:: 2.6 + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this thread. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~discord.Permissions.manage_channels` or + :attr:`~discord.Permissions.manage_messages` bypass slowmode. + + .. versionadded:: 2.6 + message_count: :class:`int` + An approximate number of messages in this thread. + + .. versionadded:: 2.6 + member_count: :class:`int` + An approximate number of members in this thread. This caps at 50. + + .. versionadded:: 2.6 + total_message_sent: :class:`int` + The total number of messages sent, including deleted messages. + + .. versionadded:: 2.6 + permissions: :class:`~discord.Permissions` + The resolved permissions of the user who invoked + the application command in that thread. + guild_id: :class:`int` + The guild ID this thread belongs to. + archived: :class:`bool` + Whether the thread is archived. + locked: :class:`bool` + Whether the thread is locked. + invitable: :class:`bool` + Whether non-moderators can add other non-moderators to this thread. + This is always ``True`` for public threads. + archiver_id: Optional[:class:`int`] + The user's ID that archived this thread. + auto_archive_duration: :class:`int` + The duration in minutes until the thread is automatically hidden from the channel list. + Usually a value of 60, 1440, 4320 and 10080. + archive_timestamp: :class:`datetime.datetime` + An aware timestamp of when the thread's archived status was last updated in UTC. + """ + + __slots__ = ( + 'id', + 'type', + 'name', + 'permissions', + 'guild_id', + 'parent_id', + 'archived', + 'archiver_id', + 'auto_archive_duration', + 'archive_timestamp', + 'locked', + 'invitable', + 'owner_id', + 'message_count', + 'member_count', + 'slowmode_delay', + 'last_message_id', + 'total_message_sent', + '_applied_tags', + '_flags', + '_created_at', + '_state', + ) + + def __init__( + self, + *, + state: ConnectionState, + data: PartialThread, + guild_id: int, + ): + self._state: ConnectionState = state + self.guild_id: int = guild_id + self.id: int = int(data['id']) + self.parent_id: int = int(data['parent_id']) + self.type: ChannelType = try_enum(ChannelType, data['type']) + self.name: str = data['name'] + self.permissions: Permissions = Permissions(int(data['permissions'])) + self.owner_id: int = int(data['owner_id']) + self.member_count: int = int(data['member_count']) + self.message_count: int = int(data['message_count']) + self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get('rate_limit_per_user', 0) + self.total_message_sent: int = data.get('total_message_sent', 0) + self._applied_tags: array.array[int] = array.array('Q', map(int, data.get('applied_tags', []))) + self._flags: int = data.get('flags', 0) + self._unroll_metadata(data['thread_metadata']) + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} archived={self.archived} type={self.type!r}>' + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`~discord.Guild`]: The channel's guild, from cache, if found.""" + return self._state._get_guild(self.guild_id) + + def _unroll_metadata(self, data: ThreadMetadata) -> None: + self.archived: bool = data['archived'] + self.archiver_id: Optional[int] = _get_as_snowflake(data, 'archiver_id') + self.auto_archive_duration: ThreadArchiveDuration = data['auto_archive_duration'] + self.archive_timestamp: datetime = parse_time(data['archive_timestamp']) + self.locked: bool = data.get('locked', False) + self.invitable: bool = data.get('invitable', True) + self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp')) + + @property + def applied_tags(self) -> List[ForumTag]: + """List[:class:`~discord.ForumTag`]: A list of tags applied to this thread. + + .. versionadded:: 2.6 + """ + tags = [] + if self.parent is None or self.parent.type not in (ChannelType.forum, ChannelType.media): + return tags + + parent = self.parent + for tag_id in self._applied_tags: + tag = parent.get_tag(tag_id) # type: ignore # parent here will be ForumChannel instance + if tag is not None: + tags.append(tag) + + return tags + + @property + def parent(self) -> Optional[Union[ForumChannel, TextChannel]]: + """Optional[Union[:class:`~discord.ForumChannel`, :class:`~discord.TextChannel`]]: The parent channel + this thread belongs to.""" + return self.guild and self.guild.get_channel(self.parent_id) # type: ignore + + @property + def flags(self) -> ChannelFlags: + """:class:`~discord.ChannelFlags`: The flags associated with this thread. + + .. versionadded:: 2.6 + """ + return ChannelFlags._from_value(self._flags) + + @property + def owner(self) -> Optional[Member]: + """Optional[:class:`~discord.Member`]: The member this thread belongs to. + + .. versionadded:: 2.6 + """ + return self.guild and self.guild.get_member(self.owner_id) + + @property + def mention(self) -> str: + """:class:`str`: The string that allows you to mention the thread.""" + return f'<#{self.id}>' + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the thread. + + .. versionadded:: 2.6 + """ + return f'https://discord.com/channels/{self.guild_id}/{self.id}' + + @property + def created_at(self) -> Optional[datetime]: + """An aware timestamp of when the thread was created in UTC. + + .. note:: + + This timestamp only exists for threads created after 9 January 2022, otherwise returns ``None``. + """ + return self._created_at + + def resolve(self) -> Optional[Thread]: + """Resolves the application command channel to the appropriate channel + from cache if found. + + Returns + -------- + Optional[:class:`.abc.GuildChannel`] + The resolved guild channel or ``None`` if not found in cache. + """ + guild = self._state._get_guild(self.guild_id) + if guild is not None: + return guild.get_thread(self.id) + return None + + async def fetch(self) -> Thread: + """|coro| + + Fetches the partial channel to a full :class:`~discord.Thread`. + + Raises + -------- + NotFound + The thread was not found. + Forbidden + You do not have the permissions required to get a thread. + HTTPException + Retrieving the thread failed. + + Returns + -------- + :class:`~discord.Thread` + The full thread. + """ + client = self._state._get_client() + return await client.fetch_channel(self.id) # type: ignore # This is explicit narrowing + + +class Argument: + """Represents an application command argument. + + .. versionadded:: 2.0 + + Attributes + ------------ + type: :class:`~discord.AppCommandOptionType` + The type of argument. + name: :class:`str` + The name of the argument. + description: :class:`str` + The description of the argument. + name_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised names of the argument. Used for display purposes. + description_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised descriptions of the argument. Used for display purposes. + required: :class:`bool` + Whether the argument is required. + choices: List[:class:`Choice`] + A list of choices for the command to choose from for this argument. + parent: Union[:class:`AppCommand`, :class:`AppCommandGroup`] + The parent application command that has this argument. + channel_types: List[:class:`~discord.ChannelType`] + The channel types that are allowed for this parameter. + min_value: Optional[Union[:class:`int`, :class:`float`]] + The minimum supported value for this parameter. + max_value: Optional[Union[:class:`int`, :class:`float`]] + The maximum supported value for this parameter. + min_length: Optional[:class:`int`] + The minimum allowed length for this parameter. + max_length: Optional[:class:`int`] + The maximum allowed length for this parameter. + autocomplete: :class:`bool` + Whether the argument has autocomplete. + """ + + __slots__ = ( + 'type', + 'name', + 'description', + 'name_localizations', + 'description_localizations', + 'required', + 'choices', + 'channel_types', + 'min_value', + 'max_value', + 'min_length', + 'max_length', + 'autocomplete', + 'parent', + '_state', + ) + + def __init__( + self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None + ) -> None: + self._state: Optional[ConnectionState] = state + self.parent: ApplicationCommandParent = parent + self._from_data(data) + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r} required={self.required}>' + + def _from_data(self, data: ApplicationCommandOption) -> None: + self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) + self.name: str = data['name'] + self.description: str = data['description'] + self.required: bool = data.get('required', False) + self.min_value: Optional[Union[int, float]] = data.get('min_value') + self.max_value: Optional[Union[int, float]] = data.get('max_value') + self.min_length: Optional[int] = data.get('min_length') + self.max_length: Optional[int] = data.get('max_length') + self.autocomplete: bool = data.get('autocomplete', False) + self.channel_types: List[ChannelType] = [try_enum(ChannelType, d) for d in data.get('channel_types', [])] + self.choices: List[Choice[Union[int, float, str]]] = [Choice.from_dict(d) for d in data.get('choices', [])] + self.name_localizations: Dict[Locale, str] = _to_locale_dict(data.get('name_localizations') or {}) + self.description_localizations: Dict[Locale, str] = _to_locale_dict(data.get('description_localizations') or {}) + + def to_dict(self) -> ApplicationCommandOption: + return { + 'name': self.name, + 'type': self.type.value, + 'description': self.description, + 'required': self.required, + 'choices': [choice.to_dict() for choice in self.choices], + 'channel_types': [channel_type.value for channel_type in self.channel_types], + 'min_value': self.min_value, + 'max_value': self.max_value, + 'min_length': self.min_length, + 'max_length': self.max_length, + 'autocomplete': self.autocomplete, + 'options': [], + 'name_localizations': {str(k): v for k, v in self.name_localizations.items()}, + 'description_localizations': {str(k): v for k, v in self.description_localizations.items()}, + } # type: ignore # Type checker does not understand this literal. + + +class AppCommandGroup: + """Represents an application command subcommand. + + .. versionadded:: 2.0 + + Attributes + ------------ + type: :class:`~discord.AppCommandOptionType` + The type of subcommand. + name: :class:`str` + The name of the subcommand. + description: :class:`str` + The description of the subcommand. + name_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised names of the subcommand. Used for display purposes. + description_localizations: Dict[:class:`~discord.Locale`, :class:`str`] + The localised descriptions of the subcommand. Used for display purposes. + options: List[Union[:class:`Argument`, :class:`AppCommandGroup`]] + A list of options. + parent: Union[:class:`AppCommand`, :class:`AppCommandGroup`] + The parent application command. + """ + + __slots__ = ( + 'type', + 'name', + 'description', + 'name_localizations', + 'description_localizations', + 'options', + 'parent', + '_state', + ) + + def __init__( + self, *, parent: ApplicationCommandParent, data: ApplicationCommandOption, state: Optional[ConnectionState] = None + ) -> None: + self.parent: ApplicationCommandParent = parent + self._state: Optional[ConnectionState] = state + self._from_data(data) + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} name={self.name!r} type={self.type!r}>' + + @property + def qualified_name(self) -> str: + """:class:`str`: Returns the fully qualified command name. + + The qualified name includes the parent name as well. For example, + in a command like ``/foo bar`` the qualified name is ``foo bar``. + """ + # A B C + # ^ self + # ^ parent + # ^ grandparent + names = [self.name, self.parent.name] + if isinstance(self.parent, AppCommandGroup): + names.append(self.parent.parent.name) + + return ' '.join(reversed(names)) + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given AppCommandGroup.""" + if isinstance(self.parent, AppCommand): + base_command = self.parent + else: + base_command = self.parent.parent + return f'' # type: ignore + + def _from_data(self, data: ApplicationCommandOption) -> None: + self.type: AppCommandOptionType = try_enum(AppCommandOptionType, data['type']) + self.name: str = data['name'] + self.description: str = data['description'] + self.options: List[Union[Argument, AppCommandGroup]] = [ + app_command_option_factory(data=d, parent=self, state=self._state) for d in data.get('options', []) + ] + self.name_localizations: Dict[Locale, str] = _to_locale_dict(data.get('name_localizations') or {}) + self.description_localizations: Dict[Locale, str] = _to_locale_dict(data.get('description_localizations') or {}) + + def to_dict(self) -> 'ApplicationCommandOption': + return { + 'name': self.name, + 'type': self.type.value, + 'description': self.description, + 'options': [arg.to_dict() for arg in self.options], + 'name_localizations': {str(k): v for k, v in self.name_localizations.items()}, + 'description_localizations': {str(k): v for k, v in self.description_localizations.items()}, + } # type: ignore # Type checker does not understand this literal. + + +class AppCommandPermissions: + """Represents the permissions for an application command. + + .. versionadded:: 2.0 + + Attributes + ----------- + guild: :class:`~discord.Guild` + The guild associated with this permission. + id: :class:`int` + The ID of the permission target, such as a role, channel, or guild. + The special ``guild_id - 1`` sentinel is used to represent "all channels". + target: Any + The role, user, or channel associated with this permission. This could also be the :class:`AllChannels` sentinel type. + Falls back to :class:`~discord.Object` if the target could not be found in the cache. + type: :class:`.AppCommandPermissionType` + The type of permission. + permission: :class:`bool` + The permission value. ``True`` for allow, ``False`` for deny. + """ + + __slots__ = ('id', 'type', 'permission', 'target', 'guild', '_state') + + def __init__(self, *, data: ApplicationCommandPermissions, guild: Guild, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.guild: Guild = guild + + self.id: int = int(data['id']) + self.type: AppCommandPermissionType = try_enum(AppCommandPermissionType, data['type']) + self.permission: bool = data['permission'] + + _object = None + _type = MISSING + + if self.type is AppCommandPermissionType.user: + _object = guild.get_member(self.id) or self._state.get_user(self.id) + _type = Member + elif self.type is AppCommandPermissionType.channel: + if self.id == (guild.id - 1): + _object = AllChannels(guild) + else: + _object = guild.get_channel(self.id) + elif self.type is AppCommandPermissionType.role: + _object = guild.get_role(self.id) + _type = Role + + if _object is None: + _object = Object(id=self.id, type=_type) + + self.target: Union[Object, User, Member, Role, AllChannels, GuildChannel] = _object + + def __repr__(self) -> str: + return f'' + + def to_dict(self) -> ApplicationCommandPermissions: + return { + 'id': self.target.id, + 'type': self.type.value, + 'permission': self.permission, + } + + +class GuildAppCommandPermissions: + """Represents the permissions for an application command in a guild. + + .. versionadded:: 2.0 + + Attributes + ----------- + application_id: :class:`int` + The application ID. + command: :class:`.AppCommand` + The application command associated with the permissions. + id: :class:`int` + ID of the command or the application ID. + When this is the application ID instead of a command ID, + the permissions apply to all commands that do not contain explicit overwrites. + guild_id: :class:`int` + The guild ID associated with the permissions. + permissions: List[:class:`AppCommandPermissions`] + The permissions, this is a max of 100. + """ + + __slots__ = ('id', 'application_id', 'command', 'guild_id', 'permissions', '_state') + + def __init__(self, *, data: GuildApplicationCommandPermissions, state: ConnectionState, command: AppCommand) -> None: + self._state: ConnectionState = state + self.command: AppCommand = command + + self.id: int = int(data['id']) + self.application_id: int = int(data['application_id']) + self.guild_id: int = int(data['guild_id']) + guild = self.guild + self.permissions: List[AppCommandPermissions] = [ + AppCommandPermissions(data=value, guild=guild, state=self._state) for value in data['permissions'] + ] + + def __repr__(self) -> str: + return f'' + + def to_dict(self) -> Dict[str, Any]: + return {'permissions': [p.to_dict() for p in self.permissions]} + + @property + def guild(self) -> Guild: + """:class:`~discord.Guild`: The guild associated with the permissions.""" + return self._state._get_or_create_unavailable_guild(self.guild_id) + + +def app_command_option_factory( + parent: ApplicationCommandParent, data: ApplicationCommandOption, *, state: Optional[ConnectionState] = None +) -> Union[Argument, AppCommandGroup]: + if is_app_command_argument_type(data['type']): + return Argument(parent=parent, data=data, state=state) + else: + return AppCommandGroup(parent=parent, data=data, state=state) diff --git a/discord/app_commands/namespace.py b/discord/app_commands/namespace.py new file mode 100644 index 000000000000..0cac8cb24c85 --- /dev/null +++ b/discord/app_commands/namespace.py @@ -0,0 +1,263 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, NamedTuple, Tuple +from ..member import Member +from ..object import Object +from ..role import Role +from ..message import Message, Attachment +from ..channel import PartialMessageable +from ..enums import AppCommandOptionType +from .models import AppCommandChannel, AppCommandThread + +if TYPE_CHECKING: + from ..interactions import Interaction + from ..types.interactions import ResolvedData, ApplicationCommandInteractionDataOption + +__all__ = ('Namespace',) + + +class ResolveKey(NamedTuple): + id: str + # CommandOptionType does not use 0 or negative numbers so those can be safe for library + # internal use, if necessary. Likewise, only 6, 7, 8, and 11 are actually in use. + type: int + + @classmethod + def any_with(cls, id: str) -> ResolveKey: + return ResolveKey(id=id, type=-1) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, ResolveKey): + return NotImplemented + if self.type == -1 or o.type == -1: + return self.id == o.id + return (self.id, self.type) == (o.id, o.type) + + def __hash__(self) -> int: + # Most of the time an ID lookup is all that is necessary + # In case of collision then we look up both the ID and the type. + return hash(self.id) + + +class Namespace: + """An object that holds the parameters being passed to a command in a mostly raw state. + + This class is deliberately simple and just holds the option name and resolved value as a simple + key-pair mapping. These attributes can be accessed using dot notation. For example, an option + with the name of ``example`` can be accessed using ``ns.example``. If an attribute is not found, + then ``None`` is returned rather than an attribute error. + + .. warning:: + + The key names come from the raw Discord data, which means that if a parameter was renamed then the + renamed key is used instead of the function parameter name. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two namespaces are equal by checking if all attributes are equal. + .. describe:: x != y + + Checks if two namespaces are not equal. + .. describe:: x[key] + + Returns an attribute if it is found, otherwise raises + a :exc:`KeyError`. + .. describe:: key in x + + Checks if the attribute is in the namespace. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + This namespace object converts resolved objects into their appropriate form depending on their + type. Consult the table below for conversion information. + + +-------------------------------------------+-------------------------------------------------------------------------------+ + | Option Type | Resolved Type | + +===========================================+===============================================================================+ + | :attr:`.AppCommandOptionType.string` | :class:`str` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.integer` | :class:`int` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.boolean` | :class:`bool` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.number` | :class:`float` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.user` | :class:`~discord.User` or :class:`~discord.Member` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.channel` | :class:`.AppCommandChannel` or :class:`.AppCommandThread` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.role` | :class:`~discord.Role` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.mentionable` | :class:`~discord.User` or :class:`~discord.Member`, or :class:`~discord.Role` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + | :attr:`.AppCommandOptionType.attachment` | :class:`~discord.Attachment` | + +-------------------------------------------+-------------------------------------------------------------------------------+ + + .. note:: + + In autocomplete interactions, the namespace might not be validated or filled in. Discord does not + send the resolved data as well, so this means that certain fields end up just as IDs rather than + the resolved data. In these cases, a :class:`discord.Object` is returned instead. + + This is a Discord limitation. + """ + + def __init__( + self, + interaction: Interaction, + resolved: ResolvedData, + options: List[ApplicationCommandInteractionDataOption], + ): + completed = self._get_resolved_items(interaction, resolved) + for option in options: + opt_type = option['type'] + name = option['name'] + focused = option.get('focused', False) + if opt_type in (3, 4, 5): # string, integer, boolean + value = option['value'] # type: ignore # Key is there + self.__dict__[name] = value + elif opt_type == 10: # number + value = option['value'] # type: ignore # Key is there + # This condition is written this way because 0 can be a valid float + if value is None or value == '': + self.__dict__[name] = float('nan') + else: + if not focused: + self.__dict__[name] = float(value) + else: + # Autocomplete focused values tend to be garbage in + self.__dict__[name] = value + elif opt_type in (6, 7, 8, 9, 11): + # Remaining ones should be snowflake based ones with resolved data + snowflake: str = option['value'] # type: ignore # Key is there + if opt_type == 9: # Mentionable + # Mentionable is User | Role, these do not cause any conflict + key = ResolveKey.any_with(snowflake) + else: + # The remaining keys can conflict, for example, a role and a channel + # could end up with the same ID in very old guilds since they used to default + # to sharing the guild ID. Old general channels no longer exist, but some old + # servers will still have them so this needs to be handled. + key = ResolveKey(id=snowflake, type=opt_type) + + value = completed.get(key) or Object(id=int(snowflake)) + self.__dict__[name] = value + + @classmethod + def _get_resolved_items(cls, interaction: Interaction, resolved: ResolvedData) -> Dict[ResolveKey, Any]: + completed: Dict[ResolveKey, Any] = {} + state = interaction._state + members = resolved.get('members', {}) + guild_id = interaction.guild_id + guild = interaction.guild + type = AppCommandOptionType.user.value + for user_id, user_data in resolved.get('users', {}).items(): + try: + member_data = members[user_id] + except KeyError: + completed[ResolveKey(id=user_id, type=type)] = state.create_user(user_data) + else: + member_data['user'] = user_data + # Guild ID can't be None in this case. + # There's a type mismatch here that I don't actually care about + member = Member(state=state, guild=guild, data=member_data) # type: ignore + completed[ResolveKey(id=user_id, type=type)] = member + + type = AppCommandOptionType.role.value + completed.update( + { + # The guild ID can't be None in this case. + ResolveKey(id=role_id, type=type): Role(guild=guild, state=state, data=role_data) # type: ignore + for role_id, role_data in resolved.get('roles', {}).items() + } + ) + + type = AppCommandOptionType.channel.value + for channel_id, channel_data in resolved.get('channels', {}).items(): + key = ResolveKey(id=channel_id, type=type) + if channel_data['type'] in (10, 11, 12): + # The guild ID can't be none in this case + completed[key] = AppCommandThread(state=state, data=channel_data, guild_id=guild_id) # type: ignore + else: + # The guild ID can't be none in this case + completed[key] = AppCommandChannel(state=state, data=channel_data, guild_id=guild_id) # type: ignore + + type = AppCommandOptionType.attachment.value + completed.update( + { + ResolveKey(id=attachment_id, type=type): Attachment(data=attachment_data, state=state) + for attachment_id, attachment_data in resolved.get('attachments', {}).items() + } + ) + + for message_id, message_data in resolved.get('messages', {}).items(): + channel_id = int(message_data['channel_id']) + if guild is None: + channel = PartialMessageable(state=state, guild_id=guild_id, id=channel_id) + else: + channel = guild.get_channel_or_thread(channel_id) or PartialMessageable( + state=state, guild_id=guild_id, id=channel_id + ) + + # Type checker doesn't understand this due to failure to narrow + message = Message(state=state, channel=channel, data=message_data) # type: ignore + message.guild = guild + key = ResolveKey(id=message_id, type=-1) + completed[key] = message + + return completed + + def __repr__(self) -> str: + items = (f'{k}={v!r}' for k, v in self.__dict__.items()) + return '<{} {}>'.format(self.__class__.__name__, ' '.join(items)) + + def __eq__(self, other: object) -> bool: + if isinstance(self, Namespace) and isinstance(other, Namespace): + return self.__dict__ == other.__dict__ + return NotImplemented + + def __getitem__(self, key: str) -> Any: + return self.__dict__[key] + + def __contains__(self, key: str) -> Any: + return key in self.__dict__ + + def __getattr__(self, attr: str) -> Any: + return None + + def __iter__(self) -> Iterator[Tuple[str, Any]]: + yield from self.__dict__.items() + + def _update_with_defaults(self, defaults: Iterable[Tuple[str, Any]]) -> None: + for key, value in defaults: + self.__dict__.setdefault(key, value) diff --git a/discord/app_commands/transformers.py b/discord/app_commands/transformers.py new file mode 100644 index 000000000000..212991cbe372 --- /dev/null +++ b/discord/app_commands/transformers.py @@ -0,0 +1,880 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +import inspect + +from dataclasses import dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Coroutine, + Dict, + Generic, + List, + Literal, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from .errors import AppCommandError, TransformerError +from .models import AppCommandChannel, AppCommandThread, Choice +from .translator import TranslationContextLocation, TranslationContext, Translator, locale_str +from ..channel import StageChannel, VoiceChannel, TextChannel, CategoryChannel, ForumChannel +from ..abc import GuildChannel +from ..threads import Thread +from ..enums import Enum as InternalEnum, AppCommandOptionType, ChannelType, Locale +from ..utils import MISSING, maybe_coroutine, _human_join +from ..user import User +from ..role import Role +from ..member import Member +from ..message import Attachment +from .._types import ClientT + +__all__ = ( + 'Transformer', + 'Transform', + 'Range', +) + +T = TypeVar('T') +FuncT = TypeVar('FuncT', bound=Callable[..., Any]) +ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float]) +NoneType = type(None) + +if TYPE_CHECKING: + from ..interactions import Interaction + from .commands import Parameter + + +@dataclass +class CommandParameter: + # The name of the parameter is *always* the parameter name in the code + # Therefore, it can't be Union[str, locale_str] + name: str = MISSING + description: Union[str, locale_str] = MISSING + required: bool = MISSING + default: Any = MISSING + choices: List[Choice[Union[str, int, float]]] = MISSING + type: AppCommandOptionType = MISSING + channel_types: List[ChannelType] = MISSING + min_value: Optional[Union[int, float]] = None + max_value: Optional[Union[int, float]] = None + autocomplete: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None + _rename: Union[str, locale_str] = MISSING + _annotation: Any = MISSING + + async def get_translated_payload(self, translator: Translator, data: Parameter) -> Dict[str, Any]: + base = self.to_dict() + + rename = self._rename + description = self.description + needs_name_translations = isinstance(rename, locale_str) + needs_description_translations = isinstance(description, locale_str) + name_localizations: Dict[str, str] = {} + description_localizations: Dict[str, str] = {} + + # Prevent creating these objects in a heavy loop + name_context = TranslationContext(location=TranslationContextLocation.parameter_name, data=data) + description_context = TranslationContext(location=TranslationContextLocation.parameter_description, data=data) + for locale in Locale: + if needs_name_translations: + translation = await translator._checked_translate(rename, locale, name_context) + if translation is not None: + name_localizations[locale.value] = translation + + if needs_description_translations: + translation = await translator._checked_translate(description, locale, description_context) + if translation is not None: + description_localizations[locale.value] = translation + + if self.choices: + base['choices'] = [await choice.get_translated_payload(translator) for choice in self.choices] + + if name_localizations: + base['name_localizations'] = name_localizations + + if description_localizations: + base['description_localizations'] = description_localizations + + return base + + def to_dict(self) -> Dict[str, Any]: + base = { + 'type': self.type.value, + 'name': self.display_name, + 'description': str(self.description), + 'required': self.required, + } + + if self.choices: + base['choices'] = [choice.to_dict() for choice in self.choices] + if self.channel_types: + base['channel_types'] = [t.value for t in self.channel_types] + if self.autocomplete: + base['autocomplete'] = True + + min_key, max_key = ( + ('min_value', 'max_value') if self.type is not AppCommandOptionType.string else ('min_length', 'max_length') + ) + if self.min_value is not None: + base[min_key] = self.min_value + if self.max_value is not None: + base[max_key] = self.max_value + + return base + + def _convert_to_locale_strings(self) -> None: + if self._rename is MISSING: + self._rename = locale_str(self.name) + elif isinstance(self._rename, str): + self._rename = locale_str(self._rename) + + if isinstance(self.description, str): + self.description = locale_str(self.description) + + if self.choices: + for choice in self.choices: + if choice._locale_name is None: + choice._locale_name = locale_str(choice.name) + + def is_choice_annotation(self) -> bool: + return getattr(self._annotation, '__discord_app_commands_is_choice__', False) + + async def transform(self, interaction: Interaction, value: Any, /) -> Any: + if hasattr(self._annotation, '__discord_app_commands_transformer__'): + # This one needs special handling for type safety reasons + if self._annotation.__discord_app_commands_is_choice__: + choice = next((c for c in self.choices if c.value == value), None) + if choice is None: + raise TransformerError(value, self.type, self._annotation) + return choice + + try: + return await maybe_coroutine(self._annotation.transform, interaction, value) + except AppCommandError: + raise + except Exception as e: + raise TransformerError(value, self.type, self._annotation) from e + + return value + + @property + def display_name(self) -> str: + """:class:`str`: The name of the parameter as it should be displayed to the user.""" + return self.name if self._rename is MISSING else str(self._rename) + + +class Transformer(Generic[ClientT]): + """The base class that allows a type annotation in an application command parameter + to map into a :class:`~discord.AppCommandOptionType` and transform the raw value into one + from this type. + + This class is customisable through the overriding of methods and properties in the class + and by using it as the second type parameter of the :class:`~discord.app_commands.Transform` + class. For example, to convert a string into a custom pair type: + + .. code-block:: python3 + + class Point(typing.NamedTuple): + x: int + y: int + + class PointTransformer(app_commands.Transformer): + async def transform(self, interaction: discord.Interaction, value: str) -> Point: + (x, _, y) = value.partition(',') + return Point(x=int(x.strip()), y=int(y.strip())) + + @app_commands.command() + async def graph( + interaction: discord.Interaction, + point: app_commands.Transform[Point, PointTransformer], + ): + await interaction.response.send_message(str(point)) + + If a class is passed instead of an instance to the second type parameter, then it is + constructed with no arguments passed to the ``__init__`` method. + + .. versionadded:: 2.0 + """ + + __discord_app_commands_transformer__: ClassVar[bool] = True + __discord_app_commands_is_choice__: ClassVar[bool] = False + + # This is needed to pass typing's type checks. + # e.g. Optional[MyTransformer] + def __call__(self) -> None: + pass + + def __or__(self, rhs: Any) -> Any: + return Union[self, rhs] + + @property + def type(self) -> AppCommandOptionType: + """:class:`~discord.AppCommandOptionType`: The option type associated with this transformer. + + This must be a :obj:`property`. + + Defaults to :attr:`~discord.AppCommandOptionType.string`. + """ + return AppCommandOptionType.string + + @property + def channel_types(self) -> List[ChannelType]: + """List[:class:`~discord.ChannelType`]: A list of channel types that are allowed to this parameter. + + Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.channel`. + + This must be a :obj:`property`. + + Defaults to an empty list. + """ + return [] + + @property + def min_value(self) -> Optional[Union[int, float]]: + """Optional[:class:`int`]: The minimum supported value for this parameter. + + Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` + :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. + + This must be a :obj:`property`. + + Defaults to ``None``. + """ + return None + + @property + def max_value(self) -> Optional[Union[int, float]]: + """Optional[:class:`int`]: The maximum supported value for this parameter. + + Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` + :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. + + This must be a :obj:`property`. + + Defaults to ``None``. + """ + return None + + @property + def choices(self) -> Optional[List[Choice[Union[int, float, str]]]]: + """Optional[List[:class:`~discord.app_commands.Choice`]]: A list of up to 25 choices that are allowed to this parameter. + + Only valid if the :meth:`type` returns :attr:`~discord.AppCommandOptionType.number` + :attr:`~discord.AppCommandOptionType.integer`, or :attr:`~discord.AppCommandOptionType.string`. + + This must be a :obj:`property`. + + Defaults to ``None``. + """ + return None + + @property + def _error_display_name(self) -> str: + name = self.__class__.__name__ + if name.endswith('Transformer'): + return name[:-11] + else: + return name + + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: + """|maybecoro| + + Transforms the converted option value into another value. + + The value passed into this transform function is the same as the + one in the :class:`conversion table `. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction being handled. + value: Any + The value of the given argument after being resolved. + See the :class:`conversion table ` + for how certain option types correspond to certain values. + """ + raise NotImplementedError('Derived classes need to implement this.') + + async def autocomplete( + self, interaction: Interaction[ClientT], value: Union[int, float, str], / + ) -> List[Choice[Union[int, float, str]]]: + """|coro| + + An autocomplete prompt handler to be automatically used by options using this transformer. + + .. note:: + + Autocomplete is only supported for options with a :meth:`~discord.app_commands.Transformer.type` + of :attr:`~discord.AppCommandOptionType.string`, :attr:`~discord.AppCommandOptionType.integer`, + or :attr:`~discord.AppCommandOptionType.number`. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The autocomplete interaction being handled. + value: Union[:class:`str`, :class:`int`, :class:`float`] + The current value entered by the user. + + Returns + -------- + List[:class:`~discord.app_commands.Choice`] + A list of choices to be displayed to the user, a maximum of 25. + + """ + raise NotImplementedError('Derived classes can implement this.') + + +class IdentityTransformer(Transformer[ClientT]): + def __init__(self, type: AppCommandOptionType) -> None: + self._type = type + + @property + def type(self) -> AppCommandOptionType: + return self._type + + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: + return value + + +class RangeTransformer(IdentityTransformer): + def __init__( + self, + opt_type: AppCommandOptionType, + *, + min: Optional[Union[int, float]] = None, + max: Optional[Union[int, float]] = None, + ) -> None: + if min and max and min > max: + raise TypeError('minimum cannot be larger than maximum') + + self._min: Optional[Union[int, float]] = min + self._max: Optional[Union[int, float]] = max + super().__init__(opt_type) + + @property + def min_value(self) -> Optional[Union[int, float]]: + return self._min + + @property + def max_value(self) -> Optional[Union[int, float]]: + return self._max + + +class LiteralTransformer(IdentityTransformer): + def __init__(self, values: Tuple[Any, ...]) -> None: + first = type(values[0]) + if first is int: + opt_type = AppCommandOptionType.integer + elif first is float: + opt_type = AppCommandOptionType.number + elif first is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {first!r}') + + self._choices = [Choice(name=str(v), value=v) for v in values] + super().__init__(opt_type) + + @property + def choices(self): + return self._choices + + +class ChoiceTransformer(IdentityTransformer): + __discord_app_commands_is_choice__: ClassVar[bool] = True + + def __init__(self, inner_type: Any) -> None: + if inner_type is int: + opt_type = AppCommandOptionType.integer + elif inner_type is float: + opt_type = AppCommandOptionType.number + elif inner_type is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {inner_type!r}') + + super().__init__(opt_type) + + +class EnumValueTransformer(Transformer): + def __init__(self, enum: Any) -> None: + super().__init__() + + values = list(enum) + if len(values) < 2: + raise TypeError('enum.Enum requires at least two values.') + + first = type(values[0].value) + if first is int: + opt_type = AppCommandOptionType.integer + elif first is float: + opt_type = AppCommandOptionType.number + elif first is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, str, or float values not {first!r}') + + self._type: AppCommandOptionType = opt_type + self._enum: Any = enum + self._choices = [Choice(name=v.name, value=v.value) for v in values] + + @property + def _error_display_name(self) -> str: + return self._enum.__name__ + + @property + def type(self) -> AppCommandOptionType: + return self._type + + @property + def choices(self): + return self._choices + + async def transform(self, interaction: Interaction, value: Any, /) -> Any: + return self._enum(value) + + +class EnumNameTransformer(Transformer): + def __init__(self, enum: Any) -> None: + super().__init__() + + values = list(enum) + if len(values) < 2: + raise TypeError('enum.Enum requires at least two values.') + + self._enum: Any = enum + self._choices = [Choice(name=v.name, value=v.name) for v in values] + + @property + def _error_display_name(self) -> str: + return self._enum.__name__ + + @property + def type(self) -> AppCommandOptionType: + return AppCommandOptionType.string + + @property + def choices(self): + return self._choices + + async def transform(self, interaction: Interaction, value: Any, /) -> Any: + return self._enum[value] + + +class InlineTransformer(Transformer[ClientT]): + def __init__(self, annotation: Any) -> None: + super().__init__() + self.annotation: Any = annotation + + @property + def _error_display_name(self) -> str: + return self.annotation.__name__ + + @property + def type(self) -> AppCommandOptionType: + return AppCommandOptionType.string + + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Any: + return await self.annotation.transform(interaction, value) + + +if TYPE_CHECKING: + from typing_extensions import Annotated as Transform + from typing_extensions import Annotated as Range +else: + + class Transform: + """A type annotation that can be applied to a parameter to customise the behaviour of + an option type by transforming with the given :class:`Transformer`. This requires + the usage of two generic parameters, the first one is the type you're converting to and the second + one is the type of the :class:`Transformer` actually doing the transformation. + + During type checking time this is equivalent to :obj:`typing.Annotated` so type checkers understand + the intent of the code. + + For example usage, check :class:`Transformer`. + + .. versionadded:: 2.0 + """ + + def __class_getitem__(cls, items) -> Transformer: + if not isinstance(items, tuple): + raise TypeError(f'expected tuple for arguments, received {items.__class__.__name__} instead') + + if len(items) != 2: + raise TypeError('Transform only accepts exactly two arguments') + + _, transformer = items + + if inspect.isclass(transformer): + if not issubclass(transformer, Transformer): + raise TypeError(f'second argument of Transform must be a Transformer class not {transformer!r}') + transformer = transformer() + elif not isinstance(transformer, Transformer): + raise TypeError(f'second argument of Transform must be a Transformer not {transformer.__class__.__name__}') + + return transformer + + class Range: + """A type annotation that can be applied to a parameter to require a numeric or string + type to fit within the range provided. + + During type checking time this is equivalent to :obj:`typing.Annotated` so type checkers understand + the intent of the code. + + Some example ranges: + + - ``Range[int, 10]`` means the minimum is 10 with no maximum. + - ``Range[int, None, 10]`` means the maximum is 10 with no minimum. + - ``Range[int, 1, 10]`` means the minimum is 1 and the maximum is 10. + - ``Range[float, 1.0, 5.0]`` means the minimum is 1.0 and the maximum is 5.0. + - ``Range[str, 1, 10]`` means the minimum length is 1 and the maximum length is 10. + + .. versionadded:: 2.0 + + Examples + ---------- + + .. code-block:: python3 + + @app_commands.command() + async def range(interaction: discord.Interaction, value: app_commands.Range[int, 10, 12]): + await interaction.response.send_message(f'Your value is {value}', ephemeral=True) + """ + + def __class_getitem__(cls, obj) -> RangeTransformer: + if not isinstance(obj, tuple): + raise TypeError(f'expected tuple for arguments, received {obj.__class__.__name__} instead') + + if len(obj) == 2: + obj = (*obj, None) + elif len(obj) != 3: + raise TypeError('Range accepts either two or three arguments with the first being the type of range.') + + obj_type, min, max = obj + + if min is None and max is None: + raise TypeError('Range must not be empty') + + if min is not None and max is not None: + # At this point max and min are both not none + if type(min) != type(max): + raise TypeError('Both min and max in Range must be the same type') + + if obj_type is int: + opt_type = AppCommandOptionType.integer + elif obj_type is float: + opt_type = AppCommandOptionType.number + elif obj_type is str: + opt_type = AppCommandOptionType.string + else: + raise TypeError(f'expected int, float, or str as range type, received {obj_type!r} instead') + + if obj_type in (str, int): + cast = int + else: + cast = float + + transformer = RangeTransformer( + opt_type, + min=cast(min) if min is not None else None, + max=cast(max) if max is not None else None, + ) + return transformer + + +class MemberTransformer(Transformer[ClientT]): + @property + def type(self) -> AppCommandOptionType: + return AppCommandOptionType.user + + async def transform(self, interaction: Interaction[ClientT], value: Any, /) -> Member: + if not isinstance(value, Member): + raise TransformerError(value, self.type, self) + return value + + +class BaseChannelTransformer(Transformer[ClientT]): + def __init__(self, *channel_types: Type[Any]) -> None: + super().__init__() + if len(channel_types) == 1: + display_name = channel_types[0].__name__ + types = CHANNEL_TO_TYPES[channel_types[0]] + else: + display_name = _human_join([t.__name__ for t in channel_types]) + types = [] + + for t in channel_types: + try: + types.extend(CHANNEL_TO_TYPES[t]) + except KeyError: + raise TypeError('Union type of channels must be entirely made up of channels') from None + + self._types: Tuple[Type[Any], ...] = channel_types + self._channel_types: List[ChannelType] = types + self._display_name = display_name + + @property + def _error_display_name(self) -> str: + return self._display_name + + @property + def type(self) -> AppCommandOptionType: + return AppCommandOptionType.channel + + @property + def channel_types(self) -> List[ChannelType]: + return self._channel_types + + async def transform(self, interaction: Interaction[ClientT], value: Any, /): + resolved = value.resolve() + if resolved is None or not isinstance(resolved, self._types): + raise TransformerError(value, AppCommandOptionType.channel, self) + return resolved + + +class RawChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): + if not isinstance(value, self._types): + raise TransformerError(value, AppCommandOptionType.channel, self) + return value + + +class UnionChannelTransformer(BaseChannelTransformer[ClientT]): + async def transform(self, interaction: Interaction[ClientT], value: Any, /): + if isinstance(value, self._types): + return value + + resolved = value.resolve() + if resolved is None or not isinstance(resolved, self._types): + raise TransformerError(value, AppCommandOptionType.channel, self) + return resolved + + +CHANNEL_TO_TYPES: Dict[Any, List[ChannelType]] = { + AppCommandChannel: [ + ChannelType.stage_voice, + ChannelType.voice, + ChannelType.text, + ChannelType.news, + ChannelType.category, + ChannelType.forum, + ChannelType.media, + ], + GuildChannel: [ + ChannelType.stage_voice, + ChannelType.voice, + ChannelType.text, + ChannelType.news, + ChannelType.category, + ChannelType.forum, + ChannelType.media, + ], + AppCommandThread: [ChannelType.news_thread, ChannelType.private_thread, ChannelType.public_thread], + Thread: [ChannelType.news_thread, ChannelType.private_thread, ChannelType.public_thread], + StageChannel: [ChannelType.stage_voice], + VoiceChannel: [ChannelType.voice], + TextChannel: [ChannelType.text, ChannelType.news], + CategoryChannel: [ChannelType.category], + ForumChannel: [ChannelType.forum, ChannelType.media], +} + +BUILT_IN_TRANSFORMERS: Dict[Any, Transformer] = { + str: IdentityTransformer(AppCommandOptionType.string), + int: IdentityTransformer(AppCommandOptionType.integer), + float: IdentityTransformer(AppCommandOptionType.number), + bool: IdentityTransformer(AppCommandOptionType.boolean), + User: IdentityTransformer(AppCommandOptionType.user), + Member: MemberTransformer(), + Role: IdentityTransformer(AppCommandOptionType.role), + AppCommandChannel: RawChannelTransformer(AppCommandChannel), + AppCommandThread: RawChannelTransformer(AppCommandThread), + GuildChannel: BaseChannelTransformer(GuildChannel), + Thread: BaseChannelTransformer(Thread), + StageChannel: BaseChannelTransformer(StageChannel), + VoiceChannel: BaseChannelTransformer(VoiceChannel), + TextChannel: BaseChannelTransformer(TextChannel), + CategoryChannel: BaseChannelTransformer(CategoryChannel), + ForumChannel: BaseChannelTransformer(ForumChannel), + Attachment: IdentityTransformer(AppCommandOptionType.attachment), +} + +ALLOWED_DEFAULTS: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = { + AppCommandOptionType.string: (str, NoneType), + AppCommandOptionType.integer: (int, NoneType), + AppCommandOptionType.boolean: (bool, NoneType), + AppCommandOptionType.number: (float, NoneType), +} + + +def get_supported_annotation( + annotation: Any, + *, + _none: type = NoneType, + _mapping: Dict[Any, Transformer] = BUILT_IN_TRANSFORMERS, +) -> Tuple[Any, Any, bool]: + """Returns an appropriate, yet supported, annotation along with an optional default value. + + The third boolean element of the tuple indicates if default values should be validated. + + This differs from the built in mapping by supporting a few more things. + Likewise, this returns a "transformed" annotation that is ready to use with CommandParameter.transform. + """ + + try: + return (_mapping[annotation], MISSING, True) + except (KeyError, TypeError): + pass + + if isinstance(annotation, Transformer): + return (annotation, MISSING, False) + + if inspect.isclass(annotation): + if issubclass(annotation, Transformer): + return (annotation(), MISSING, False) + if issubclass(annotation, (Enum, InternalEnum)): + if all(isinstance(v.value, (str, int, float)) for v in annotation): + return (EnumValueTransformer(annotation), MISSING, False) + else: + return (EnumNameTransformer(annotation), MISSING, False) + if annotation is Choice: + raise TypeError('Choice requires a type argument of int, str, or float') + + # Check if a transform @classmethod is given to the class + # These flatten into simple "inline" transformers with implicit strings + transform_classmethod = annotation.__dict__.get('transform', None) + if isinstance(transform_classmethod, classmethod): + params = inspect.signature(transform_classmethod.__func__).parameters + if len(params) != 3: + raise TypeError('Inline transformer with transform classmethod requires 3 parameters') + if not inspect.iscoroutinefunction(transform_classmethod.__func__): + raise TypeError('Inline transformer with transform classmethod must be a coroutine') + return (InlineTransformer(annotation), MISSING, False) + + # Check if there's an origin + origin = getattr(annotation, '__origin__', None) + if origin is Literal: + args = annotation.__args__ + return (LiteralTransformer(args), MISSING, True) + + if origin is Choice: + arg = annotation.__args__[0] + return (ChoiceTransformer(arg), MISSING, True) + + if origin is not Union: + # Only Union/Optional is supported right now so bail early + raise TypeError(f'unsupported type annotation {annotation!r}') + + default = MISSING + args = annotation.__args__ + if args[-1] is _none: + if len(args) == 2: + underlying = args[0] + inner, _, validate_default = get_supported_annotation(underlying) + if inner is None: + raise TypeError(f'unsupported inner optional type {underlying!r}') + return (inner, None, validate_default) + else: + args = args[:-1] + default = None + + # Check for channel union types + if any(arg in CHANNEL_TO_TYPES for arg in args): + # If any channel type is given, then *all* must be channel types + return (UnionChannelTransformer(*args), default, True) + + # The only valid transformations here are: + # [Member, User] => user + # [Member, User, Role] => mentionable + # [Member | User, Role] => mentionable + supported_types: Set[Any] = {Role, Member, User} + if not all(arg in supported_types for arg in args): + raise TypeError(f'unsupported types given inside {annotation!r}') + if args == (User, Member) or args == (Member, User): + return (IdentityTransformer(AppCommandOptionType.user), default, True) + + return (IdentityTransformer(AppCommandOptionType.mentionable), default, True) + + +def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter: + """Returns the appropriate :class:`CommandParameter` for the given annotation. + + The resulting ``_annotation`` attribute might not match the one given here and might + be transformed in order to be easier to call from the ``transform`` asynchronous function + of a command parameter. + """ + + (inner, default, validate_default) = get_supported_annotation(annotation) + type = inner.type + + if default is MISSING or default is None: + param_default = parameter.default + if param_default is not parameter.empty: + default = param_default + + # Verify validity of the default parameter + if default is not MISSING and validate_default: + valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) + if not isinstance(default, valid_types): + raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') + + result = CommandParameter( + type=type, + _annotation=inner, + default=default, + required=default is MISSING, + name=parameter.name, + ) + + choices = inner.choices + if choices is not None: + result.choices = choices + + # These methods should be duck typed + if type in (AppCommandOptionType.number, AppCommandOptionType.string, AppCommandOptionType.integer): + result.min_value = inner.min_value + result.max_value = inner.max_value + + if type is AppCommandOptionType.channel: + result.channel_types = inner.channel_types + + if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL): + raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}') + + # Check if the method is overridden + if inner.autocomplete.__func__ is not Transformer.autocomplete: + from .commands import validate_auto_complete_callback + + result.autocomplete = validate_auto_complete_callback(inner.autocomplete) + + return result diff --git a/discord/app_commands/translator.py b/discord/app_commands/translator.py new file mode 100644 index 000000000000..36b1b923c1b1 --- /dev/null +++ b/discord/app_commands/translator.py @@ -0,0 +1,299 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload +from .errors import TranslationError +from ..enums import Enum, Locale + + +if TYPE_CHECKING: + from .commands import Command, ContextMenu, Group, Parameter + from .models import Choice + + +__all__ = ( + 'TranslationContextLocation', + 'TranslationContextTypes', + 'TranslationContext', + 'Translator', + 'locale_str', +) + + +class TranslationContextLocation(Enum): + command_name = 0 + command_description = 1 + group_name = 2 + group_description = 3 + parameter_name = 4 + parameter_description = 5 + choice_name = 6 + other = 7 + + +_L = TypeVar('_L', bound=TranslationContextLocation) +_D = TypeVar('_D') + + +class TranslationContext(Generic[_L, _D]): + """A class that provides context for the :class:`locale_str` being translated. + + This is useful to determine where exactly the string is located and aid in looking + up the actual translation. + + Attributes + ----------- + location: :class:`TranslationContextLocation` + The location where this string is located. + data: Any + The extraneous data that is being translated. + """ + + __slots__ = ('location', 'data') + + @overload + def __init__( + self, location: Literal[TranslationContextLocation.command_name], data: Union[Command[Any, ..., Any], ContextMenu] + ) -> None: ... + + @overload + def __init__( + self, location: Literal[TranslationContextLocation.command_description], data: Command[Any, ..., Any] + ) -> None: ... + + @overload + def __init__( + self, + location: Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], + data: Group, + ) -> None: ... + + @overload + def __init__( + self, + location: Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], + data: Parameter, + ) -> None: ... + + @overload + def __init__(self, location: Literal[TranslationContextLocation.choice_name], data: Choice[Any]) -> None: ... + + @overload + def __init__(self, location: Literal[TranslationContextLocation.other], data: Any) -> None: ... + + def __init__(self, location: _L, data: _D) -> None: # type: ignore # pyright doesn't like the overloads + self.location: _L = location + self.data: _D = data + + +# For type checking purposes, it makes sense to allow the user to leverage type narrowing +# So code like this works as expected: +# +# if context.type == TranslationContextLocation.command_name: +# reveal_type(context.data) # Revealed type is Command | ContextMenu +# +# This requires a union of types +CommandNameTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.command_name], Union['Command[Any, ..., Any]', 'ContextMenu'] +] +CommandDescriptionTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.command_description], 'Command[Any, ..., Any]' +] +GroupTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.group_name, TranslationContextLocation.group_description], 'Group' +] +ParameterTranslationContext = TranslationContext[ + Literal[TranslationContextLocation.parameter_name, TranslationContextLocation.parameter_description], 'Parameter' +] +ChoiceTranslationContext = TranslationContext[Literal[TranslationContextLocation.choice_name], 'Choice[Any]'] +OtherTranslationContext = TranslationContext[Literal[TranslationContextLocation.other], Any] + +TranslationContextTypes = Union[ + CommandNameTranslationContext, + CommandDescriptionTranslationContext, + GroupTranslationContext, + ParameterTranslationContext, + ChoiceTranslationContext, + OtherTranslationContext, +] + + +class Translator: + """A class that handles translations for commands, parameters, and choices. + + Translations are done lazily in order to allow for async enabled translations as well + as supporting a wide array of translation systems such as :mod:`gettext` and + `Project Fluent `_. + + In order for a translator to be used, it must be set using the :meth:`CommandTree.set_translator` + method. The translation flow for a string is as follows: + + 1. Use :class:`locale_str` instead of :class:`str` in areas of a command you want to be translated. + - Currently, these are command names, command descriptions, parameter names, parameter descriptions, and choice names. + - This can also be used inside the :func:`~discord.app_commands.describe` decorator. + 2. Call :meth:`CommandTree.set_translator` to the translator instance that will handle the translations. + 3. Call :meth:`CommandTree.sync` + 4. The library will call :meth:`Translator.translate` on all the relevant strings being translated. + + .. versionadded:: 2.0 + """ + + async def load(self) -> None: + """|coro| + + An asynchronous setup function for loading the translation system. + + The default implementation does nothing. + + This is invoked when :meth:`CommandTree.set_translator` is called. + """ + pass + + async def unload(self) -> None: + """|coro| + + An asynchronous teardown function for unloading the translation system. + + The default implementation does nothing. + + This is invoked when :meth:`CommandTree.set_translator` is called + if a tree already has a translator or when :meth:`discord.Client.close` is called. + """ + pass + + async def _checked_translate( + self, string: locale_str, locale: Locale, context: TranslationContextTypes + ) -> Optional[str]: + try: + return await self.translate(string, locale, context) + except TranslationError: + raise + except Exception as e: + raise TranslationError(string=string, locale=locale, context=context) from e + + async def translate(self, string: locale_str, locale: Locale, context: TranslationContextTypes) -> Optional[str]: + """|coro| + + Translates the given string to the specified locale. + + If the string cannot be translated, ``None`` should be returned. + + The default implementation returns ``None``. + + If an exception is raised in this method, it should inherit from :exc:`TranslationError`. + If it doesn't, then when this is called the exception will be chained with it instead. + + Parameters + ------------ + string: :class:`locale_str` + The string being translated. + locale: :class:`~discord.Locale` + The locale being requested for translation. + context: :class:`TranslationContext` + The translation context where the string originated from. + For better type checking ergonomics, the ``TranslationContextTypes`` + type can be used instead to aid with type narrowing. It is functionally + equivalent to :class:`TranslationContext`. + """ + + return None + + +class locale_str: + """Marks a string as ready for translation. + + This is done lazily and is not actually translated until :meth:`CommandTree.sync` is called. + + The sync method then ultimately defers the responsibility of translating to the :class:`Translator` + instance used by the :class:`CommandTree`. For more information on the translation flow, see the + :class:`Translator` documentation. + + .. container:: operations + + .. describe:: str(x) + + Returns the message passed to the string. + + .. describe:: x == y + + Checks if the string is equal to another string. + + .. describe:: x != y + + Checks if the string is not equal to another string. + + .. describe:: hash(x) + + Returns the hash of the string. + + .. versionadded:: 2.0 + + Attributes + ------------ + message: :class:`str` + The message being translated. Once set, this cannot be changed. + + .. warning:: + + This must be the default "message" that you send to Discord. + Discord sends this message back to the library and the library + uses it to access the data in order to dispatch commands. + + For example, in a command name context, if the command + name is ``foo`` then the message *must* also be ``foo``. + For other translation systems that require a message ID such + as Fluent, consider using a keyword argument to pass it in. + extras: :class:`dict` + A dict of user provided extras to attach to the translated string. + This can be used to add more context, information, or any metadata necessary + to aid in actually translating the string. + + Since these are passed via keyword arguments, the keys are strings. + """ + + __slots__ = ('__message', 'extras') + + def __init__(self, message: str, /, **kwargs: Any) -> None: + self.__message: str = message + self.extras: dict[str, Any] = kwargs + + @property + def message(self) -> str: + return self.__message + + def __str__(self) -> str: + return self.__message + + def __repr__(self) -> str: + kwargs = ', '.join(f'{k}={v!r}' for k, v in self.extras.items()) + if kwargs: + return f'{self.__class__.__name__}({self.__message!r}, {kwargs})' + return f'{self.__class__.__name__}({self.__message!r})' + + def __eq__(self, obj: object) -> bool: + return isinstance(obj, locale_str) and self.message == obj.message + + def __hash__(self) -> int: + return hash(self.__message) diff --git a/discord/app_commands/tree.py b/discord/app_commands/tree.py new file mode 100644 index 000000000000..aa446a01f2b9 --- /dev/null +++ b/discord/app_commands/tree.py @@ -0,0 +1,1304 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +import logging +import inspect + +from typing import ( + Any, + TYPE_CHECKING, + Callable, + Coroutine, + Dict, + Generator, + Generic, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, + overload, +) +from collections import Counter + + +from .namespace import Namespace, ResolveKey +from .models import AppCommand +from .commands import Command, ContextMenu, Group +from .errors import ( + AppCommandError, + CommandAlreadyRegistered, + CommandNotFound, + CommandSignatureMismatch, + CommandLimitReached, + CommandSyncFailure, + MissingApplicationID, +) +from .installs import AppCommandContext, AppInstallationType +from .translator import Translator, locale_str +from ..errors import ClientException, HTTPException +from ..enums import AppCommandType, InteractionType +from ..utils import MISSING, _get_as_snowflake, _is_submodule, _shorten +from .._types import ClientT + + +if TYPE_CHECKING: + from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption + from ..interactions import Interaction + from ..abc import Snowflake + from .commands import ContextMenuCallback, CommandCallback, P, T + + ErrorFunc = Callable[ + [Interaction[ClientT], AppCommandError], + Coroutine[Any, Any, Any], + ] + +__all__ = ('CommandTree',) + +_log = logging.getLogger(__name__) + + +def _retrieve_guild_ids( + command: Any, guild: Optional[Snowflake] = MISSING, guilds: Sequence[Snowflake] = MISSING +) -> Optional[Set[int]]: + if guild is not MISSING and guilds is not MISSING: + raise TypeError('cannot mix guild and guilds keyword arguments') + + # guilds=[] or guilds=[...] + if guild is MISSING: + # If no arguments are given then it should default to the ones + # given to the guilds(...) decorator or None for global. + if guilds is MISSING: + return getattr(command, '_guild_ids', None) + + # guilds=[] is the same as global + if len(guilds) == 0: + return None + + return {g.id for g in guilds} + + # At this point it should be... + # guild=None or guild=Object + if guild is None: + return None + return {guild.id} + + +class CommandTree(Generic[ClientT]): + """Represents a container that holds application command information. + + Parameters + ----------- + client: :class:`~discord.Client` + The client instance to get application command information from. + fallback_to_global: :class:`bool` + If a guild-specific command is not found when invoked, then try falling back into + a global command in the tree. For example, if the tree locally has a ``/ping`` command + under the global namespace but the guild has a guild-specific ``/ping``, instead of failing + to find the guild-specific ``/ping`` command it will fall back to the global ``/ping`` command. + This has the potential to raise more :exc:`~discord.app_commands.CommandSignatureMismatch` errors + than usual. Defaults to ``True``. + allowed_contexts: :class:`~discord.app_commands.AppCommandContext` + The default allowed contexts that applies to all commands in this tree. + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 + allowed_installs: :class:`~discord.app_commands.AppInstallationType` + The default allowed install locations that apply to all commands in this tree. + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 + """ + + def __init__( + self, + client: ClientT, + *, + fallback_to_global: bool = True, + allowed_contexts: AppCommandContext = MISSING, + allowed_installs: AppInstallationType = MISSING, + ): + self.client: ClientT = client + self._http = client.http + self._state = client._connection + + if self._state._command_tree is not None: + raise ClientException('This client already has an associated command tree.') + + self._state._command_tree = self + self.fallback_to_global: bool = fallback_to_global + self.allowed_contexts = AppCommandContext() if allowed_contexts is MISSING else allowed_contexts + self.allowed_installs = AppInstallationType() if allowed_installs is MISSING else allowed_installs + self._guild_commands: Dict[int, Dict[str, Union[Command, Group]]] = {} + self._global_commands: Dict[str, Union[Command, Group]] = {} + # (name, guild_id, command_type): Command + # The above two mappings can use this structure too but we need fast retrieval + # by name and guild_id in the above case while here it isn't as important since + # it's uncommon and N=5 anyway. + self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} + + async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand: + """|coro| + + Fetches an application command from the application. + + Parameters + ----------- + command_id: :class:`int` + The ID of the command to fetch. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to fetch the command from. If not passed then the global command + is fetched instead. + + Raises + ------- + HTTPException + Fetching the command failed. + MissingApplicationID + The application ID could not be found. + NotFound + The application command was not found. + This could also be because the command is a guild command + and the guild was not specified and vice versa. + + Returns + -------- + :class:`~discord.app_commands.AppCommand` + The application command. + """ + if self.client.application_id is None: + raise MissingApplicationID + + if guild is None: + command = await self._http.get_global_command(self.client.application_id, command_id) + else: + command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id) + + return AppCommand(data=command, state=self._state) + + async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: + """|coro| + + Fetches the application's current commands. + + If no guild is passed then global commands are fetched, otherwise + the guild's commands are fetched instead. + + .. note:: + + This includes context menu commands. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to fetch the commands from. If not passed then global commands + are fetched instead. + + Raises + ------- + HTTPException + Fetching the commands failed. + MissingApplicationID + The application ID could not be found. + + Returns + -------- + List[:class:`~discord.app_commands.AppCommand`] + The application's commands. + """ + if self.client.application_id is None: + raise MissingApplicationID + + if guild is None: + commands = await self._http.get_global_commands(self.client.application_id) + else: + commands = await self._http.get_guild_commands(self.client.application_id, guild.id) + + return [AppCommand(data=data, state=self._state) for data in commands] + + def copy_global_to(self, *, guild: Snowflake) -> None: + """Copies all global commands to the specified guild. + + This method is mainly available for development purposes, as it allows you + to copy your global commands over to a testing guild easily. + + Note that this method will *override* pre-existing guild commands that would conflict. + + Parameters + ----------- + guild: :class:`~discord.abc.Snowflake` + The guild to copy the commands to. + + Raises + -------- + CommandLimitReached + The maximum number of commands was reached for that guild. + This is currently 100 for slash commands and 5 for context menu commands. + """ + + try: + mapping = self._guild_commands[guild.id].copy() + except KeyError: + mapping = {} + + mapping.update(self._global_commands) + if len(mapping) > 100: + raise CommandLimitReached(guild_id=guild.id, limit=100) + + ctx_menu: Dict[Tuple[str, Optional[int], int], ContextMenu] = { + (name, guild.id, cmd_type): cmd + for ((name, g, cmd_type), cmd) in self._context_menus.items() + if g is None or g == guild.id + } + + counter = Counter(cmd_type for _, _, cmd_type in ctx_menu) + for cmd_type, count in counter.items(): + if count > 5: + as_enum = AppCommandType(cmd_type) + raise CommandLimitReached(guild_id=guild.id, limit=5, type=as_enum) + + self._context_menus.update(ctx_menu) + self._guild_commands[guild.id] = mapping + + def add_command( + self, + command: Union[Command[Any, ..., Any], ContextMenu, Group], + /, + *, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + override: bool = False, + ) -> None: + """Adds an application command to the tree. + + This only adds the command locally -- in order to sync the commands + and enable them in the client, :meth:`sync` must be called. + + The root parent of the command is added regardless of the type passed. + + Parameters + ----------- + command: Union[:class:`Command`, :class:`Group`] + The application command or group to add. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to add the command to. If not given or ``None`` then it + becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + override: :class:`bool` + Whether to override a command with the same name. If ``False`` + an exception is raised. Default is ``False``. + + Raises + -------- + ~discord.app_commands.CommandAlreadyRegistered + The command was already registered and no override was specified. + TypeError + The application command passed is not a valid application command. + Or, ``guild`` and ``guilds`` were both given. + CommandLimitReached + The maximum number of commands was reached globally or for that guild. + This is currently 100 for slash commands and 5 for context menu commands. + """ + + guild_ids = _retrieve_guild_ids(command, guild, guilds) + if isinstance(command, ContextMenu): + type = command.type.value + name = command.name + + def _context_menu_add_helper( + guild_id: Optional[int], + data: Dict[Tuple[str, Optional[int], int], ContextMenu], + name: str = name, + type: int = type, + ) -> None: + key = (name, guild_id, type) + found = key in self._context_menus + if found and not override: + raise CommandAlreadyRegistered(name, guild_id) + + # If the key is found and overridden then it shouldn't count as an extra addition + # read as `0 if override and found else 1` if confusing + to_add = not (override and found) + total = sum(1 for _, g, t in self._context_menus if g == guild_id and t == type) + if total + to_add > 5: + raise CommandLimitReached(guild_id=guild_id, limit=5, type=AppCommandType(type)) + data[key] = command + + if guild_ids is None: + _context_menu_add_helper(None, self._context_menus) + else: + current: Dict[Tuple[str, Optional[int], int], ContextMenu] = {} + for guild_id in guild_ids: + _context_menu_add_helper(guild_id, current) + + # Update at the end in order to make sure the update is atomic. + # An error during addition could end up making the context menu mapping + # have a partial state + self._context_menus.update(current) + return + elif not isinstance(command, (Command, Group)): + raise TypeError(f'Expected an application command, received {command.__class__.__name__} instead') + + # todo: validate application command groups having children (required) + + root = command.root_parent or command + name = root.name + if guild_ids is not None: + # Validate that the command can be added first, before actually + # adding it into the mapping. This ensures atomicity. + for guild_id in guild_ids: + commands = self._guild_commands.get(guild_id, {}) + found = name in commands + if found and not override: + raise CommandAlreadyRegistered(name, guild_id) + + to_add = not (override and found) + if len(commands) + to_add > 100: + raise CommandLimitReached(guild_id=guild_id, limit=100) + + # Actually add the command now that it has been verified to be okay. + for guild_id in guild_ids: + commands = self._guild_commands.setdefault(guild_id, {}) + commands[name] = root + else: + found = name in self._global_commands + if found and not override: + raise CommandAlreadyRegistered(name, None) + + to_add = not (override and found) + if len(self._global_commands) + to_add > 100: + raise CommandLimitReached(guild_id=None, limit=100) + self._global_commands[name] = root + + @overload + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user], + ) -> Optional[ContextMenu]: ... + + @overload + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> Optional[Union[Command[Any, ..., Any], Group]]: ... + + @overload + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: AppCommandType, + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: ... + + def remove_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: + """Removes an application command from the tree. + + This only removes the command locally -- in order to sync the commands + and remove them in the client, :meth:`sync` must be called. + + Parameters + ----------- + command: :class:`str` + The name of the root command to remove. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to remove the command from. If not given or ``None`` then it + removes a global command instead. + type: :class:`~discord.AppCommandType` + The type of command to remove. Defaults to :attr:`~discord.AppCommandType.chat_input`, + i.e. slash commands. + + Returns + --------- + Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] + The application command that got removed. + If nothing was removed then ``None`` is returned instead. + """ + + if type is AppCommandType.chat_input: + if guild is None: + return self._global_commands.pop(command, None) + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + return None + else: + return commands.pop(command, None) + elif type in (AppCommandType.user, AppCommandType.message): + guild_id = None if guild is None else guild.id + key = (command, guild_id, type.value) + return self._context_menus.pop(key, None) + + def clear_commands(self, *, guild: Optional[Snowflake], type: Optional[AppCommandType] = None) -> None: + """Clears all application commands from the tree. + + This only removes the commands locally -- in order to sync the commands + and remove them in the client, :meth:`sync` must be called. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to remove the commands from. If ``None`` then it + removes all global commands instead. + type: :class:`~discord.AppCommandType` + The type of command to clear. If not given or ``None`` then it removes all commands + regardless of the type. + """ + + if type is None or type is AppCommandType.chat_input: + if guild is None: + self._global_commands.clear() + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + pass + else: + commands.clear() + + guild_id = None if guild is None else guild.id + if type is None: + self._context_menus = { + (name, _guild_id, value): cmd + for (name, _guild_id, value), cmd in self._context_menus.items() + if _guild_id != guild_id + } + elif type in (AppCommandType.user, AppCommandType.message): + self._context_menus = { + (name, _guild_id, value): cmd + for (name, _guild_id, value), cmd in self._context_menus.items() + if _guild_id != guild_id or value != type.value + } + + @overload + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user], + ) -> Optional[ContextMenu]: ... + + @overload + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> Optional[Union[Command[Any, ..., Any], Group]]: ... + + @overload + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = ..., + type: AppCommandType, + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: ... + + def get_command( + self, + command: str, + /, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Optional[Union[Command[Any, ..., Any], ContextMenu, Group]]: + """Gets an application command from the tree. + + Parameters + ----------- + command: :class:`str` + The name of the root command to get. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to get the command from. If not given or ``None`` then it + gets a global command instead. + type: :class:`~discord.AppCommandType` + The type of command to get. Defaults to :attr:`~discord.AppCommandType.chat_input`, + i.e. slash commands. + + Returns + --------- + Optional[Union[:class:`Command`, :class:`ContextMenu`, :class:`Group`]] + The application command that was found. + If nothing was found then ``None`` is returned instead. + """ + + if type is AppCommandType.chat_input: + if guild is None: + return self._global_commands.get(command) + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + return None + else: + return commands.get(command) + elif type in (AppCommandType.user, AppCommandType.message): + guild_id = None if guild is None else guild.id + key = (command, guild_id, type.value) + return self._context_menus.get(key) + + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user], + ) -> List[ContextMenu]: ... + + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input], + ) -> List[Union[Command[Any, ..., Any], Group]]: ... + + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: AppCommandType, + ) -> Union[List[Union[Command[Any, ..., Any], Group]], List[ContextMenu]]: ... + + @overload + def get_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Optional[AppCommandType] = ..., + ) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]: ... + + def get_commands( + self, + *, + guild: Optional[Snowflake] = None, + type: Optional[AppCommandType] = None, + ) -> Union[ + List[ContextMenu], + List[Union[Command[Any, ..., Any], Group]], + List[Union[Command[Any, ..., Any], Group, ContextMenu]], + ]: + """Gets all application commands from the tree. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to get the commands from, not including global commands. + If not given or ``None`` then only global commands are returned. + type: Optional[:class:`~discord.AppCommandType`] + The type of commands to get. When not given or ``None``, then all + command types are returned. + + Returns + --------- + List[Union[:class:`ContextMenu`, :class:`Command`, :class:`Group`]] + The application commands from the tree. + """ + if type is None: + return self._get_all_commands(guild=guild) + + if type is AppCommandType.chat_input: + if guild is None: + return list(self._global_commands.values()) + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + return [] + else: + return list(commands.values()) + else: + guild_id = None if guild is None else guild.id + value = type.value + return [command for ((_, g, t), command) in self._context_menus.items() if g == guild_id and t == value] + + @overload + def walk_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.message, AppCommandType.user], + ) -> Generator[ContextMenu, None, None]: ... + + @overload + def walk_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: Literal[AppCommandType.chat_input] = ..., + ) -> Generator[Union[Command[Any, ..., Any], Group], None, None]: ... + + @overload + def walk_commands( + self, + *, + guild: Optional[Snowflake] = ..., + type: AppCommandType, + ) -> Union[Generator[Union[Command[Any, ..., Any], Group], None, None], Generator[ContextMenu, None, None]]: ... + + def walk_commands( + self, + *, + guild: Optional[Snowflake] = None, + type: AppCommandType = AppCommandType.chat_input, + ) -> Union[Generator[Union[Command[Any, ..., Any], Group], None, None], Generator[ContextMenu, None, None]]: + """An iterator that recursively walks through all application commands and child commands from the tree. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to iterate the commands from, not including global commands. + If not given or ``None`` then only global commands are iterated. + type: :class:`~discord.AppCommandType` + The type of commands to iterate over. Defaults to :attr:`~discord.AppCommandType.chat_input`, + i.e. slash commands. + + Yields + --------- + Union[:class:`ContextMenu`, :class:`Command`, :class:`Group`] + The application commands from the tree. + """ + + if type is AppCommandType.chat_input: + if guild is None: + for cmd in self._global_commands.values(): + yield cmd + if isinstance(cmd, Group): + yield from cmd.walk_commands() + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + return + else: + for cmd in commands.values(): + yield cmd + if isinstance(cmd, Group): + yield from cmd.walk_commands() + else: + guild_id = None if guild is None else guild.id + value = type.value + for (_, g, t), command in self._context_menus.items(): + if g == guild_id and t == value: + yield command + + def _get_all_commands( + self, *, guild: Optional[Snowflake] = None + ) -> List[Union[Command[Any, ..., Any], Group, ContextMenu]]: + if guild is None: + base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(self._global_commands.values()) + base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g is None) + return base + else: + try: + commands = self._guild_commands[guild.id] + except KeyError: + guild_id = guild.id + return [cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id] + else: + base: List[Union[Command[Any, ..., Any], Group, ContextMenu]] = list(commands.values()) + guild_id = guild.id + base.extend(cmd for ((_, g, _), cmd) in self._context_menus.items() if g == guild_id) + return base + + def _remove_with_module(self, name: str) -> None: + remove: List[Any] = [] + for key, cmd in self._context_menus.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del self._context_menus[key] + + remove = [] + for key, cmd in self._global_commands.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del self._global_commands[key] + + for mapping in self._guild_commands.values(): + remove = [] + for key, cmd in mapping.items(): + if cmd.module is not None and _is_submodule(name, cmd.module): + remove.append(key) + + for key in remove: + del mapping[key] + + async def on_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: + """|coro| + + A callback that is called when any command raises an :exc:`AppCommandError`. + + The default implementation logs the exception using the library logger + if the command does not have any error handlers attached to it. + + To get the command that failed, :attr:`discord.Interaction.command` should + be used. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that is being handled. + error: :exc:`AppCommandError` + The exception that was raised. + """ + + command = interaction.command + if command is not None: + if command._has_any_error_handlers(): + return + + _log.error('Ignoring exception in command %r', command.name, exc_info=error) + else: + _log.error('Ignoring exception in command tree', exc_info=error) + + def error(self, coro: ErrorFunc[ClientT]) -> ErrorFunc[ClientT]: + """A decorator that registers a coroutine as a local error handler. + + This must match the signature of the :meth:`on_error` callback. + + The error passed will be derived from :exc:`AppCommandError`. + + Parameters + ----------- + coro: :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + TypeError + The coroutine passed is not actually a coroutine or does + not match the signature. + """ + + if not inspect.iscoroutinefunction(coro): + raise TypeError('The error handler must be a coroutine.') + + params = inspect.signature(coro).parameters + if len(params) != 2: + raise TypeError('error handler must have 2 parameters') + + self.on_error = coro # type: ignore + return coro + + def command( + self, + *, + name: Union[str, locale_str] = MISSING, + description: Union[str, locale_str] = MISSING, + nsfw: bool = False, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, + ) -> Callable[[CommandCallback[Group, P, T]], Command[Group, P, T]]: + """A decorator that creates an application command from a regular function directly under this tree. + + Parameters + ------------ + name: Union[:class:`str`, :class:`locale_str`] + The name of the application command. If not given, it defaults to a lower-case + version of the callback name. + description: Union[:class:`str`, :class:`locale_str`] + The description of the application command. This shows up in the UI to describe + the application command. If not given, it defaults to the first line of the docstring + of the callback shortened to 100 characters. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to add the command to. If not given or ``None`` then it + becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def decorator(func: CommandCallback[Group, P, T]) -> Command[Group, P, T]: + if not inspect.iscoroutinefunction(func): + raise TypeError('command function must be a coroutine function') + + if description is MISSING: + if func.__doc__ is None: + desc = '…' + else: + desc = _shorten(func.__doc__) + else: + desc = description + + command = Command( + name=name if name is not MISSING else func.__name__, + description=desc, + callback=func, + nsfw=nsfw, + parent=None, + auto_locale_strings=auto_locale_strings, + extras=extras, + ) + self.add_command(command, guild=guild, guilds=guilds) + return command + + return decorator + + def context_menu( + self, + *, + name: Union[str, locale_str] = MISSING, + nsfw: bool = False, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + auto_locale_strings: bool = True, + extras: Dict[Any, Any] = MISSING, + ) -> Callable[[ContextMenuCallback], ContextMenu]: + """A decorator that creates an application command context menu from a regular function directly under this tree. + + This function must have a signature of :class:`~discord.Interaction` as its first parameter + and taking either a :class:`~discord.Member`, :class:`~discord.User`, or :class:`~discord.Message`, + or a :obj:`typing.Union` of ``Member`` and ``User`` as its second parameter. + + Examples + --------- + + .. code-block:: python3 + + @app_commands.context_menu() + async def react(interaction: discord.Interaction, message: discord.Message): + await interaction.response.send_message('Very cool message!', ephemeral=True) + + @app_commands.context_menu() + async def ban(interaction: discord.Interaction, user: discord.Member): + await interaction.response.send_message(f'Should I actually ban {user}...', ephemeral=True) + + Parameters + ------------ + name: Union[:class:`str`, :class:`locale_str`] + The name of the context menu command. If not given, it defaults to a title-case + version of the callback name. Note that unlike regular slash commands this can + have spaces and upper case characters in the name. + nsfw: :class:`bool` + Whether the command is NSFW and should only work in NSFW channels. Defaults to ``False``. + + Due to a Discord limitation, this does not work on subcommands. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to add the command to. If not given or ``None`` then it + becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + guilds: List[:class:`~discord.abc.Snowflake`] + The list of guilds to add the command to. This cannot be mixed + with the ``guild`` parameter. If no guilds are given at all + then it becomes a global command instead. + + .. note :: + + Due to a Discord limitation, this keyword argument cannot be used in conjunction with + contexts (e.g. :func:`.app_commands.allowed_contexts`) or installation types + (e.g. :func:`.app_commands.allowed_installs`). + + auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`locale_str` rather than :class:`str`. This could + avoid some repetition and be more ergonomic for certain defaults such + as default command names, command descriptions, and parameter names. + Defaults to ``True``. + extras: :class:`dict` + A dictionary that can be used to store extraneous data. + The library will not touch any values or keys within this dictionary. + """ + + def decorator(func: ContextMenuCallback) -> ContextMenu: + if not inspect.iscoroutinefunction(func): + raise TypeError('context menu function must be a coroutine function') + + actual_name = func.__name__.title() if name is MISSING else name + context_menu = ContextMenu( + name=actual_name, + nsfw=nsfw, + callback=func, + auto_locale_strings=auto_locale_strings, + extras=extras, + ) + self.add_command(context_menu, guild=guild, guilds=guilds) + return context_menu + + return decorator + + @property + def translator(self) -> Optional[Translator]: + """Optional[:class:`Translator`]: The translator, if any, responsible for handling translation of commands. + + To change the translator, use :meth:`set_translator`. + """ + return self._state._translator + + async def set_translator(self, translator: Optional[Translator]) -> None: + """|coro| + + Sets the translator to use for translating commands. + + If a translator was previously set, it will be unloaded using its + :meth:`Translator.unload` method. + + When a translator is set, it will be loaded using its :meth:`Translator.load` method. + + Parameters + ------------ + translator: Optional[:class:`Translator`] + The translator to use. If ``None`` then the translator is just removed and unloaded. + + Raises + ------- + TypeError + The translator was not ``None`` or a :class:`Translator` instance. + """ + + if translator is not None and not isinstance(translator, Translator): + raise TypeError(f'expected None or Translator instance, received {translator.__class__.__name__} instead') + + old_translator = self._state._translator + if old_translator is not None: + await old_translator.unload() + + if translator is None: + self._state._translator = None + else: + await translator.load() + self._state._translator = translator + + async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]: + """|coro| + + Syncs the application commands to Discord. + + This also runs the translator to get the translated strings necessary for + feeding back into Discord. + + This must be called for the application commands to show up. + + Parameters + ----------- + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to sync the commands to. If ``None`` then it + syncs all global commands instead. + + Raises + ------- + HTTPException + Syncing the commands failed. + CommandSyncFailure + Syncing the commands failed due to a user related error, typically because + the command has invalid data. This is equivalent to an HTTP status code of + 400. + Forbidden + The client does not have the ``applications.commands`` scope in the guild. + MissingApplicationID + The client does not have an application ID. + TranslationError + An error occurred while translating the commands. + + Returns + -------- + List[:class:`AppCommand`] + The application's commands that got synced. + """ + + if self.client.application_id is None: + raise MissingApplicationID + + commands = self._get_all_commands(guild=guild) + + translator = self.translator + if translator: + payload = [await command.get_translated_payload(self, translator) for command in commands] + else: + payload = [command.to_dict(self) for command in commands] + + try: + if guild is None: + data = await self._http.bulk_upsert_global_commands(self.client.application_id, payload=payload) + else: + data = await self._http.bulk_upsert_guild_commands(self.client.application_id, guild.id, payload=payload) + except HTTPException as e: + if e.status == 400 and e.code == 50035: + raise CommandSyncFailure(e, commands) from None + raise + + return [AppCommand(data=d, state=self._state) for d in data] + + async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None: + command = interaction.command + interaction.command_failed = True + try: + if isinstance(command, Command): + await command._invoke_error_handlers(interaction, error) + finally: + await self.on_error(interaction, error) + + def _from_interaction(self, interaction: Interaction[ClientT]) -> None: + async def wrapper(): + try: + await self._call(interaction) + except AppCommandError as e: + await self._dispatch_error(interaction, e) + + self.client.loop.create_task(wrapper(), name='CommandTree-invoker') + + def _get_context_menu(self, data: ApplicationCommandInteractionData) -> Optional[ContextMenu]: + name = data['name'] + guild_id = _get_as_snowflake(data, 'guild_id') + t = data.get('type', 1) + cmd = self._context_menus.get((name, guild_id, t)) + if cmd is None and self.fallback_to_global: + return self._context_menus.get((name, None, t)) + return cmd + + def _get_app_command_options( + self, data: ApplicationCommandInteractionData + ) -> Tuple[Command[Any, ..., Any], List[ApplicationCommandInteractionDataOption]]: + parents: List[str] = [] + name = data['name'] + + command_guild_id = _get_as_snowflake(data, 'guild_id') + if command_guild_id: + try: + guild_commands = self._guild_commands[command_guild_id] + except KeyError: + command = None if not self.fallback_to_global else self._global_commands.get(name) + else: + command = guild_commands.get(name) + if command is None and self.fallback_to_global: + command = self._global_commands.get(name) + else: + command = self._global_commands.get(name) + + # If it's not found at this point then it's not gonna be found at any point + if command is None: + raise CommandNotFound(name, parents) + + # This could be done recursively but it'd be a bother due to the state needed + # to be tracked above like the parents, the actual command type, and the + # resulting options we care about + searching = True + options: List[ApplicationCommandInteractionDataOption] = data.get('options', []) + while searching: + for option in options: + # Find subcommands + if option.get('type', 0) in (1, 2): + parents.append(name) + name = option['name'] + command = command._get_internal_command(name) + if command is None: + raise CommandNotFound(name, parents) + options = option.get('options', []) + break + else: + searching = False + break + else: + break + + if isinstance(command, Group): + # Right now, groups can't be invoked. This is a Discord limitation in how they + # do slash commands. So if we're here and we have a Group rather than a Command instance + # then something in the code is out of date from the data that Discord has. + raise CommandSignatureMismatch(command) + + return (command, options) + + async def _call_context_menu( + self, interaction: Interaction[ClientT], data: ApplicationCommandInteractionData, type: int + ) -> None: + name = data['name'] + guild_id = _get_as_snowflake(data, 'guild_id') + ctx_menu = self._context_menus.get((name, guild_id, type)) + if ctx_menu is None and self.fallback_to_global: + ctx_menu = self._context_menus.get((name, None, type)) + + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = ctx_menu + + if ctx_menu is None: + raise CommandNotFound(name, [], AppCommandType(type)) + + resolved = Namespace._get_resolved_items(interaction, data.get('resolved', {})) + + # This is annotated as str | int but realistically this will always be str + target_id: Optional[Union[str, int]] = data.get('target_id') + # Right now, the only types are message and user + # Therefore, there's no conflict with snowflakes + + # This will always work at runtime + key = ResolveKey.any_with(target_id) # type: ignore + value = resolved.get(key) + if ctx_menu.type.value != type: + raise CommandSignatureMismatch(ctx_menu) + + if value is None: + raise AppCommandError('This should not happen if Discord sent well-formed data.') + + # I assume I don't have to type check here. + try: + await ctx_menu._invoke(interaction, value) + except AppCommandError as e: + if ctx_menu.on_error is not None: + await ctx_menu.on_error(interaction, e) + await self.on_error(interaction, e) + else: + self.client.dispatch('app_command_completion', interaction, ctx_menu) + + async def interaction_check(self, interaction: Interaction[ClientT], /) -> bool: + """|coro| + + A global check to determine if an :class:`~discord.Interaction` should + be processed by the tree. + + The default implementation returns True (all interactions are processed), + but can be overridden if custom behaviour is desired. + """ + return True + + async def _call(self, interaction: Interaction[ClientT]) -> None: + if not await self.interaction_check(interaction): + interaction.command_failed = True + return + + data: ApplicationCommandInteractionData = interaction.data # type: ignore + type = data.get('type', 1) + if type != 1: + # Context menu command... + await self._call_context_menu(interaction, data, type) + return + + command, options = self._get_app_command_options(data) + + # Pre-fill the cached slot to prevent re-computation + interaction._cs_command = command + + # At this point options refers to the arguments of the command + # and command refers to the class type we care about + namespace = Namespace(interaction, data.get('resolved', {}), options) + + # Same pre-fill as above + interaction._cs_namespace = namespace + + # Auto complete handles the namespace differently... so at this point this is where we decide where that is. + if interaction.type is InteractionType.autocomplete: + focused = next((opt['name'] for opt in options if opt.get('focused')), None) + if focused is None: + raise AppCommandError('This should not happen, but there is no focused element. This is a Discord bug.') + + try: + await command._invoke_autocomplete(interaction, focused, namespace) + except Exception: + # Suppress exception since it can't be handled anyway. + _log.exception('Ignoring exception in autocomplete for %r', command.qualified_name) + + return + + try: + await command._invoke_with_namespace(interaction, namespace) + except AppCommandError as e: + interaction.command_failed = True + await command._invoke_error_handlers(interaction, e) + await self.on_error(interaction, e) + else: + if not interaction.command_failed: + self.client.dispatch('app_command_completion', interaction, command) diff --git a/discord/appinfo.py b/discord/appinfo.py index f38f40b91423..9dd70f7efc73 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,8 +22,36 @@ DEALINGS IN THE SOFTWARE. """ -from .user import User +from __future__ import annotations + +from typing import List, TYPE_CHECKING, Literal, Optional + +from . import utils from .asset import Asset +from .flags import ApplicationFlags +from .permissions import Permissions +from .utils import MISSING + +if TYPE_CHECKING: + from typing import Dict, Any + + from .guild import Guild + from .types.appinfo import ( + AppInfo as AppInfoPayload, + PartialAppInfo as PartialAppInfoPayload, + Team as TeamPayload, + InstallParams as InstallParamsPayload, + AppIntegrationTypeConfig as AppIntegrationTypeConfigPayload, + ) + from .user import User + from .state import ConnectionState + +__all__ = ( + 'AppInfo', + 'PartialAppInfo', + 'AppInstallParams', + 'IntegrationTypeConfig', +) class AppInfo: @@ -40,9 +66,12 @@ class AppInfo: The application name. owner: :class:`User` The application owner. - icon: Optional[:class:`str`] - The icon hash, if it exists. - description: Optional[:class:`str`] + team: Optional[:class:`Team`] + The application's team. + + .. versionadded:: 1.3 + + description: :class:`str` The application description. bot_public: :class:`bool` Whether the bot can be invited by anyone or if it is locked @@ -52,27 +81,565 @@ class AppInfo: grant flow to join. rpc_origins: Optional[List[:class:`str`]] A list of RPC origin URLs, if RPC is enabled. + + verify_key: :class:`str` + The hex encoded key for verification in interactions and the + GameSDK's :ddocs:`GetTicket `. + + .. versionadded:: 1.3 + + guild_id: Optional[:class:`int`] + If this application is a game sold on Discord, + this field will be the guild to which it has been linked to. + + .. versionadded:: 1.3 + + primary_sku_id: Optional[:class:`int`] + If this application is a game sold on Discord, + this field will be the id of the "Game SKU" that is created, + if it exists. + + .. versionadded:: 1.3 + + slug: Optional[:class:`str`] + If this application is a game sold on Discord, + this field will be the URL slug that links to the store page. + + .. versionadded:: 1.3 + + terms_of_service_url: Optional[:class:`str`] + The application's terms of service URL, if set. + + .. versionadded:: 2.0 + + privacy_policy_url: Optional[:class:`str`] + The application's privacy policy URL, if set. + + .. versionadded:: 2.0 + + tags: List[:class:`str`] + The list of tags describing the functionality of the application. + + .. versionadded:: 2.0 + + custom_install_url: List[:class:`str`] + The custom authorization URL for the application, if enabled. + + .. versionadded:: 2.0 + + install_params: Optional[:class:`AppInstallParams`] + The settings for custom authorization URL of application, if enabled. + + .. versionadded:: 2.0 + role_connections_verification_url: Optional[:class:`str`] + The application's connection verification URL which will render the application as + a verification method in the guild's role verification configuration. + + .. versionadded:: 2.2 + interactions_endpoint_url: Optional[:class:`str`] + The interactions endpoint url of the application to receive interactions over this endpoint rather than + over the gateway, if configured. + + .. versionadded:: 2.4 + redirect_uris: List[:class:`str`] + A list of authentication redirect URIs. + + .. versionadded:: 2.4 + approximate_guild_count: :class:`int` + The approximate count of the guilds the bot was added to. + + .. versionadded:: 2.4 + approximate_user_install_count: Optional[:class:`int`] + The approximate count of the user-level installations the bot has. + + .. versionadded:: 2.5 """ - __slots__ = ('_state', 'description', 'id', 'name', 'rpc_origins', - 'bot_public', 'bot_require_code_grant', 'owner', 'icon') - def __init__(self, state, data): - self._state = state + __slots__ = ( + '_state', + 'description', + 'id', + 'name', + 'rpc_origins', + 'bot_public', + 'bot_require_code_grant', + 'owner', + '_icon', + 'verify_key', + 'team', + 'guild_id', + 'primary_sku_id', + 'slug', + '_cover_image', + '_flags', + 'terms_of_service_url', + 'privacy_policy_url', + 'tags', + 'custom_install_url', + 'install_params', + 'role_connections_verification_url', + 'interactions_endpoint_url', + 'redirect_uris', + 'approximate_guild_count', + 'approximate_user_install_count', + '_integration_types_config', + ) + + def __init__(self, state: ConnectionState, data: AppInfoPayload): + from .team import Team + + self._state: ConnectionState = state + self.id: int = int(data['id']) + self.name: str = data['name'] + self.description: str = data['description'] + self._icon: Optional[str] = data['icon'] + self.rpc_origins: Optional[List[str]] = data.get('rpc_origins') + self.bot_public: bool = data['bot_public'] + self.bot_require_code_grant: bool = data['bot_require_code_grant'] + self.owner: User = state.create_user(data['owner']) + + team: Optional[TeamPayload] = data.get('team') + self.team: Optional[Team] = Team(state, team) if team else None + + self.verify_key: str = data['verify_key'] - self.id = int(data['id']) - self.name = data['name'] - self.description = data['description'] - self.icon = data['icon'] - self.rpc_origins = data['rpc_origins'] - self.bot_public = data['bot_public'] - self.bot_require_code_grant = data['bot_require_code_grant'] - self.owner = User(state=self._state, data=data['owner']) + self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') - def __repr__(self): - return '<{0.__class__.__name__} id={0.id} name={0.name!r} description={0.description!r} public={0.bot_public} ' \ - 'owner={0.owner!r}>'.format(self) + self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') + self.slug: Optional[str] = data.get('slug') + self._flags: int = data.get('flags', 0) + self._cover_image: Optional[str] = data.get('cover_image') + self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') + self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.tags: List[str] = data.get('tags', []) + self.custom_install_url: Optional[str] = data.get('custom_install_url') + self.role_connections_verification_url: Optional[str] = data.get('role_connections_verification_url') + + params = data.get('install_params') + self.install_params: Optional[AppInstallParams] = AppInstallParams(params) if params else None + self.interactions_endpoint_url: Optional[str] = data.get('interactions_endpoint_url') + self.redirect_uris: List[str] = data.get('redirect_uris', []) + self.approximate_guild_count: int = data.get('approximate_guild_count', 0) + self.approximate_user_install_count: Optional[int] = data.get('approximate_user_install_count') + self._integration_types_config: Dict[Literal['0', '1'], AppIntegrationTypeConfigPayload] = data.get( + 'integration_types_config', {} + ) + + def __repr__(self) -> str: + return ( + f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' + f'description={self.description!r} public={self.bot_public} ' + f'owner={self.owner!r}>' + ) + + @property + def icon(self) -> Optional[Asset]: + """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path='app') + + @property + def cover_image(self) -> Optional[Asset]: + """Optional[:class:`.Asset`]: Retrieves the cover image on a store embed, if any. + + This is only available if the application is a game sold on Discord. + """ + if self._cover_image is None: + return None + return Asset._from_cover_image(self._state, self.id, self._cover_image) + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`Guild`]: If this application is a game sold on Discord, + this field will be the guild to which it has been linked + + .. versionadded:: 1.3 + """ + return self._state._get_guild(self.guild_id) + + @property + def flags(self) -> ApplicationFlags: + """:class:`ApplicationFlags`: The application's flags. + + .. versionadded:: 2.0 + """ + return ApplicationFlags._from_value(self._flags) + + @property + def guild_integration_config(self) -> Optional[IntegrationTypeConfig]: + """Optional[:class:`IntegrationTypeConfig`]: The default settings for the + application's installation context in a guild. + + .. versionadded:: 2.5 + """ + if not self._integration_types_config: + return None + + try: + return IntegrationTypeConfig(self._integration_types_config['0']) + except KeyError: + return None @property - def icon_url(self): - """:class:`.Asset`: Retrieves the application's icon asset.""" - return Asset._from_icon(self._state, self, 'app') + def user_integration_config(self) -> Optional[IntegrationTypeConfig]: + """Optional[:class:`IntegrationTypeConfig`]: The default settings for the + application's installation context as a user. + + .. versionadded:: 2.5 + """ + if not self._integration_types_config: + return None + + try: + return IntegrationTypeConfig(self._integration_types_config['1']) + except KeyError: + return None + + async def edit( + self, + *, + reason: Optional[str] = MISSING, + custom_install_url: Optional[str] = MISSING, + description: Optional[str] = MISSING, + role_connections_verification_url: Optional[str] = MISSING, + install_params_scopes: Optional[List[str]] = MISSING, + install_params_permissions: Optional[Permissions] = MISSING, + flags: Optional[ApplicationFlags] = MISSING, + icon: Optional[bytes] = MISSING, + cover_image: Optional[bytes] = MISSING, + interactions_endpoint_url: Optional[str] = MISSING, + tags: Optional[List[str]] = MISSING, + guild_install_scopes: Optional[List[str]] = MISSING, + guild_install_permissions: Optional[Permissions] = MISSING, + user_install_scopes: Optional[List[str]] = MISSING, + user_install_permissions: Optional[Permissions] = MISSING, + ) -> AppInfo: + r"""|coro| + + Edits the application info. + + .. versionadded:: 2.4 + + Parameters + ---------- + custom_install_url: Optional[:class:`str`] + The new custom authorization URL for the application. Can be ``None`` to remove the URL. + description: Optional[:class:`str`] + The new application description. Can be ``None`` to remove the description. + role_connections_verification_url: Optional[:class:`str`] + The new application’s connection verification URL which will render the application + as a verification method in the guild’s role verification configuration. Can be ``None`` to remove the URL. + install_params_scopes: Optional[List[:class:`str`]] + The new list of :ddocs:`OAuth2 scopes ` of + the :attr:`~install_params`. Can be ``None`` to remove the scopes. + install_params_permissions: Optional[:class:`Permissions`] + The new permissions of the :attr:`~install_params`. Can be ``None`` to remove the permissions. + flags: Optional[:class:`ApplicationFlags`] + The new application’s flags. Only limited intent flags (:attr:`~ApplicationFlags.gateway_presence_limited`, + :attr:`~ApplicationFlags.gateway_guild_members_limited`, :attr:`~ApplicationFlags.gateway_message_content_limited`) + can be edited. Can be ``None`` to remove the flags. + + .. warning:: + + Editing the limited intent flags leads to the termination of the bot. + + icon: Optional[:class:`bytes`] + The new application’s icon as a :term:`py:bytes-like object`. Can be ``None`` to remove the icon. + cover_image: Optional[:class:`bytes`] + The new application’s cover image as a :term:`py:bytes-like object` on a store embed. + The cover image is only available if the application is a game sold on Discord. + Can be ``None`` to remove the image. + interactions_endpoint_url: Optional[:class:`str`] + The new interactions endpoint url of the application to receive interactions over this endpoint rather than + over the gateway. Can be ``None`` to remove the URL. + tags: Optional[List[:class:`str`]] + The new list of tags describing the functionality of the application. Can be ``None`` to remove the tags. + guild_install_scopes: Optional[List[:class:`str`]] + The new list of :ddocs:`OAuth2 scopes ` of + the default guild installation context. Can be ``None`` to remove the scopes. + + .. versionadded: 2.5 + guild_install_permissions: Optional[:class:`Permissions`] + The new permissions of the default guild installation context. Can be ``None`` to remove the permissions. + + .. versionadded: 2.5 + user_install_scopes: Optional[List[:class:`str`]] + The new list of :ddocs:`OAuth2 scopes ` of + the default user installation context. Can be ``None`` to remove the scopes. + + .. versionadded: 2.5 + user_install_permissions: Optional[:class:`Permissions`] + The new permissions of the default user installation context. Can be ``None`` to remove the permissions. + + .. versionadded: 2.5 + reason: Optional[:class:`str`] + The reason for editing the application. Shows up on the audit log. + + Raises + ------- + HTTPException + Editing the application failed + ValueError + The image format passed in to ``icon`` or ``cover_image`` is invalid. This is also raised + when ``install_params_scopes`` and ``install_params_permissions`` are incompatible with each other, + or when ``guild_install_scopes`` and ``guild_install_permissions`` are incompatible with each other. + + Returns + ------- + :class:`AppInfo` + The newly updated application info. + """ + payload: Dict[str, Any] = {} + + if custom_install_url is not MISSING: + payload['custom_install_url'] = custom_install_url + + if description is not MISSING: + payload['description'] = description + + if role_connections_verification_url is not MISSING: + payload['role_connections_verification_url'] = role_connections_verification_url + + if install_params_scopes is not MISSING: + install_params: Optional[Dict[str, Any]] = {} + if install_params_scopes is None: + install_params = None + else: + if 'bot' not in install_params_scopes and install_params_permissions is not MISSING: + raise ValueError("'bot' must be in install_params_scopes if install_params_permissions is set") + + install_params['scopes'] = install_params_scopes + + if install_params_permissions is MISSING: + install_params['permissions'] = 0 + else: + if install_params_permissions is None: + install_params['permissions'] = 0 + else: + install_params['permissions'] = install_params_permissions.value + + payload['install_params'] = install_params + + else: + if install_params_permissions is not MISSING: + raise ValueError('install_params_scopes must be set if install_params_permissions is set') + + if flags is not MISSING: + if flags is None: + payload['flags'] = flags + else: + payload['flags'] = flags.value + + if icon is not MISSING: + if icon is None: + payload['icon'] = icon + else: + payload['icon'] = utils._bytes_to_base64_data(icon) + + if cover_image is not MISSING: + if cover_image is None: + payload['cover_image'] = cover_image + else: + payload['cover_image'] = utils._bytes_to_base64_data(cover_image) + + if interactions_endpoint_url is not MISSING: + payload['interactions_endpoint_url'] = interactions_endpoint_url + + if tags is not MISSING: + payload['tags'] = tags + + integration_types_config: Dict[str, Any] = {} + if guild_install_scopes is not MISSING or guild_install_permissions is not MISSING: + guild_install_params: Optional[Dict[str, Any]] = {} + if guild_install_scopes in (None, MISSING): + guild_install_scopes = [] + + if 'bot' not in guild_install_scopes and guild_install_permissions is not MISSING: + raise ValueError("'bot' must be in guild_install_scopes if guild_install_permissions is set") + + if guild_install_permissions in (None, MISSING): + guild_install_params['permissions'] = 0 + else: + guild_install_params['permissions'] = guild_install_permissions.value + + guild_install_params['scopes'] = guild_install_scopes + + integration_types_config['0'] = {'oauth2_install_params': guild_install_params or None} + else: + if guild_install_permissions is not MISSING: + raise ValueError('guild_install_scopes must be set if guild_install_permissions is set') + + if user_install_scopes is not MISSING or user_install_permissions is not MISSING: + user_install_params: Optional[Dict[str, Any]] = {} + if user_install_scopes in (None, MISSING): + user_install_scopes = [] + + if 'bot' not in user_install_scopes and user_install_permissions is not MISSING: + raise ValueError("'bot' must be in user_install_scopes if user_install_permissions is set") + + if user_install_permissions in (None, MISSING): + user_install_params['permissions'] = 0 + else: + user_install_params['permissions'] = user_install_permissions.value + + user_install_params['scopes'] = user_install_scopes + + integration_types_config['1'] = {'oauth2_install_params': user_install_params or None} + else: + if user_install_permissions is not MISSING: + raise ValueError('user_install_scopes must be set if user_install_permissions is set') + + if integration_types_config: + payload['integration_types_config'] = integration_types_config + + data = await self._state.http.edit_application_info(reason=reason, payload=payload) + return AppInfo(data=data, state=self._state) + + +class PartialAppInfo: + """Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite` + + .. versionadded:: 2.0 + + Attributes + ------------- + id: :class:`int` + The application ID. + name: :class:`str` + The application name. + description: :class:`str` + The application description. + rpc_origins: Optional[List[:class:`str`]] + A list of RPC origin URLs, if RPC is enabled. + verify_key: :class:`str` + The hex encoded key for verification in interactions and the + GameSDK's :ddocs:`GetTicket `. + terms_of_service_url: Optional[:class:`str`] + The application's terms of service URL, if set. + privacy_policy_url: Optional[:class:`str`] + The application's privacy policy URL, if set. + approximate_guild_count: :class:`int` + The approximate count of the guilds the bot was added to. + + .. versionadded:: 2.3 + redirect_uris: List[:class:`str`] + A list of authentication redirect URIs. + + .. versionadded:: 2.3 + interactions_endpoint_url: Optional[:class:`str`] + The interactions endpoint url of the application to receive interactions over this endpoint rather than + over the gateway, if configured. + + .. versionadded:: 2.3 + role_connections_verification_url: Optional[:class:`str`] + The application's connection verification URL which will render the application as + a verification method in the guild's role verification configuration. + + .. versionadded:: 2.3 + """ + + __slots__ = ( + '_state', + 'id', + 'name', + 'description', + 'rpc_origins', + 'verify_key', + 'terms_of_service_url', + 'privacy_policy_url', + '_icon', + '_flags', + '_cover_image', + 'approximate_guild_count', + 'redirect_uris', + 'interactions_endpoint_url', + 'role_connections_verification_url', + ) + + def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self.name: str = data['name'] + self._icon: Optional[str] = data.get('icon') + self._flags: int = data.get('flags', 0) + self._cover_image: Optional[str] = data.get('cover_image') + self.description: str = data['description'] + self.rpc_origins: Optional[List[str]] = data.get('rpc_origins') + self.verify_key: str = data['verify_key'] + self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') + self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.approximate_guild_count: int = data.get('approximate_guild_count', 0) + self.redirect_uris: List[str] = data.get('redirect_uris', []) + self.interactions_endpoint_url: Optional[str] = data.get('interactions_endpoint_url') + self.role_connections_verification_url: Optional[str] = data.get('role_connections_verification_url') + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>' + + @property + def icon(self) -> Optional[Asset]: + """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path='app') + + @property + def cover_image(self) -> Optional[Asset]: + """Optional[:class:`.Asset`]: Retrieves the cover image of the application's default rich presence. + + This is only available if the application is a game sold on Discord. + + .. versionadded:: 2.3 + """ + if self._cover_image is None: + return None + return Asset._from_cover_image(self._state, self.id, self._cover_image) + + @property + def flags(self) -> ApplicationFlags: + """:class:`ApplicationFlags`: The application's flags. + + .. versionadded:: 2.0 + """ + return ApplicationFlags._from_value(self._flags) + + +class AppInstallParams: + """Represents the settings for custom authorization URL of an application. + + .. versionadded:: 2.0 + + Attributes + ---------- + scopes: List[:class:`str`] + The list of :ddocs:`OAuth2 scopes ` + to add the application to a guild with. + permissions: :class:`Permissions` + The permissions to give to application in the guild. + """ + + __slots__ = ('scopes', 'permissions') + + def __init__(self, data: InstallParamsPayload) -> None: + self.scopes: List[str] = data.get('scopes', []) + self.permissions: Permissions = Permissions(int(data['permissions'])) + + +class IntegrationTypeConfig: + """Represents the default settings for the application's installation context. + + .. versionadded:: 2.5 + + Attributes + ---------- + oauth2_install_params: Optional[:class:`AppInstallParams`] + The install params for this installation context's default in-app authorization link. + """ + + def __init__(self, data: AppIntegrationTypeConfigPayload) -> None: + self.oauth2_install_params: Optional[AppInstallParams] = None + try: + self.oauth2_install_params = AppInstallParams(data['oauth2_install_params']) # type: ignore # EAFP + except KeyError: + pass diff --git a/discord/asset.py b/discord/asset.py index af5a480012f0..41bcba3cf186 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,15 +22,159 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import io +import os +from typing import Any, Literal, Optional, TYPE_CHECKING, Tuple, Union from .errors import DiscordException -from .errors import InvalidArgument from . import utils +from .file import File + +import yarl + +# fmt: off +__all__ = ( + 'Asset', +) +# fmt: on + +if TYPE_CHECKING: + from typing_extensions import Self + + from .state import ConnectionState + from .webhook.async_ import _WebhookState + + _State = Union[ConnectionState, _WebhookState] + + ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] + ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] + +VALID_STATIC_FORMATS = frozenset({'jpeg', 'jpg', 'webp', 'png'}) +VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {'gif'} + + +MISSING = utils.MISSING + + +class AssetMixin: + __slots__ = () + url: str + _state: Optional[Any] + + async def read(self) -> bytes: + """|coro| + + Retrieves the content of this asset as a :class:`bytes` object. + + Raises + ------ + DiscordException + There was no internal connection state. + HTTPException + Downloading the asset failed. + NotFound + The asset was deleted. + + Returns + ------- + :class:`bytes` + The content of the asset. + """ + if self._state is None: + raise DiscordException('Invalid state (no ConnectionState provided)') + + return await self._state.http.get_from_cdn(self.url) + + async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int: + """|coro| + + Saves this asset into a file-like object. + + Parameters + ---------- + fp: Union[:class:`io.BufferedIOBase`, :class:`os.PathLike`] + The file-like object to save this asset to or the filename + to use. If a filename is passed then a file is created with that + filename and used instead. + seek_begin: :class:`bool` + Whether to seek to the beginning of the file after saving is + successfully done. + + Raises + ------ + DiscordException + There was no internal connection state. + HTTPException + Downloading the asset failed. + NotFound + The asset was deleted. + + Returns + -------- + :class:`int` + The number of bytes written. + """ + + data = await self.read() + if isinstance(fp, io.BufferedIOBase): + written = fp.write(data) + if seek_begin: + fp.seek(0) + return written + else: + with open(fp, 'wb') as f: + return f.write(data) + + async def to_file( + self, + *, + filename: Optional[str] = MISSING, + description: Optional[str] = None, + spoiler: bool = False, + ) -> File: + """|coro| + + Converts the asset into a :class:`File` suitable for sending via + :meth:`abc.Messageable.send`. -VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) -VALID_AVATAR_FORMATS = VALID_STATIC_FORMATS | {"gif"} + .. versionadded:: 2.0 + + Parameters + ----------- + filename: Optional[:class:`str`] + The filename of the file. If not provided, then the filename from + the asset's URL is used. + description: Optional[:class:`str`] + The description for the file. + spoiler: :class:`bool` + Whether the file is a spoiler. + + Raises + ------ + DiscordException + The asset does not have an associated state. + ValueError + The asset is a unicode emoji. + TypeError + The asset is a sticker with lottie type. + HTTPException + Downloading the asset failed. + NotFound + The asset was deleted. + + Returns + ------- + :class:`File` + The asset as a file suitable for sending. + """ -class Asset: + data = await self.read() + file_filename = filename if filename is not MISSING else yarl.URL(self.url).name + return File(io.BytesIO(data), filename=file_filename, description=description, spoiler=spoiler) + + +class Asset(AssetMixin): """Represents a CDN asset on Discord. .. container:: operations @@ -45,10 +187,6 @@ class Asset: Returns the length of the CDN asset's URL. - .. describe:: bool(x) - - Checks if the Asset has a URL. - .. describe:: x == y Checks if the asset is equal to another asset. @@ -61,165 +199,357 @@ class Asset: Returns the hash of the asset. """ - __slots__ = ('_state', '_url') - def __init__(self, state, url=None): - self._state = state - self._url = url + __slots__: Tuple[str, ...] = ( + '_state', + '_url', + '_animated', + '_key', + ) + + BASE = 'https://cdn.discordapp.com' + + def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None: + self._state: _State = state + self._url: str = url + self._animated: bool = animated + self._key: str = key @classmethod - def _from_avatar(cls, state, user, *, format=None, static_format='webp', size=1024): - if not utils.valid_icon_size(size): - raise InvalidArgument("size must be a power of 2 between 16 and 4096") - if format is not None and format not in VALID_AVATAR_FORMATS: - raise InvalidArgument("format must be None or one of {}".format(VALID_AVATAR_FORMATS)) - if format == "gif" and not user.is_avatar_animated(): - raise InvalidArgument("non animated avatars do not support gif format") - if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument("static_format must be one of {}".format(VALID_STATIC_FORMATS)) + def _from_default_avatar(cls, state: _State, index: int) -> Self: + return cls( + state, + url=f'{cls.BASE}/embed/avatars/{index}.png', + key=str(index), + animated=False, + ) - if user.avatar is None: - return user.default_avatar_url + @classmethod + def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self: + animated = avatar.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024', + key=avatar, + animated=animated, + ) - if format is None: - format = 'gif' if user.is_avatar_animated() else static_format + @classmethod + def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self: + animated = avatar.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024', + key=avatar, + animated=animated, + ) - return cls(state, 'https://cdn.discordapp.com/avatars/{0.id}/{0.avatar}.{1}?size={2}'.format(user, format, size)) + @classmethod + def _from_guild_banner(cls, state: _State, guild_id: int, member_id: int, banner: str) -> Self: + animated = banner.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/guilds/{guild_id}/users/{member_id}/banners/{banner}.{format}?size=1024', + key=banner, + animated=animated, + ) @classmethod - def _from_icon(cls, state, object, path): - if object.icon is None: - return cls(state) + def _from_avatar_decoration(cls, state: _State, avatar_decoration: str) -> Self: + return cls( + state, + url=f'{cls.BASE}/avatar-decoration-presets/{avatar_decoration}.png?size=96', + key=avatar_decoration, + animated=True, + ) - url = 'https://cdn.discordapp.com/{0}-icons/{1.id}/{1.icon}.jpg'.format(path, object) - return cls(state, url) + @classmethod + def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self: + return cls( + state, + url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', + key=icon_hash, + animated=False, + ) @classmethod - def _from_guild_image(cls, state, id, hash, key, *, format='webp', size=1024): - if not utils.valid_icon_size(size): - raise InvalidArgument("size must be a power of 2 between 16 and 4096") - if format not in VALID_STATIC_FORMATS: - raise InvalidArgument("format must be one of {}".format(VALID_STATIC_FORMATS)) + def _from_app_icon( + cls, state: _State, object_id: int, icon_hash: str, asset_type: Literal['icon', 'cover_image'] + ) -> Self: + return cls( + state, + url=f'{cls.BASE}/app-icons/{object_id}/{asset_type}.png?size=1024', + key=icon_hash, + animated=False, + ) - if hash is None: - return cls(state) + @classmethod + def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self: + return cls( + state, + url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', + key=cover_image_hash, + animated=False, + ) - url = 'https://cdn.discordapp.com/{key}/{0}/{1}.{2}?size={3}' - return cls(state, url.format(id, hash, format, size, key=key)) + @classmethod + def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self: + return cls( + state, + url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024', + key=cover_image_hash, + animated=False, + ) @classmethod - def _from_guild_icon(cls, state, guild, *, format=None, static_format='webp', size=1024): - if not utils.valid_icon_size(size): - raise InvalidArgument("size must be a power of 2 between 16 and 4096") - if format is not None and format not in VALID_AVATAR_FORMATS: - raise InvalidArgument("format must be one of {}".format(VALID_AVATAR_FORMATS)) - if format == "gif" and not guild.is_icon_animated(): - raise InvalidArgument("non animated guild icons do not support gif format") - if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument("static_format must be one of {}".format(VALID_STATIC_FORMATS)) + def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self: + animated = image.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/{path}/{guild_id}/{image}.{format}?size=1024', + key=image, + animated=animated, + ) + + @classmethod + def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self: + animated = icon_hash.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', + key=icon_hash, + animated=animated, + ) + + @classmethod + def _from_sticker_banner(cls, state: _State, banner: int) -> Self: + return cls( + state, + url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', + key=str(banner), + animated=False, + ) + + @classmethod + def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self: + animated = banner_hash.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + state, + url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', + key=banner_hash, + animated=animated, + ) - if guild.icon is None: - return cls(state) + @classmethod + def _from_primary_guild(cls, state: _State, guild_id: int, icon_hash: str) -> Self: + return cls( + state, + url=f'{cls.BASE}/guild-tag-badges/{guild_id}/{icon_hash}.png?size=64', + key=icon_hash, + animated=False, + ) - if format is None: - format = 'gif' if guild.is_icon_animated() else static_format + @classmethod + def _from_user_collectible(cls, state: _State, asset: str, animated: bool = False) -> Self: + name = 'static.png' if not animated else 'asset.webm' + return cls( + state, + url=f'{cls.BASE}/assets/collectibles/{asset}{name}', + key=asset, + animated=animated, + ) + + def __str__(self) -> str: + return self._url + + def __len__(self) -> int: + return len(self._url) + + def __repr__(self) -> str: + shorten = self._url.replace(self.BASE, '') + return f'' + + def __eq__(self, other: object) -> bool: + return isinstance(other, Asset) and self._url == other._url - return cls(state, 'https://cdn.discordapp.com/icons/{0.id}/{0.icon}.{1}?size={2}'.format(guild, format, size)) + def __hash__(self) -> int: + return hash(self._url) + @property + def url(self) -> str: + """:class:`str`: Returns the underlying URL of the asset.""" + return self._url - def __str__(self): - return self._url if self._url is not None else '' + @property + def key(self) -> str: + """:class:`str`: Returns the identifying key of the asset.""" + return self._key - def __len__(self): - if self._url: - return len(self._url) - return 0 + def is_animated(self) -> bool: + """:class:`bool`: Returns whether the asset is animated.""" + return self._animated - def __bool__(self): - return self._url is not None + def replace( + self, + *, + size: int = MISSING, + format: ValidAssetFormatTypes = MISSING, + static_format: ValidStaticFormatTypes = MISSING, + ) -> Self: + """Returns a new asset with the passed components replaced. - def __repr__(self): - return ''.format(self) - def __eq__(self, other): - return isinstance(other, Asset) and self._url == other._url + .. versionchanged:: 2.0 + ``static_format`` is now preferred over ``format`` + if both are present and the asset is not animated. - def __ne__(self, other): - return not self.__eq__(other) + .. versionchanged:: 2.0 + This function will now raise :exc:`ValueError` instead of + ``InvalidArgument``. - def __hash__(self): - return hash(self._url) + Parameters + ----------- + size: :class:`int` + The new size of the asset. + format: :class:`str` + The new format to change it to. Must be either + 'webp', 'jpeg', 'jpg', 'png', or 'gif' if it's animated. + static_format: :class:`str` + The new format to change it to if the asset isn't animated. + Must be either 'webp', 'jpeg', 'jpg', or 'png'. - async def read(self): - """|coro| + Raises + ------- + ValueError + An invalid size or format was passed. - Retrieves the content of this asset as a :class:`bytes` object. + Returns + -------- + :class:`Asset` + The newly updated asset. + """ + url = yarl.URL(self._url) + path, _ = os.path.splitext(url.path) + + if format is not MISSING: + if self._animated: + if format not in VALID_ASSET_FORMATS: + raise ValueError(f'format must be one of {VALID_ASSET_FORMATS}') + else: + if static_format is MISSING and format not in VALID_STATIC_FORMATS: + raise ValueError(f'format must be one of {VALID_STATIC_FORMATS}') + url = url.with_path(f'{path}.{format}') + + if static_format is not MISSING and not self._animated: + if static_format not in VALID_STATIC_FORMATS: + raise ValueError(f'static_format must be one of {VALID_STATIC_FORMATS}') + url = url.with_path(f'{path}.{static_format}') + + if size is not MISSING: + if not utils.valid_icon_size(size): + raise ValueError('size must be a power of 2 between 16 and 4096') + url = url.with_query(size=size) + else: + url = url.with_query(url.raw_query_string) - .. warning:: + url = str(url) + return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated) - :class:`PartialEmoji` won't have a connection state if user created, - and a URL won't be present if a custom image isn't associated with - the asset, e.g. a guild with no custom icon. + def with_size(self, size: int, /) -> Self: + """Returns a new asset with the specified size. - .. versionadded:: 1.1.0 + .. versionchanged:: 2.0 + This function will now raise :exc:`ValueError` instead of + ``InvalidArgument``. Parameters - ----------- - fp: Union[:class:`io.BufferedIOBase`, :class:`os.PathLike`] - Same as in :meth:`Attachment.save`. - seek_begin: :class:`bool` - Same as in :meth:`Attachment.save`. + ------------ + size: :class:`int` + The new size of the asset. Raises - ------ - DiscordException - There was no valid URL or internal connection state. - HTTPException - Downloading the asset failed. - NotFound - The asset was deleted. + ------- + ValueError + The asset had an invalid size. Returns + -------- + :class:`Asset` + The new updated asset. + """ + if not utils.valid_icon_size(size): + raise ValueError('size must be a power of 2 between 16 and 4096') + + url = str(yarl.URL(self._url).with_query(size=size)) + return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated) + + def with_format(self, format: ValidAssetFormatTypes, /) -> Self: + """Returns a new asset with the specified format. + + .. versionchanged:: 2.0 + This function will now raise :exc:`ValueError` instead of + ``InvalidArgument``. + + Parameters + ------------ + format: :class:`str` + The new format of the asset. + + Raises ------- - :class:`bytes` - The content of the asset. + ValueError + The asset had an invalid format. + + Returns + -------- + :class:`Asset` + The new updated asset. """ - if not self._url: - raise DiscordException('Invalid asset (no URL provided)') - if self._state is None: - raise DiscordException('Invalid state (no ConnectionState provided)') + if self._animated: + if format not in VALID_ASSET_FORMATS: + raise ValueError(f'format must be one of {VALID_ASSET_FORMATS}') + else: + if format not in VALID_STATIC_FORMATS: + raise ValueError(f'format must be one of {VALID_STATIC_FORMATS}') - return await self._state.http.get_from_cdn(self._url) + url = yarl.URL(self._url) + path, _ = os.path.splitext(url.path) + url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) + return self.__class__(state=self._state, url=url, key=self._key, animated=self._animated) - async def save(self, fp, *, seek_begin=True): - """|coro| + def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self: + """Returns a new asset with the specified static format. - Saves this asset into a file-like object. + This only changes the format if the underlying asset is + not animated. Otherwise, the asset is not changed. + + .. versionchanged:: 2.0 + This function will now raise :exc:`ValueError` instead of + ``InvalidArgument``. Parameters - ---------- - fp: Union[BinaryIO, :class:`os.PathLike`] - Same as in :meth:`Attachment.save`. - seek_begin: :class:`bool` - Same as in :meth:`Attachment.save`. + ------------ + format: :class:`str` + The new static format of the asset. Raises - ------ - Same as :meth:`read`. + ------- + ValueError + The asset had an invalid format. Returns -------- - :class:`int` - The number of bytes written. + :class:`Asset` + The new updated asset. """ - data = await self.read() - if isinstance(fp, io.IOBase) and fp.writable(): - written = fp.write(data) - if seek_begin: - fp.seek(0) - return written - else: - with open(fp, 'wb') as f: - return f.write(data) + if self._animated: + return self + return self.with_format(format) diff --git a/discord/audit_logs.py b/discord/audit_logs.py index b997e91040de..e56f0fb3d252 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,122 +22,423 @@ DEALINGS IN THE SOFTWARE. """ -from . import utils, enums -from .object import Object -from .permissions import PermissionOverwrite, Permissions +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, Generator, List, Optional, Tuple, Type, TypeVar, Union + +from . import enums, flags, utils +from .asset import Asset from .colour import Colour from .invite import Invite +from .mixins import Hashable +from .object import Object +from .permissions import PermissionOverwrite, Permissions +from .automod import AutoModTrigger, AutoModRuleAction, AutoModRule +from .role import Role +from .emoji import Emoji +from .partial_emoji import PartialEmoji +from .member import Member +from .scheduled_event import ScheduledEvent +from .stage_instance import StageInstance +from .sticker import GuildSticker +from .threads import Thread +from .integrations import PartialIntegration +from .channel import ForumChannel, StageChannel, ForumTag +from .onboarding import OnboardingPrompt, OnboardingPromptOption + +__all__ = ( + 'AuditLogDiff', + 'AuditLogChanges', + 'AuditLogEntry', +) + + +if TYPE_CHECKING: + import datetime + + from . import abc + from .guild import Guild + from .state import ConnectionState + from .types.audit_log import ( + AuditLogChange as AuditLogChangePayload, + AuditLogEntry as AuditLogEntryPayload, + _AuditLogChange_TriggerMetadata as AuditLogChangeTriggerMetadataPayload, + ) + from .types.channel import ( + PermissionOverwrite as PermissionOverwritePayload, + ForumTag as ForumTagPayload, + DefaultReaction as DefaultReactionPayload, + ) + from .types.invite import Invite as InvitePayload + from .types.role import Role as RolePayload, RoleColours + from .types.snowflake import Snowflake + from .types.command import ApplicationCommandPermissions + from .types.automod import AutoModerationAction + from .types.onboarding import Prompt as PromptPayload, PromptOption as PromptOptionPayload + from .user import User + from .app_commands import AppCommand + from .webhook import Webhook + + TargetType = Union[ + Guild, + abc.GuildChannel, + Member, + User, + Role, + Invite, + Emoji, + StageInstance, + GuildSticker, + Thread, + Object, + PartialIntegration, + AutoModRule, + ScheduledEvent, + Webhook, + AppCommand, + None, + ] + + +def _transform_timestamp(entry: AuditLogEntry, data: Optional[str]) -> Optional[datetime.datetime]: + return utils.parse_time(data) + + +def _transform_color(entry: AuditLogEntry, data: int) -> Colour: + return Colour(data) -def _transform_verification_level(entry, data): - return enums.try_enum(enums.VerificationLevel, data) -def _transform_default_notifications(entry, data): - return enums.try_enum(enums.NotificationLevel, data) +def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: + return int(data) -def _transform_explicit_content_filter(entry, data): - return enums.try_enum(enums.ContentFilter, data) -def _transform_permissions(entry, data): - return Permissions(data) +def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]: + if data is None: + return None + return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_color(entry, data): - return Colour(data) -def _transform_snowflake(entry, data): - return int(data) +def _transform_channels_or_threads( + entry: AuditLogEntry, data: List[Snowflake] +) -> List[Union[abc.GuildChannel, Thread, Object]]: + return [entry.guild.get_channel_or_thread(int(data)) or Object(id=data) for data in data] -def _transform_channel(entry, data): + +def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]: if data is None: return None - channel = entry.guild.get_channel(int(data)) or Object(id=data) - return channel + return entry._get_member(int(data)) -def _transform_owner_id(entry, data): + +def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: if data is None: return None - return entry._get_member(int(data)) + return entry._state._get_guild(int(data)) + + +def _transform_roles(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[Role, Object]]: + return [entry.guild.get_role(int(role_id)) or Object(role_id, type=Role) for role_id in data] -def _transform_inviter_id(entry, data): + +def _transform_applied_forum_tags(entry: AuditLogEntry, data: List[Snowflake]) -> List[Union[ForumTag, Object]]: + thread = entry.target + if isinstance(thread, Thread) and isinstance(thread.parent, ForumChannel): + return [thread.parent.get_tag(tag_id) or Object(id=tag_id, type=ForumTag) for tag_id in map(int, data)] + return [Object(id=tag_id, type=ForumTag) for tag_id in data] + + +def _transform_overloaded_flags(entry: AuditLogEntry, data: int) -> Union[int, flags.ChannelFlags, flags.InviteFlags]: + # The `flags` key is definitely overloaded. Right now it's for channels, threads and invites but + # I am aware of `member.flags` and `user.flags` existing. However, this does not impact audit logs + # at the moment but better safe than sorry. + channel_audit_log_types = ( + enums.AuditLogAction.channel_create, + enums.AuditLogAction.channel_update, + enums.AuditLogAction.channel_delete, + enums.AuditLogAction.thread_create, + enums.AuditLogAction.thread_update, + enums.AuditLogAction.thread_delete, + ) + invite_audit_log_types = ( + enums.AuditLogAction.invite_create, + enums.AuditLogAction.invite_update, + enums.AuditLogAction.invite_delete, + ) + + if entry.action in channel_audit_log_types: + return flags.ChannelFlags._from_value(data) + elif entry.action in invite_audit_log_types: + return flags.InviteFlags._from_value(data) + return data + + +def _transform_forum_tags(entry: AuditLogEntry, data: List[ForumTagPayload]) -> List[ForumTag]: + return [ForumTag.from_data(state=entry._state, data=d) for d in data] + + +def _transform_default_reaction(entry: AuditLogEntry, data: DefaultReactionPayload) -> Optional[PartialEmoji]: if data is None: return None - return entry._get_member(int(data)) -def _transform_overwrites(entry, data): + emoji_name = data.get('emoji_name') or '' + emoji_id = utils._get_as_snowflake(data, 'emoji_id') or None # Coerce 0 -> None + return PartialEmoji.with_state(state=entry._state, name=emoji_name, id=emoji_id) + + +def _transform_overwrites( + entry: AuditLogEntry, data: List[PermissionOverwritePayload] +) -> List[Tuple[Object, PermissionOverwrite]]: overwrites = [] for elem in data: - allow = Permissions(elem['allow']) - deny = Permissions(elem['deny']) + allow = Permissions(int(elem['allow'])) + deny = Permissions(int(elem['deny'])) ow = PermissionOverwrite.from_pair(allow, deny) ow_type = elem['type'] ow_id = int(elem['id']) - if ow_type == 'role': + target = None + if ow_type == '0': target = entry.guild.get_role(ow_id) - else: + elif ow_type == '1': target = entry._get_member(ow_id) if target is None: - target = Object(id=ow_id) + target = Object(id=ow_id, type=Role if ow_type == '0' else Member) overwrites.append((target, ow)) return overwrites + +def _transform_icon(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: + if data is None: + return None + if entry.action is enums.AuditLogAction.guild_update: + return Asset._from_guild_icon(entry._state, entry.guild.id, data) + else: + return Asset._from_icon(entry._state, entry._target_id, data, path='role') # type: ignore # target_id won't be None in this case + + +def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: + if data is None: + return None + return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore # target_id won't be None in this case + + +def _transform_cover_image(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: + if data is None: + return None + return Asset._from_scheduled_event_cover_image(entry._state, entry._target_id, data) # type: ignore # target_id won't be None in this case + + +def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]: + def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: + if data is None: + return None + return Asset._from_guild_image(entry._state, entry.guild.id, data, path=path) + + return _transform + + +def _transform_automod_actions(entry: AuditLogEntry, data: List[AutoModerationAction]) -> List[AutoModRuleAction]: + return [AutoModRuleAction.from_data(action) for action in data] + + +def _transform_default_emoji(entry: AuditLogEntry, data: str) -> PartialEmoji: + return PartialEmoji(name=data) + + +def _transform_onboarding_prompts(entry: AuditLogEntry, data: List[PromptPayload]) -> List[OnboardingPrompt]: + return [OnboardingPrompt.from_dict(data=prompt, state=entry._state, guild=entry.guild) for prompt in data] + + +def _transform_onboarding_prompt_options( + entry: AuditLogEntry, data: List[PromptOptionPayload] +) -> List[OnboardingPromptOption]: + return [OnboardingPromptOption.from_dict(data=option, state=entry._state, guild=entry.guild) for option in data] + + +E = TypeVar('E', bound=enums.Enum) + + +def _enum_transformer(enum: Type[E]) -> Callable[[AuditLogEntry, int], E]: + def _transform(entry: AuditLogEntry, data: int) -> E: + return enums.try_enum(enum, data) + + return _transform + + +F = TypeVar('F', bound=flags.BaseFlags) + + +def _flag_transformer(cls: Type[F]) -> Callable[[AuditLogEntry, Union[int, str]], F]: + def _transform(entry: AuditLogEntry, data: Union[int, str]) -> F: + return cls._from_value(int(data)) + + return _transform + + +def _transform_type( + entry: AuditLogEntry, data: Union[int, str] +) -> Union[enums.ChannelType, enums.StickerType, enums.WebhookType, str, enums.OnboardingPromptType]: + if entry.action.name.startswith('sticker_'): + return enums.try_enum(enums.StickerType, data) + elif entry.action.name.startswith('integration_'): + return data # type: ignore # integration type is str + elif entry.action.name.startswith('webhook_'): + return enums.try_enum(enums.WebhookType, data) + elif entry.action.name.startswith('onboarding_prompt_'): + return enums.try_enum(enums.OnboardingPromptType, data) + else: + return enums.try_enum(enums.ChannelType, data) + + class AuditLogDiff: - def __len__(self): + def __len__(self) -> int: return len(self.__dict__) - def __iter__(self): - return iter(self.__dict__.items()) + def __iter__(self) -> Generator[Tuple[str, Any], None, None]: + yield from self.__dict__.items() - def __repr__(self): + def __repr__(self) -> str: values = ' '.join('%s=%r' % item for item in self.__dict__.items()) - return '' % values + return f'' + + if TYPE_CHECKING: + + def __getattr__(self, item: str) -> Any: ... + + def __setattr__(self, key: str, value: Any) -> Any: ... + + +Transformer = Callable[['AuditLogEntry', Any], Any] + class AuditLogChanges: - TRANSFORMERS = { - 'verification_level': (None, _transform_verification_level), - 'explicit_content_filter': (None, _transform_explicit_content_filter), - 'allow': (None, _transform_permissions), - 'deny': (None, _transform_permissions), - 'permissions': (None, _transform_permissions), - 'id': (None, _transform_snowflake), - 'color': ('colour', _transform_color), - 'owner_id': ('owner', _transform_owner_id), - 'inviter_id': ('inviter', _transform_inviter_id), - 'channel_id': ('channel', _transform_channel), - 'afk_channel_id': ('afk_channel', _transform_channel), - 'system_channel_id': ('system_channel', _transform_channel), - 'widget_channel_id': ('widget_channel', _transform_channel), - 'permission_overwrites': ('overwrites', _transform_overwrites), - 'splash_hash': ('splash', None), - 'icon_hash': ('icon', None), - 'avatar_hash': ('avatar', None), - 'rate_limit_per_user': ('slowmode_delay', None), - 'default_message_notifications': ('default_notifications', _transform_default_notifications), + # fmt: off + TRANSFORMERS: ClassVar[Mapping[str, Tuple[Optional[str], Optional[Transformer]]]] = { + 'verification_level': (None, _enum_transformer(enums.VerificationLevel)), + 'explicit_content_filter': (None, _enum_transformer(enums.ContentFilter)), + 'allow': (None, _flag_transformer(Permissions)), + 'deny': (None, _flag_transformer(Permissions)), + 'permissions': (None, _flag_transformer(Permissions)), + 'id': (None, _transform_snowflake), + 'color': ('colour', _transform_color), + 'owner_id': ('owner', _transform_member_id), + 'inviter_id': ('inviter', _transform_member_id), + 'channel_id': ('channel', _transform_channel), + 'afk_channel_id': ('afk_channel', _transform_channel), + 'system_channel_id': ('system_channel', _transform_channel), + 'system_channel_flags': (None, _flag_transformer(flags.SystemChannelFlags)), + 'widget_channel_id': ('widget_channel', _transform_channel), + 'rules_channel_id': ('rules_channel', _transform_channel), + 'public_updates_channel_id': ('public_updates_channel', _transform_channel), + 'permission_overwrites': ('overwrites', _transform_overwrites), + 'splash_hash': ('splash', _guild_hash_transformer('splashes')), + 'banner_hash': ('banner', _guild_hash_transformer('banners')), + 'discovery_splash_hash': ('discovery_splash', _guild_hash_transformer('discovery-splashes')), + 'icon_hash': ('icon', _transform_icon), + 'avatar_hash': ('avatar', _transform_avatar), + 'rate_limit_per_user': ('slowmode_delay', None), + 'default_thread_rate_limit_per_user': ('default_thread_slowmode_delay', None), + 'guild_id': ('guild', _transform_guild_id), + 'tags': ('emoji', None), + 'default_message_notifications': ('default_notifications', _enum_transformer(enums.NotificationLevel)), + 'video_quality_mode': (None, _enum_transformer(enums.VideoQualityMode)), + 'privacy_level': (None, _enum_transformer(enums.PrivacyLevel)), + 'format_type': (None, _enum_transformer(enums.StickerFormatType)), + 'type': (None, _transform_type), + 'communication_disabled_until': ('timed_out_until', _transform_timestamp), + 'expire_behavior': (None, _enum_transformer(enums.ExpireBehaviour)), + 'mfa_level': (None, _enum_transformer(enums.MFALevel)), + 'status': (None, _enum_transformer(enums.EventStatus)), + 'entity_type': (None, _enum_transformer(enums.EntityType)), + 'preferred_locale': (None, _enum_transformer(enums.Locale)), + 'image_hash': ('cover_image', _transform_cover_image), + 'trigger_type': (None, _enum_transformer(enums.AutoModRuleTriggerType)), + 'event_type': (None, _enum_transformer(enums.AutoModRuleEventType)), + 'actions': (None, _transform_automod_actions), + 'exempt_channels': (None, _transform_channels_or_threads), + 'exempt_roles': (None, _transform_roles), + 'applied_tags': (None, _transform_applied_forum_tags), + 'available_tags': (None, _transform_forum_tags), + 'flags': (None, _transform_overloaded_flags), + 'default_reaction_emoji': (None, _transform_default_reaction), + 'emoji_name': ('emoji', _transform_default_emoji), + 'user_id': ('user', _transform_member_id), + 'options': (None, _transform_onboarding_prompt_options), + 'prompts': (None, _transform_onboarding_prompts), + 'default_channel_ids': ('default_channels', _transform_channels_or_threads), + 'mode': (None, _enum_transformer(enums.OnboardingMode)), } - - def __init__(self, entry, data): - self.before = AuditLogDiff() - self.after = AuditLogDiff() + # fmt: on + + def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): + self.before: AuditLogDiff = AuditLogDiff() + self.after: AuditLogDiff = AuditLogDiff() + # special case entire process since each + # element in data is a different target + # key is the target id + if entry.action is enums.AuditLogAction.app_command_permission_update: + self.before.app_command_permissions = [] + self.after.app_command_permissions = [] + + for elem in data: + self._handle_app_command_permissions( + self.before, + entry, + elem.get('old_value'), # type: ignore # value will be an ApplicationCommandPermissions if present + ) + + self._handle_app_command_permissions( + self.after, + entry, + elem.get('new_value'), # type: ignore # value will be an ApplicationCommandPermissions if present + ) + return for elem in data: attr = elem['key'] # special cases for role add/remove if attr == '$add': - self._handle_role(self.before, self.after, entry, elem['new_value']) + self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore # new_value is a list of roles in this case continue elif attr == '$remove': - self._handle_role(self.after, self.before, entry, elem['new_value']) + self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore # new_value is a list of roles in this case continue - transformer = self.TRANSFORMERS.get(attr) - if transformer: - key, transformer = transformer + # special case for automod trigger + if attr == 'trigger_metadata': + # given full metadata dict + self._handle_trigger_metadata(entry, elem, data) # type: ignore # should be trigger metadata + continue + elif entry.action is enums.AuditLogAction.automod_rule_update and attr.startswith('$'): + # on update, some trigger attributes are keys and formatted as $(add/remove)_{attribute} + action, _, trigger_attr = attr.partition('_') + # new_value should be a list of added/removed strings for keyword_filter, regex_patterns, or allow_list + if action == '$add': + self._handle_trigger_attr_update(self.before, self.after, entry, trigger_attr, elem['new_value']) # type: ignore + elif action == '$remove': + self._handle_trigger_attr_update(self.after, self.before, entry, trigger_attr, elem['new_value']) # type: ignore + continue + + # special case for colors to set secondary and tertiary colos/colour attributes + if attr == 'colors': + self._handle_colours(self.before, elem.get('old_value')) # type: ignore # should be a RoleColours dict + self._handle_colours(self.after, elem.get('new_value')) # type: ignore # should be a RoleColours dict + continue + + try: + key, transformer = self.TRANSFORMERS[attr] + except (ValueError, KeyError): + transformer = None + else: if key: attr = key + transformer: Optional[Transformer] + try: before = elem['old_value'] except KeyError: @@ -164,43 +463,214 @@ def __init__(self, entry, data): if hasattr(self.after, 'colour'): self.after.color = self.after.colour self.before.color = self.before.colour + if hasattr(self.after, 'expire_behavior'): + self.after.expire_behaviour = self.after.expire_behavior + self.before.expire_behaviour = self.before.expire_behavior - def __repr__(self): - return '' % (self.before, self.after) + def __repr__(self) -> str: + return f'' - def _handle_role(self, first, second, entry, elem): + def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: if not hasattr(first, 'roles'): setattr(first, 'roles', []) data = [] - g = entry.guild + g: Guild = entry.guild for e in elem: role_id = int(e['id']) role = g.get_role(role_id) if role is None: - role = Object(id=role_id) - role.name = e['name'] + role = Object(id=role_id, type=Role) + role.name = e['name'] # type: ignore # Object doesn't usually have name data.append(role) setattr(second, 'roles', data) -class AuditLogEntry: + def _handle_app_command_permissions( + self, + diff: AuditLogDiff, + entry: AuditLogEntry, + data: Optional[ApplicationCommandPermissions], + ): + if data is None: + return + + # avoid circular import + from discord.app_commands import AppCommandPermissions + + state = entry._state + guild = entry.guild + diff.app_command_permissions.append(AppCommandPermissions(data=data, guild=guild, state=state)) + + def _handle_trigger_metadata( + self, + entry: AuditLogEntry, + data: AuditLogChangeTriggerMetadataPayload, + full_data: List[AuditLogChangePayload], + ): + trigger_value: Optional[int] = None + trigger_type: Optional[enums.AutoModRuleTriggerType] = None + + # try to get trigger type from before or after + trigger_type = getattr(self.before, 'trigger_type', getattr(self.after, 'trigger_type', None)) + + if trigger_type is None: + if isinstance(entry.target, AutoModRule): + # Trigger type cannot be changed, so it should be the same before and after updates. + # Avoids checking which keys are in data to guess trigger type + trigger_value = entry.target.trigger.type.value + else: + # found a trigger type from before or after + trigger_value = trigger_type.value + + if trigger_value is None: + # try to find trigger type in the full list of changes + _elem = utils.find(lambda elem: elem['key'] == 'trigger_type', full_data) + if _elem is not None: + trigger_value = _elem.get('old_value', _elem.get('new_value')) # type: ignore # trigger type values should be int + + if trigger_value is None: + # try to infer trigger_type from the keys in old or new value + combined = (data.get('old_value') or {}).keys() | (data.get('new_value') or {}).keys() + if not combined: + trigger_value = enums.AutoModRuleTriggerType.spam.value + elif 'presets' in combined: + trigger_value = enums.AutoModRuleTriggerType.keyword_preset.value + elif 'keyword_filter' in combined or 'regex_patterns' in combined: + trigger_value = enums.AutoModRuleTriggerType.keyword.value + elif 'mention_total_limit' in combined or 'mention_raid_protection_enabled' in combined: + trigger_value = enums.AutoModRuleTriggerType.mention_spam.value + else: + # some unknown type + trigger_value = -1 + + self.before.trigger = AutoModTrigger.from_data(trigger_value, data.get('old_value')) + self.after.trigger = AutoModTrigger.from_data(trigger_value, data.get('new_value')) + + def _handle_trigger_attr_update( + self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, attr: str, data: List[str] + ): + self._create_trigger(first, entry) + trigger = self._create_trigger(second, entry) + try: + # guard unexpecte non list attributes or non iterable data + getattr(trigger, attr).extend(data) + except (AttributeError, TypeError): + pass + + def _handle_colours(self, diff: AuditLogDiff, colours: Optional[RoleColours]): + if colours is not None: + # handle colours to multiple colour attributes + colour = Colour(colours['primary_color']) + secondary_colour = colours['secondary_color'] + tertiary_colour = colours['tertiary_color'] + else: + colour = None + secondary_colour = None + tertiary_colour = None + + diff.color = diff.colour = colour + diff.secondary_color = diff.secondary_colour = Colour(secondary_colour) if secondary_colour is not None else None + diff.tertiary_color = diff.tertiary_colour = Colour(tertiary_colour) if tertiary_colour is not None else None + + def _create_trigger(self, diff: AuditLogDiff, entry: AuditLogEntry) -> AutoModTrigger: + # check if trigger has already been created + if not hasattr(diff, 'trigger'): + # create a trigger + if isinstance(entry.target, AutoModRule): + # get trigger type from the automod rule + trigger_type = entry.target.trigger.type + else: + # unknown trigger type + trigger_type = enums.try_enum(enums.AutoModRuleTriggerType, -1) + + diff.trigger = AutoModTrigger(type=trigger_type) + return diff.trigger + + +class _AuditLogProxy: + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) + + +class _AuditLogProxyMemberPrune(_AuditLogProxy): + delete_member_days: int + members_removed: int + + +class _AuditLogProxyMemberMoveOrMessageDelete(_AuditLogProxy): + channel: Union[abc.GuildChannel, Thread] + count: int + + +class _AuditLogProxyMemberDisconnect(_AuditLogProxy): + count: int + + +class _AuditLogProxyPinAction(_AuditLogProxy): + channel: Union[abc.GuildChannel, Thread] + message_id: int + + +class _AuditLogProxyStageInstanceAction(_AuditLogProxy): + channel: abc.GuildChannel + + +class _AuditLogProxyMessageBulkDelete(_AuditLogProxy): + count: int + + +class _AuditLogProxyAutoModAction(_AuditLogProxy): + automod_rule_name: str + automod_rule_trigger_type: str + channel: Optional[Union[abc.GuildChannel, Thread]] + + +class _AuditLogProxyMemberKickOrMemberRoleUpdate(_AuditLogProxy): + integration_type: Optional[str] + + +class AuditLogEntry(Hashable): r"""Represents an Audit Log entry. You retrieve these via :meth:`Guild.audit_logs`. + .. container:: operations + + .. describe:: x == y + + Checks if two entries are equal. + + .. describe:: x != y + + Checks if two entries are not equal. + + .. describe:: hash(x) + + Returns the entry's hash. + + .. versionchanged:: 1.7 + Audit log entries are now comparable and hashable. + Attributes ----------- action: :class:`AuditLogAction` The action that was done. - user: :class:`abc.User` + user: Optional[:class:`abc.User`] The user who initiated this action. Usually a :class:`Member`\, unless gone then it's a :class:`User`. + user_id: Optional[:class:`int`] + The user ID who initiated this action. + + .. versionadded:: 2.2 id: :class:`int` The entry ID. + guild: :class:`Guild` + The guild that this entry belongs to. target: Any The target that got changed. The exact type of this depends on the action being done. @@ -213,43 +683,120 @@ class AuditLogEntry: which actions have this field filled out. """ - def __init__(self, *, users, data, guild): - self._state = guild._state - self.guild = guild - self._users = users + def __init__( + self, + *, + users: Mapping[int, User], + integrations: Mapping[int, PartialIntegration], + app_commands: Mapping[int, AppCommand], + automod_rules: Mapping[int, AutoModRule], + webhooks: Mapping[int, Webhook], + data: AuditLogEntryPayload, + guild: Guild, + ): + self._state: ConnectionState = guild._state + self.guild: Guild = guild + self._users: Mapping[int, User] = users + self._integrations: Mapping[int, PartialIntegration] = integrations + self._app_commands: Mapping[int, AppCommand] = app_commands + self._automod_rules: Mapping[int, AutoModRule] = automod_rules + self._webhooks: Mapping[int, Webhook] = webhooks self._from_data(data) - def _from_data(self, data): - self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) - self.id = int(data['id']) + def _from_data(self, data: AuditLogEntryPayload) -> None: + self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type']) + self.id: int = int(data['id']) # this key is technically not usually present - self.reason = data.get('reason') - self.extra = data.get('options') - - if self.extra: + self.reason: Optional[str] = data.get('reason') + extra = data.get('options') + + # fmt: off + self.extra: Union[ + _AuditLogProxyMemberPrune, + _AuditLogProxyMemberMoveOrMessageDelete, + _AuditLogProxyMemberDisconnect, + _AuditLogProxyPinAction, + _AuditLogProxyStageInstanceAction, + _AuditLogProxyMessageBulkDelete, + _AuditLogProxyAutoModAction, + _AuditLogProxyMemberKickOrMemberRoleUpdate, + Member, User, None, PartialIntegration, + Role, Object + ] = None + # fmt: on + + if isinstance(self.action, enums.AuditLogAction) and extra: if self.action is enums.AuditLogAction.member_prune: # member prune has two keys with useful information - self.extra = type('_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()})() - elif self.action is enums.AuditLogAction.message_delete: - channel_id = int(self.extra['channel_id']) - elems = { - 'count': int(self.extra['count']), - 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id) - } - self.extra = type('_AuditLogProxy', (), elems)() + self.extra = _AuditLogProxyMemberPrune( + delete_member_days=int(extra['delete_member_days']), + members_removed=int(extra['members_removed']), + ) + elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: + channel_id = int(extra['channel_id']) + self.extra = _AuditLogProxyMemberMoveOrMessageDelete( + count=int(extra['count']), + channel=self.guild.get_channel_or_thread(channel_id) or Object(id=channel_id), + ) + elif self.action is enums.AuditLogAction.member_disconnect: + # The member disconnect action has a dict with some information + self.extra = _AuditLogProxyMemberDisconnect(count=int(extra['count'])) + elif self.action is enums.AuditLogAction.message_bulk_delete: + # The bulk message delete action has the number of messages deleted + self.extra = _AuditLogProxyMessageBulkDelete(count=int(extra['count'])) + elif self.action in (enums.AuditLogAction.kick, enums.AuditLogAction.member_role_update): + # The member kick action has a dict with some information + integration_type = extra.get('integration_type') + self.extra = _AuditLogProxyMemberKickOrMemberRoleUpdate(integration_type=integration_type) + elif self.action.name.endswith('pin'): + # the pin actions have a dict with some information + channel_id = int(extra['channel_id']) + self.extra = _AuditLogProxyPinAction( + channel=self.guild.get_channel_or_thread(channel_id) or Object(id=channel_id), + message_id=int(extra['message_id']), + ) + elif ( + self.action is enums.AuditLogAction.automod_block_message + or self.action is enums.AuditLogAction.automod_flag_message + or self.action is enums.AuditLogAction.automod_timeout_member + or self.action is enums.AuditLogAction.automod_quarantine_user + ): + channel_id = utils._get_as_snowflake(extra, 'channel_id') + channel = None + + # May be an empty string instead of None due to a Discord issue + if channel_id: + channel = self.guild.get_channel_or_thread(channel_id) or Object(id=channel_id) + + self.extra = _AuditLogProxyAutoModAction( + automod_rule_name=extra['auto_moderation_rule_name'], + automod_rule_trigger_type=enums.try_enum( + enums.AutoModRuleTriggerType, int(extra['auto_moderation_rule_trigger_type']) + ), + channel=channel, + ) + elif self.action.name.startswith('overwrite_'): # the overwrite_ actions have a dict with some information - instance_id = int(self.extra['id']) - the_type = self.extra.get('type') - if the_type == 'member': + instance_id = int(extra['id']) + the_type = extra.get('type') + if the_type == '1': self.extra = self._get_member(instance_id) - else: + elif the_type == '0': role = self.guild.get_role(instance_id) if role is None: - role = Object(id=instance_id) - role.name = self.extra.get('role_name') + role = Object(id=instance_id, type=Role) + role.name = extra.get('role_name') # type: ignore # Object doesn't usually have name self.extra = role + elif self.action.name.startswith('stage_instance'): + channel_id = int(extra['channel_id']) + self.extra = _AuditLogProxyStageInstanceAction( + channel=self.guild.get_channel(channel_id) or Object(id=channel_id, type=StageChannel) + ) + elif self.action.name.startswith('app_command'): + app_id = int(extra['application_id']) + self.extra = self._get_integration_by_app_id(app_id) or Object(app_id, type=PartialIntegration) # this key is not present when the above is present, typically. # It's a list of { new_value: a, old_value: b, key: c } @@ -258,93 +805,182 @@ def _from_data(self, data): # into meaningful data when requested self._changes = data.get('changes', []) - self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) + self.user_id: Optional[int] = utils._get_as_snowflake(data, 'user_id') + self.user: Optional[Union[User, Member]] = self._get_member(self.user_id) self._target_id = utils._get_as_snowflake(data, 'target_id') - def _get_member(self, user_id): + def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]: + if user_id is None: + return None + return self.guild.get_member(user_id) or self._users.get(user_id) - def __repr__(self): - return ''.format(self) + def _get_integration(self, integration_id: Optional[int]) -> Optional[PartialIntegration]: + if integration_id is None: + return None + + return self._integrations.get(integration_id) + + def _get_integration_by_app_id(self, application_id: Optional[int]) -> Optional[PartialIntegration]: + if application_id is None: + return None + + # get PartialIntegration by application id + return utils.get(self._integrations.values(), application_id=application_id) + + def _get_app_command(self, app_command_id: Optional[int]) -> Optional[AppCommand]: + if app_command_id is None: + return None + + return self._app_commands.get(app_command_id) + + def __repr__(self) -> str: + return f'' @utils.cached_property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the entry's creation time in UTC.""" return utils.snowflake_time(self.id) @utils.cached_property - def target(self): + def target(self) -> TargetType: + if self.action.target_type is None: + return None + try: converter = getattr(self, '_convert_target_' + self.action.target_type) except AttributeError: + if self._target_id is None: + return None return Object(id=self._target_id) else: return converter(self._target_id) @utils.cached_property - def category(self): + def category(self) -> Optional[enums.AuditLogActionCategory]: """Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable.""" return self.action.category @utils.cached_property - def changes(self): + def changes(self) -> AuditLogChanges: """:class:`AuditLogChanges`: The list of changes this entry has.""" obj = AuditLogChanges(self, self._changes) del self._changes return obj @utils.cached_property - def before(self): + def before(self) -> AuditLogDiff: """:class:`AuditLogDiff`: The target's prior state.""" return self.changes.before @utils.cached_property - def after(self): + def after(self) -> AuditLogDiff: """:class:`AuditLogDiff`: The target's subsequent state.""" return self.changes.after - def _convert_target_guild(self, target_id): + def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel(self, target_id): - ch = self.guild.get_channel(target_id) - if ch is None: - return Object(id=target_id) - return ch + def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]: + return self.guild.get_channel(target_id) or Object(id=target_id) + + def _convert_target_user(self, target_id: Optional[int]) -> Optional[Union[Member, User, Object]]: + # For some reason the member_disconnect and member_move action types + # do not have a non-null target_id so safeguard against that + if target_id is None: + return None - def _convert_target_user(self, target_id): - return self._get_member(target_id) + return self._get_member(target_id) or Object(id=target_id, type=Member) - def _convert_target_role(self, target_id): - role = self.guild.get_role(target_id) - if role is None: - return Object(id=target_id) - return role + def _convert_target_role(self, target_id: int) -> Union[Role, Object]: + return self.guild.get_role(target_id) or Object(id=target_id, type=Role) - def _convert_target_invite(self, target_id): + def _convert_target_invite(self, target_id: None) -> Invite: # invites have target_id set to null # so figure out which change has the full invite data changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after - fake_payload = { + fake_payload: InvitePayload = { 'max_age': changeset.max_age, 'max_uses': changeset.max_uses, 'code': changeset.code, 'temporary': changeset.temporary, - 'channel': changeset.channel, 'uses': changeset.uses, - 'guild': self.guild, + 'channel': None, # type: ignore # the channel is passed to the Invite constructor directly } - obj = Invite(state=self._state, data=fake_payload) + obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) try: obj.inviter = changeset.inviter except AttributeError: pass return obj - def _convert_target_emoji(self, target_id): - return self._state.get_emoji(target_id) or Object(id=target_id) + def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: + return self._state.get_emoji(target_id) or Object(id=target_id, type=Emoji) + + def _convert_target_message(self, target_id: Optional[int]) -> Optional[Union[Member, User, Object]]: + # The message_pin and message_unpin action types do not have a + # non-null target_id so safeguard against that + + if target_id is None: + return None + + return self._get_member(target_id) or Object(id=target_id, type=Member) + + def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: + return self.guild.get_stage_instance(target_id) or Object(id=target_id, type=StageInstance) + + def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: + return self._state.get_sticker(target_id) or Object(id=target_id, type=GuildSticker) + + def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: + return self.guild.get_thread(target_id) or Object(id=target_id, type=Thread) + + def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]: + return self.guild.get_scheduled_event(target_id) or Object(id=target_id, type=ScheduledEvent) + + def _convert_target_integration(self, target_id: int) -> Union[PartialIntegration, Object]: + return self._get_integration(target_id) or Object(target_id, type=PartialIntegration) + + def _convert_target_app_command(self, target_id: int) -> Union[AppCommand, Object]: + target = self._get_app_command(target_id) + if not target: + # circular import + from .app_commands import AppCommand + + target = Object(target_id, type=AppCommand) + + return target + + def _convert_target_integration_or_app_command(self, target_id: int) -> Union[PartialIntegration, AppCommand, Object]: + target = self._get_integration_by_app_id(target_id) or self._get_app_command(target_id) + if not target: + try: + # circular import + from .app_commands import AppCommand + + # get application id from extras + # if it matches target id, type should be integration + target_app = self.extra + # extra should be an Object or PartialIntegration + app_id = target_app.application_id if isinstance(target_app, PartialIntegration) else target_app.id # type: ignore + type = PartialIntegration if target_id == app_id else AppCommand + except AttributeError: + return Object(target_id) + else: + return Object(target_id, type=type) + + return target + + def _convert_target_auto_moderation(self, target_id: int) -> Union[AutoModRule, Object]: + return self._automod_rules.get(target_id) or Object(target_id, type=AutoModRule) + + def _convert_target_webhook(self, target_id: int) -> Union[Webhook, Object]: + # circular import + from .webhook import Webhook + + return self._webhooks.get(target_id) or Object(target_id, type=Webhook) - def _convert_target_message(self, target_id): - return self._get_member(target_id) + def _convert_target_onboarding_prompt(self, target_id: int) -> Object: + return Object(target_id, type=OnboardingPrompt) diff --git a/discord/automod.py b/discord/automod.py new file mode 100644 index 000000000000..5441d9467103 --- /dev/null +++ b/discord/automod.py @@ -0,0 +1,666 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations +import datetime + +from typing import TYPE_CHECKING, Any, Dict, Optional, List, Set, Union, Sequence, overload, Literal + +from .enums import AutoModRuleTriggerType, AutoModRuleActionType, AutoModRuleEventType, try_enum +from .flags import AutoModPresets +from . import utils +from .utils import MISSING, cached_slot_property + +if TYPE_CHECKING: + from typing_extensions import Self + from .abc import Snowflake, GuildChannel + from .threads import Thread + from .guild import Guild + from .member import Member + from .state import ConnectionState + from .types.automod import ( + AutoModerationRule as AutoModerationRulePayload, + AutoModerationTriggerMetadata as AutoModerationTriggerMetadataPayload, + AutoModerationAction as AutoModerationActionPayload, + AutoModerationActionExecution as AutoModerationActionExecutionPayload, + ) + from .role import Role + +__all__ = ( + 'AutoModRuleAction', + 'AutoModTrigger', + 'AutoModRule', + 'AutoModAction', +) + + +class AutoModRuleAction: + """Represents an auto moderation's rule action. + + .. note:: + Only one of ``channel_id``, ``duration``, or ``custom_message`` can be used. + + .. versionadded:: 2.0 + + Attributes + ----------- + type: :class:`AutoModRuleActionType` + The type of action to take. + Defaults to :attr:`~AutoModRuleActionType.block_message`. + channel_id: Optional[:class:`int`] + The ID of the channel or thread to send the alert message to, if any. + Passing this sets :attr:`type` to :attr:`~AutoModRuleActionType.send_alert_message`. + duration: Optional[:class:`datetime.timedelta`] + The duration of the timeout to apply, if any. + Has a maximum of 28 days. + Passing this sets :attr:`type` to :attr:`~AutoModRuleActionType.timeout`. + custom_message: Optional[:class:`str`] + A custom message which will be shown to a user when their message is blocked. + Passing this sets :attr:`type` to :attr:`~AutoModRuleActionType.block_message`. + + .. versionadded:: 2.2 + """ + + __slots__ = ('type', 'channel_id', 'duration', 'custom_message') + + @overload + def __init__(self, *, channel_id: int = ...) -> None: ... + + @overload + def __init__(self, *, type: Literal[AutoModRuleActionType.send_alert_message], channel_id: int = ...) -> None: ... + + @overload + def __init__(self, *, duration: datetime.timedelta = ...) -> None: ... + + @overload + def __init__(self, *, type: Literal[AutoModRuleActionType.timeout], duration: datetime.timedelta = ...) -> None: ... + + @overload + def __init__(self, *, custom_message: str = ...) -> None: ... + + @overload + def __init__(self, *, type: Literal[AutoModRuleActionType.block_message]) -> None: ... + + @overload + def __init__( + self, *, type: Literal[AutoModRuleActionType.block_message], custom_message: Optional[str] = ... + ) -> None: ... + + @overload + def __init__( + self, + *, + type: Optional[AutoModRuleActionType] = ..., + channel_id: Optional[int] = ..., + duration: Optional[datetime.timedelta] = ..., + custom_message: Optional[str] = ..., + ) -> None: ... + + def __init__( + self, + *, + type: Optional[AutoModRuleActionType] = None, + channel_id: Optional[int] = None, + duration: Optional[datetime.timedelta] = None, + custom_message: Optional[str] = None, + ) -> None: + if sum(v is None for v in (channel_id, duration, custom_message)) < 2: + raise ValueError('Only one of channel_id, duration, or custom_message can be passed.') + + self.type: AutoModRuleActionType + self.channel_id: Optional[int] = None + self.duration: Optional[datetime.timedelta] = None + self.custom_message: Optional[str] = None + + if type is not None: + self.type = type + elif channel_id is not None: + self.type = AutoModRuleActionType.send_alert_message + elif duration is not None: + self.type = AutoModRuleActionType.timeout + else: + self.type = AutoModRuleActionType.block_message + + if self.type is AutoModRuleActionType.send_alert_message: + if channel_id is None: + raise ValueError('channel_id cannot be None if type is send_alert_message') + self.channel_id = channel_id + + if self.type is AutoModRuleActionType.timeout: + if duration is None: + raise ValueError('duration cannot be None set if type is timeout') + self.duration = duration + + if self.type is AutoModRuleActionType.block_message: + self.custom_message = custom_message + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_data(cls, data: AutoModerationActionPayload) -> Self: + if data['type'] == AutoModRuleActionType.timeout.value: + duration_seconds = data['metadata']['duration_seconds'] + return cls(duration=datetime.timedelta(seconds=duration_seconds)) + elif data['type'] == AutoModRuleActionType.send_alert_message.value: + channel_id = int(data['metadata']['channel_id']) + return cls(channel_id=channel_id) + elif data['type'] == AutoModRuleActionType.block_message.value: + custom_message = data.get('metadata', {}).get('custom_message') + return cls(type=AutoModRuleActionType.block_message, custom_message=custom_message) + + return cls(type=AutoModRuleActionType.block_member_interactions) + + def to_dict(self) -> Dict[str, Any]: + ret = {'type': self.type.value, 'metadata': {}} + if self.type is AutoModRuleActionType.block_message and self.custom_message is not None: + ret['metadata'] = {'custom_message': self.custom_message} + elif self.type is AutoModRuleActionType.timeout: + ret['metadata'] = {'duration_seconds': int(self.duration.total_seconds())} # type: ignore # duration cannot be None here + elif self.type is AutoModRuleActionType.send_alert_message: + ret['metadata'] = {'channel_id': str(self.channel_id)} + return ret + + +class AutoModTrigger: + r"""Represents a trigger for an auto moderation rule. + + The following table illustrates relevant attributes for each :class:`AutoModRuleTriggerType`: + + +-----------------------------------------------+------------------------------------------------+ + | Type | Attributes | + +===============================================+================================================+ + | :attr:`AutoModRuleTriggerType.keyword` | :attr:`keyword_filter`, :attr:`regex_patterns`,| + | | :attr:`allow_list` | + +-----------------------------------------------+------------------------------------------------+ + | :attr:`AutoModRuleTriggerType.spam` | | + +-----------------------------------------------+------------------------------------------------+ + | :attr:`AutoModRuleTriggerType.keyword_preset` | :attr:`presets`\, :attr:`allow_list` | + +-----------------------------------------------+------------------------------------------------+ + | :attr:`AutoModRuleTriggerType.mention_spam` | :attr:`mention_limit`, | + | | :attr:`mention_raid_protection` | + +-----------------------------------------------+------------------------------------------------+ + | :attr:`AutoModRuleTriggerType.member_profile` | :attr:`keyword_filter`, :attr:`regex_patterns`,| + | | :attr:`allow_list` | + +-----------------------------------------------+------------------------------------------------+ + + .. versionadded:: 2.0 + + Attributes + ----------- + type: :class:`AutoModRuleTriggerType` + The type of trigger. + keyword_filter: List[:class:`str`] + The list of strings that will trigger the filter. + Maximum of 1000. Keywords can only be up to 60 characters in length. + + This could be combined with :attr:`regex_patterns`. + regex_patterns: List[:class:`str`] + The regex pattern that will trigger the filter. The syntax is based off of + `Rust's regex syntax `_. + Maximum of 10. Regex strings can only be up to 260 characters in length. + + This could be combined with :attr:`keyword_filter` and/or :attr:`allow_list` + + .. versionadded:: 2.1 + presets: :class:`AutoModPresets` + The presets used with the preset keyword filter. + allow_list: List[:class:`str`] + The list of words that are exempt from the commonly flagged words. Maximum of 100. + Keywords can only be up to 60 characters in length. + mention_limit: :class:`int` + The total number of user and role mentions a message can contain. + Has a maximum of 50. + mention_raid_protection: :class:`bool` + Whether mention raid protection is enabled or not. + + .. versionadded:: 2.4 + """ + + __slots__ = ( + 'type', + 'keyword_filter', + 'presets', + 'allow_list', + 'mention_limit', + 'regex_patterns', + 'mention_raid_protection', + ) + + def __init__( + self, + *, + type: Optional[AutoModRuleTriggerType] = None, + keyword_filter: Optional[List[str]] = None, + presets: Optional[AutoModPresets] = None, + allow_list: Optional[List[str]] = None, + mention_limit: Optional[int] = None, + regex_patterns: Optional[List[str]] = None, + mention_raid_protection: Optional[bool] = None, + ) -> None: + unique_args = (keyword_filter or regex_patterns, presets, mention_limit or mention_raid_protection) + if type is None and sum(arg is not None for arg in unique_args) > 1: + raise ValueError( + 'Please pass only one of keyword_filter/regex_patterns, presets, or mention_limit/mention_raid_protection.' + ) + + if type is not None: + self.type = type + elif keyword_filter is not None or regex_patterns is not None: + self.type = AutoModRuleTriggerType.keyword + elif presets is not None: + self.type = AutoModRuleTriggerType.keyword_preset + elif mention_limit is not None or mention_raid_protection is not None: + self.type = AutoModRuleTriggerType.mention_spam + else: + raise ValueError( + 'Please pass the trigger type explicitly if not using keyword_filter, regex_patterns, presets, mention_limit, or mention_raid_protection.' + ) + + self.keyword_filter: List[str] = keyword_filter if keyword_filter is not None else [] + self.presets: AutoModPresets = presets if presets is not None else AutoModPresets() + self.allow_list: List[str] = allow_list if allow_list is not None else [] + self.mention_limit: int = mention_limit if mention_limit is not None else 0 + self.mention_raid_protection: bool = mention_raid_protection if mention_raid_protection is not None else False + self.regex_patterns: List[str] = regex_patterns if regex_patterns is not None else [] + + def __repr__(self) -> str: + data = self.to_metadata_dict() + if data: + joined = ' '.join(f'{k}={v!r}' for k, v in data.items()) + return f'' + + return f'' + + @classmethod + def from_data(cls, type: int, data: Optional[AutoModerationTriggerMetadataPayload]) -> Self: + type_ = try_enum(AutoModRuleTriggerType, type) + if data is None: + return cls(type=type_) + elif type_ in (AutoModRuleTriggerType.keyword, AutoModRuleTriggerType.member_profile): + return cls( + type=type_, + keyword_filter=data.get('keyword_filter'), + regex_patterns=data.get('regex_patterns'), + allow_list=data.get('allow_list'), + ) + elif type_ is AutoModRuleTriggerType.keyword_preset: + return cls( + type=type_, presets=AutoModPresets._from_value(data.get('presets', [])), allow_list=data.get('allow_list') + ) + elif type_ is AutoModRuleTriggerType.mention_spam: + return cls( + type=type_, + mention_limit=data.get('mention_total_limit'), + mention_raid_protection=data.get('mention_raid_protection_enabled'), + ) + else: + return cls(type=type_) + + def to_metadata_dict(self) -> Optional[Dict[str, Any]]: + if self.type in (AutoModRuleTriggerType.keyword, AutoModRuleTriggerType.member_profile): + return { + 'keyword_filter': self.keyword_filter, + 'regex_patterns': self.regex_patterns, + 'allow_list': self.allow_list, + } + elif self.type is AutoModRuleTriggerType.keyword_preset: + return {'presets': self.presets.to_array(), 'allow_list': self.allow_list} + elif self.type is AutoModRuleTriggerType.mention_spam: + return { + 'mention_total_limit': self.mention_limit, + 'mention_raid_protection_enabled': self.mention_raid_protection, + } + + +class AutoModRule: + """Represents an auto moderation rule. + + .. versionadded:: 2.0 + + Attributes + ----------- + id: :class:`int` + The ID of the rule. + guild: :class:`Guild` + The guild the rule is for. + name: :class:`str` + The name of the rule. + creator_id: :class:`int` + The ID of the user that created the rule. + trigger: :class:`AutoModTrigger` + The rule's trigger. + enabled: :class:`bool` + Whether the rule is enabled. + exempt_role_ids: Set[:class:`int`] + The IDs of the roles that are exempt from the rule. + exempt_channel_ids: Set[:class:`int`] + The IDs of the channels that are exempt from the rule. + event_type: :class:`AutoModRuleEventType` + The type of event that will trigger the the rule. + """ + + __slots__ = ( + '_state', + '_cs_exempt_roles', + '_cs_exempt_channels', + '_cs_actions', + 'id', + 'guild', + 'name', + 'creator_id', + 'event_type', + 'trigger', + 'enabled', + 'exempt_role_ids', + 'exempt_channel_ids', + '_actions', + ) + + def __init__(self, *, data: AutoModerationRulePayload, guild: Guild, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.guild: Guild = guild + self.id: int = int(data['id']) + self.name: str = data['name'] + self.creator_id = int(data['creator_id']) + self.event_type: AutoModRuleEventType = try_enum(AutoModRuleEventType, data['event_type']) + self.trigger: AutoModTrigger = AutoModTrigger.from_data(data['trigger_type'], data=data.get('trigger_metadata')) + self.enabled: bool = data['enabled'] + self.exempt_role_ids: Set[int] = {int(role_id) for role_id in data['exempt_roles']} + self.exempt_channel_ids: Set[int] = {int(channel_id) for channel_id in data['exempt_channels']} + self._actions: List[AutoModerationActionPayload] = data['actions'] + + def __repr__(self) -> str: + return f'' + + def to_dict(self) -> AutoModerationRulePayload: + ret: AutoModerationRulePayload = { + 'id': str(self.id), + 'guild_id': str(self.guild.id), + 'name': self.name, + 'creator_id': str(self.creator_id), + 'event_type': self.event_type.value, + 'trigger_type': self.trigger.type.value, + 'trigger_metadata': self.trigger.to_metadata_dict(), + 'actions': [action.to_dict() for action in self.actions], + 'enabled': self.enabled, + 'exempt_roles': [str(role_id) for role_id in self.exempt_role_ids], + 'exempt_channels': [str(channel_id) for channel_id in self.exempt_channel_ids], + } # type: ignore # trigger types break the flow here. + + return ret + + @property + def creator(self) -> Optional[Member]: + """Optional[:class:`Member`]: The member that created this rule.""" + return self.guild.get_member(self.creator_id) + + @cached_slot_property('_cs_exempt_roles') + def exempt_roles(self) -> List[Role]: + """List[:class:`Role`]: The roles that are exempt from this rule.""" + result = [] + get_role = self.guild.get_role + for role_id in self.exempt_role_ids: + role = get_role(role_id) + if role is not None: + result.append(role) + + return utils._unique(result) + + @cached_slot_property('_cs_exempt_channels') + def exempt_channels(self) -> List[Union[GuildChannel, Thread]]: + """List[Union[:class:`abc.GuildChannel`, :class:`Thread`]]: The channels that are exempt from this rule.""" + it = filter(None, map(self.guild._resolve_channel, self.exempt_channel_ids)) + return utils._unique(it) + + @cached_slot_property('_cs_actions') + def actions(self) -> List[AutoModRuleAction]: + """List[:class:`AutoModRuleAction`]: The actions that are taken when this rule is triggered.""" + return [AutoModRuleAction.from_data(action) for action in self._actions] + + def is_exempt(self, obj: Snowflake, /) -> bool: + """Check if an object is exempt from the automod rule. + + Parameters + ----------- + obj: :class:`abc.Snowflake` + The role, channel, or thread to check. + + Returns + -------- + :class:`bool` + Whether the object is exempt from the automod rule. + """ + return obj.id in self.exempt_channel_ids or obj.id in self.exempt_role_ids + + async def edit( + self, + *, + name: str = MISSING, + event_type: AutoModRuleEventType = MISSING, + actions: List[AutoModRuleAction] = MISSING, + trigger: AutoModTrigger = MISSING, + enabled: bool = MISSING, + exempt_roles: Sequence[Snowflake] = MISSING, + exempt_channels: Sequence[Snowflake] = MISSING, + reason: str = MISSING, + ) -> Self: + """|coro| + + Edits this auto moderation rule. + + You must have :attr:`Permissions.manage_guild` to edit rules. + + Parameters + ----------- + name: :class:`str` + The new name to change to. + event_type: :class:`AutoModRuleEventType` + The new event type to change to. + actions: List[:class:`AutoModRuleAction`] + The new rule actions to update. + trigger: :class:`AutoModTrigger` + The new trigger to update. + You can only change the trigger metadata, not the type. + enabled: :class:`bool` + Whether the rule should be enabled or not. + exempt_roles: Sequence[:class:`abc.Snowflake`] + The new roles to exempt from the rule. + exempt_channels: Sequence[:class:`abc.Snowflake`] + The new channels to exempt from the rule. + reason: :class:`str` + The reason for updating this rule. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permission to edit this rule. + HTTPException + Editing the rule failed. + + Returns + -------- + :class:`AutoModRule` + The updated auto moderation rule. + """ + payload = {} + if actions is not MISSING: + payload['actions'] = [action.to_dict() for action in actions] + + if name is not MISSING: + payload['name'] = name + + if event_type is not MISSING: + payload['event_type'] = event_type.value + + if trigger is not MISSING: + trigger_metadata = trigger.to_metadata_dict() + if trigger_metadata is not None: + payload['trigger_metadata'] = trigger_metadata + + if enabled is not MISSING: + payload['enabled'] = enabled + + if exempt_roles is not MISSING: + payload['exempt_roles'] = [x.id for x in exempt_roles] + + if exempt_channels is not MISSING: + payload['exempt_channels'] = [x.id for x in exempt_channels] + + data = await self._state.http.edit_auto_moderation_rule( + self.guild.id, + self.id, + reason=reason, + **payload, + ) + + return self.__class__(data=data, guild=self.guild, state=self._state) + + async def delete(self, *, reason: str = MISSING) -> None: + """|coro| + + Deletes the auto moderation rule. + + You must have :attr:`Permissions.manage_guild` to delete rules. + + Parameters + ----------- + reason: :class:`str` + The reason for deleting this rule. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to delete the rule. + HTTPException + Deleting the rule failed. + """ + await self._state.http.delete_auto_moderation_rule(self.guild.id, self.id, reason=reason) + + +class AutoModAction: + """Represents an action that was taken as the result of a moderation rule. + + .. versionadded:: 2.0 + + Attributes + ----------- + action: :class:`AutoModRuleAction` + The action that was taken. + message_id: Optional[:class:`int`] + The message ID that triggered the action. This is only available if the + action is done on an edited message. + rule_id: :class:`int` + The ID of the rule that was triggered. + rule_trigger_type: :class:`AutoModRuleTriggerType` + The trigger type of the rule that was triggered. + guild_id: :class:`int` + The ID of the guild where the rule was triggered. + user_id: :class:`int` + The ID of the user that triggered the rule. + channel_id: :class:`int` + The ID of the channel where the rule was triggered. + alert_system_message_id: Optional[:class:`int`] + The ID of the system message that was sent to the predefined alert channel. + content: :class:`str` + The content of the message that triggered the rule. + Requires the :attr:`Intents.message_content` or it will always return an empty string. + matched_keyword: Optional[:class:`str`] + The matched keyword from the triggering message. + matched_content: Optional[:class:`str`] + The matched content from the triggering message. + Requires the :attr:`Intents.message_content` or it will always return ``None``. + """ + + __slots__ = ( + '_state', + 'action', + 'rule_id', + 'rule_trigger_type', + 'guild_id', + 'user_id', + 'channel_id', + 'message_id', + 'alert_system_message_id', + 'content', + 'matched_keyword', + 'matched_content', + ) + + def __init__(self, *, data: AutoModerationActionExecutionPayload, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.message_id: Optional[int] = utils._get_as_snowflake(data, 'message_id') + self.action: AutoModRuleAction = AutoModRuleAction.from_data(data['action']) + self.rule_id: int = int(data['rule_id']) + self.rule_trigger_type: AutoModRuleTriggerType = try_enum(AutoModRuleTriggerType, data['rule_trigger_type']) + self.guild_id: int = int(data['guild_id']) + self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') + self.user_id: int = int(data['user_id']) + self.alert_system_message_id: Optional[int] = utils._get_as_snowflake(data, 'alert_system_message_id') + self.content: str = data.get('content', '') + self.matched_keyword: Optional[str] = data['matched_keyword'] + self.matched_content: Optional[str] = data.get('matched_content') + + def __repr__(self) -> str: + return f'' + + @property + def guild(self) -> Guild: + """:class:`Guild`: The guild this action was taken in.""" + return self._state._get_or_create_unavailable_guild(self.guild_id) + + @property + def channel(self) -> Optional[Union[GuildChannel, Thread]]: + """Optional[Union[:class:`abc.GuildChannel`, :class:`Thread`]]: The channel this action was taken in.""" + if self.channel_id: + return self.guild.get_channel_or_thread(self.channel_id) + return None + + @property + def member(self) -> Optional[Member]: + """Optional[:class:`Member`]: The member this action was taken against /who triggered this rule.""" + return self.guild.get_member(self.user_id) + + async def fetch_rule(self) -> AutoModRule: + """|coro| + + Fetch the rule whose action was taken. + + You must have :attr:`Permissions.manage_guild` to do this. + + Raises + ------- + Forbidden + You do not have permissions to view the rule. + HTTPException + Fetching the rule failed. + + Returns + -------- + :class:`AutoModRule` + The rule that was executed. + """ + + data = await self._state.http.get_auto_moderation_rule(self.guild.id, self.rule_id) + return AutoModRule(data=data, guild=self.guild, state=self._state) diff --git a/discord/backoff.py b/discord/backoff.py index 3ebc019e7d71..f40142a9acfc 100644 --- a/discord/backoff.py +++ b/discord/backoff.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,10 +22,23 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + + import time import random +from typing import Callable, Generic, Literal, TypeVar, overload, Union + +T = TypeVar('T', bool, Literal[True], Literal[False]) + +# fmt: off +__all__ = ( + 'ExponentialBackoff', +) +# fmt: on -class ExponentialBackoff: + +class ExponentialBackoff(Generic[T]): """An implementation of the exponential backoff algorithm Provides a convenient interface to implement an exponential backoff @@ -49,21 +60,30 @@ class ExponentialBackoff: number in between may be returned. """ - def __init__(self, base=1, *, integral=False): - self._base = base + def __init__(self, base: int = 1, *, integral: T = False): + self._base: int = base - self._exp = 0 - self._max = 10 - self._reset_time = base * 2 ** 11 - self._last_invocation = time.monotonic() + self._exp: int = 0 + self._max: int = 10 + self._reset_time: int = base * 2**11 + self._last_invocation: float = time.monotonic() # Use our own random instance to avoid messing with global one rand = random.Random() rand.seed() - self._randfunc = rand.randrange if integral else rand.uniform + self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform + + @overload + def delay(self: ExponentialBackoff[Literal[False]]) -> float: ... + + @overload + def delay(self: ExponentialBackoff[Literal[True]]) -> int: ... + + @overload + def delay(self: ExponentialBackoff[bool]) -> Union[int, float]: ... - def delay(self): + def delay(self) -> Union[int, float]: """Compute the next delay Returns the next delay to wait according to the exponential @@ -82,4 +102,4 @@ def delay(self): self._exp = 0 self._exp = min(self._exp + 1, self._max) - return self._randfunc(0, self._base * 2 ** self._exp) + return self._randfunc(0, self._base * 2**self._exp) diff --git a/discord/bin/libopus-0.x64.dll b/discord/bin/libopus-0.x64.dll index 2832418655d9..74a8e3554ff0 100644 Binary files a/discord/bin/libopus-0.x64.dll and b/discord/bin/libopus-0.x64.dll differ diff --git a/discord/bin/libopus-0.x86.dll b/discord/bin/libopus-0.x86.dll index b291dfc677d4..ee71317fa629 100644 Binary files a/discord/bin/libopus-0.x86.dll and b/discord/bin/libopus-0.x86.dll differ diff --git a/discord/calls.py b/discord/calls.py deleted file mode 100644 index 8ab04145f418..000000000000 --- a/discord/calls.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -The MIT License (MIT) - -Copyright (c) 2015-2019 Rapptz - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -import datetime - -from . import utils -from .enums import VoiceRegion, try_enum -from .member import VoiceState - -class CallMessage: - """Represents a group call message from Discord. - - This is only received in cases where the message type is equivalent to - :attr:`MessageType.call`. - - Attributes - ----------- - ended_timestamp: Optional[:class:`datetime.datetime`] - A naive UTC datetime object that represents the time that the call has ended. - participants: List[:class:`User`] - The list of users that are participating in this call. - message: :class:`Message` - The message associated with this call message. - """ - - def __init__(self, message, **kwargs): - self.message = message - self.ended_timestamp = utils.parse_time(kwargs.get('ended_timestamp')) - self.participants = kwargs.get('participants') - - @property - def call_ended(self): - """:class:`bool`: Indicates if the call has ended.""" - return self.ended_timestamp is not None - - @property - def channel(self): - r""":class:`GroupChannel`\: The private channel associated with this message.""" - return self.message.channel - - @property - def duration(self): - """Queries the duration of the call. - - If the call has not ended then the current duration will - be returned. - - Returns - --------- - :class:`datetime.timedelta` - The timedelta object representing the duration. - """ - if self.ended_timestamp is None: - return datetime.datetime.utcnow() - self.message.created_at - else: - return self.ended_timestamp - self.message.created_at - -class GroupCall: - """Represents the actual group call from Discord. - - This is accompanied with a :class:`CallMessage` denoting the information. - - Attributes - ----------- - call: :class:`CallMessage` - The call message associated with this group call. - unavailable: :class:`bool` - Denotes if this group call is unavailable. - ringing: List[:class:`User`] - A list of users that are currently being rung to join the call. - region: :class:`VoiceRegion` - The guild region the group call is being hosted on. - """ - - def __init__(self, **kwargs): - self.call = kwargs.get('call') - self.unavailable = kwargs.get('unavailable') - self._voice_states = {} - - for state in kwargs.get('voice_states', []): - self._update_voice_state(state) - - self._update(**kwargs) - - def _update(self, **kwargs): - self.region = try_enum(VoiceRegion, kwargs.get('region')) - lookup = {u.id: u for u in self.call.channel.recipients} - me = self.call.channel.me - lookup[me.id] = me - self.ringing = list(filter(None, map(lookup.get, kwargs.get('ringing', [])))) - - def _update_voice_state(self, data): - user_id = int(data['user_id']) - # left the voice channel? - if data['channel_id'] is None: - self._voice_states.pop(user_id, None) - else: - self._voice_states[user_id] = VoiceState(data=data, channel=self.channel) - - @property - def connected(self): - """List[:class:`User`]: A property that returns all users that are currently in this call.""" - ret = [u for u in self.channel.recipients if self.voice_state_for(u) is not None] - me = self.channel.me - if self.voice_state_for(me) is not None: - ret.append(me) - - return ret - - @property - def channel(self): - r""":class:`GroupChannel`\: Returns the channel the group call is in.""" - return self.call.channel - - def voice_state_for(self, user): - """Retrieves the :class:`VoiceState` for a specified :class:`User`. - - If the :class:`User` has no voice state then this function returns - ``None``. - - Parameters - ------------ - user: :class:`User` - The user to retrieve the voice state for. - - Returns - -------- - Optional[:class:`VoiceState`] - The voice state associated with this user. - """ - - return self._voice_states.get(user.id) diff --git a/discord/channel.py b/discord/channel.py index ae0eb4c37b25..17a1c0fb2d67 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,31 +22,263 @@ DEALINGS IN THE SOFTWARE. """ -import time -import asyncio +from __future__ import annotations + +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterable, + List, + Literal, + Mapping, + NamedTuple, + Optional, + TYPE_CHECKING, + Sequence, + Tuple, + TypeVar, + TypedDict, + Union, + overload, +) +import datetime import discord.abc -from .permissions import Permissions -from .enums import ChannelType, try_enum +from .scheduled_event import ScheduledEvent +from .permissions import PermissionOverwrite, Permissions +from .enums import ( + ChannelType, + ForumLayoutType, + ForumOrderType, + PrivacyLevel, + try_enum, + VideoQualityMode, + EntityType, + VoiceChannelEffectAnimationType, +) from .mixins import Hashable from . import utils +from .utils import MISSING from .asset import Asset -from .errors import ClientException, NoMoreItems -from .webhook import Webhook +from .errors import ClientException +from .stage_instance import StageInstance +from .threads import Thread +from .partial_emoji import _EmojiTag, PartialEmoji +from .flags import ChannelFlags, MessageFlags +from .http import handle_message_parameters +from .object import Object +from .soundboard import BaseSoundboardSound, SoundboardDefaultSound __all__ = ( 'TextChannel', 'VoiceChannel', + 'StageChannel', 'DMChannel', 'CategoryChannel', - 'StoreChannel', + 'ForumTag', + 'ForumChannel', 'GroupChannel', - '_channel_factory', + 'PartialMessageable', + 'VoiceChannelEffect', + 'VoiceChannelSoundEffect', ) -async def _single_delete_strategy(messages): - for m in messages: - await m.delete() +if TYPE_CHECKING: + from typing_extensions import Self, Unpack + + from .types.threads import ThreadArchiveDuration + from .role import Role + from .member import Member, VoiceState + from .abc import Snowflake, SnowflakeTime + from .embeds import Embed + from .message import Message, PartialMessage, EmojiInputType + from .mentions import AllowedMentions + from .webhook import Webhook + from .state import ConnectionState + from .sticker import GuildSticker, StickerItem + from .file import File + from .user import ClientUser, User, BaseUser + from .guild import Guild, GuildChannel as GuildChannelType + from .ui.view import BaseView, View, LayoutView + from .types.channel import ( + TextChannel as TextChannelPayload, + NewsChannel as NewsChannelPayload, + VoiceChannel as VoiceChannelPayload, + StageChannel as StageChannelPayload, + DMChannel as DMChannelPayload, + CategoryChannel as CategoryChannelPayload, + GroupDMChannel as GroupChannelPayload, + ForumChannel as ForumChannelPayload, + MediaChannel as MediaChannelPayload, + ForumTag as ForumTagPayload, + VoiceChannelEffect as VoiceChannelEffectPayload, + ) + from .types.snowflake import SnowflakeList + from .types.soundboard import BaseSoundboardSound as BaseSoundboardSoundPayload + from .soundboard import SoundboardSound + + OverwriteKeyT = TypeVar('OverwriteKeyT', Role, BaseUser, Object, Union[Role, Member, Object]) + + class _BaseCreateChannelOptions(TypedDict, total=False): + reason: Optional[str] + position: int + + class _CreateTextChannelOptions(_BaseCreateChannelOptions, total=False): + topic: str + slowmode_delay: int + nsfw: bool + overwrites: Mapping[Union[Role, Member, Object], PermissionOverwrite] + default_auto_archive_duration: int + default_thread_slowmode_delay: int + + class _CreateVoiceChannelOptions(_BaseCreateChannelOptions, total=False): + bitrate: int + user_limit: int + rtc_region: Optional[str] + video_quality_mode: VideoQualityMode + overwrites: Mapping[Union[Role, Member, Object], PermissionOverwrite] + + class _CreateStageChannelOptions(_CreateVoiceChannelOptions, total=False): + bitrate: int + user_limit: int + rtc_region: Optional[str] + video_quality_mode: VideoQualityMode + overwrites: Mapping[Union[Role, Member, Object], PermissionOverwrite] + + class _CreateForumChannelOptions(_CreateTextChannelOptions, total=False): + topic: str + slowmode_delay: int + nsfw: bool + overwrites: Mapping[Union[Role, Member, Object], PermissionOverwrite] + default_auto_archive_duration: int + default_thread_slowmode_delay: int + default_sort_order: ForumOrderType + default_reaction_emoji: EmojiInputType + default_layout: ForumLayoutType + available_tags: Sequence[ForumTag] + + +class ThreadWithMessage(NamedTuple): + thread: Thread + message: Message + + +class VoiceChannelEffectAnimation(NamedTuple): + id: int + type: VoiceChannelEffectAnimationType + + +class VoiceChannelSoundEffect(BaseSoundboardSound): + """Represents a Discord voice channel sound effect. + + .. versionadded:: 2.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two sound effects are equal. + + .. describe:: x != y + + Checks if two sound effects are not equal. + + .. describe:: hash(x) + + Returns the sound effect's hash. + + Attributes + ------------ + id: :class:`int` + The ID of the sound. + volume: :class:`float` + The volume of the sound as floating point percentage (e.g. ``1.0`` for 100%). + """ + + __slots__ = ('_state',) + + def __init__(self, *, state: ConnectionState, id: int, volume: float): + data: BaseSoundboardSoundPayload = { + 'sound_id': id, + 'volume': volume, + } + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id} volume={self.volume}>' + + @property + def created_at(self) -> Optional[datetime.datetime]: + """Optional[:class:`datetime.datetime`]: Returns the snowflake's creation time in UTC. + Returns ``None`` if it's a default sound.""" + if self.is_default(): + return None + else: + return utils.snowflake_time(self.id) + + def is_default(self) -> bool: + """:class:`bool`: Whether it's a default sound or not.""" + # if it's smaller than the Discord Epoch it cannot be a snowflake + return self.id < utils.DISCORD_EPOCH + + +class VoiceChannelEffect: + """Represents a Discord voice channel effect. + + .. versionadded:: 2.5 + + Attributes + ------------ + channel: :class:`VoiceChannel` + The channel in which the effect is sent. + user: Optional[:class:`Member`] + The user who sent the effect. ``None`` if not found in cache. + animation: Optional[:class:`VoiceChannelEffectAnimation`] + The animation the effect has. Returns ``None`` if the effect has no animation. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the effect. + sound: Optional[:class:`VoiceChannelSoundEffect`] + The sound of the effect. Returns ``None`` if it's an emoji effect. + """ + + __slots__ = ('channel', 'user', 'animation', 'emoji', 'sound') + + def __init__(self, *, state: ConnectionState, data: VoiceChannelEffectPayload, guild: Guild): + self.channel: VoiceChannel = guild.get_channel(int(data['channel_id'])) # type: ignore # will always be a VoiceChannel + self.user: Optional[Member] = guild.get_member(int(data['user_id'])) + self.animation: Optional[VoiceChannelEffectAnimation] = None + + animation_id = data.get('animation_id') + if animation_id is not None: + animation_type = try_enum(VoiceChannelEffectAnimationType, data['animation_type']) # type: ignore # cannot be None here + self.animation = VoiceChannelEffectAnimation(id=animation_id, type=animation_type) + + emoji = data.get('emoji') + self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None + self.sound: Optional[VoiceChannelSoundEffect] = None + + sound_id: Optional[int] = utils._get_as_snowflake(data, 'sound_id') + if sound_id is not None: + sound_volume = data.get('sound_volume') or 0.0 + self.sound = VoiceChannelSoundEffect(state=state, id=sound_id, volume=sound_volume) + + def __repr__(self) -> str: + attrs = [ + ('channel', self.channel), + ('user', self.user), + ('animation', self.animation), + ('emoji', self.emoji), + ('sound', self.sound), + ] + inner = ' '.join('%s=%r' % t for t in attrs) + return f'<{self.__class__.__name__} {inner}>' + + def is_sound(self) -> bool: + """:class:`bool`: Whether the effect is a sound or not.""" + return self.sound is not None + class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """Represents a Discord guild text channel. @@ -79,10 +309,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): The guild the channel belongs to. id: :class:`int` The channel ID. - category_id: :class:`int` - The category channel ID this channel belongs to. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. topic: Optional[:class:`str`] - The channel's topic. None if it doesn't exist. + The channel's topic. ``None`` if it doesn't exist. position: :class:`int` The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. @@ -91,83 +321,123 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): *not* point to an existing or valid message. slowmode_delay: :class:`int` The number of seconds a member must wait between sending messages - in this channel. A value of `0` denotes that it is disabled. + in this channel. A value of ``0`` denotes that it is disabled. Bots and users with :attr:`~Permissions.manage_channels` or :attr:`~Permissions.manage_messages` bypass slowmode. - """ + nsfw: :class:`bool` + If the channel is marked as "not safe for work" or "age restricted". + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. - __slots__ = ('name', 'id', 'guild', 'topic', '_state', 'nsfw', - 'category_id', 'position', 'slowmode_delay', '_overwrites', - '_type', 'last_message_id') + .. versionadded:: 2.0 + default_thread_slowmode_delay: :class:`int` + The default slowmode delay in seconds for threads created in this channel. - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) - self._type = data['type'] + .. versionadded:: 2.3 + """ + + __slots__ = ( + 'name', + 'id', + 'guild', + 'topic', + '_state', + 'nsfw', + 'category_id', + 'position', + 'slowmode_delay', + '_overwrites', + '_type', + 'last_message_id', + 'default_auto_archive_duration', + 'default_thread_slowmode_delay', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[TextChannelPayload, NewsChannelPayload]): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self._type: Literal[0, 5] = data['type'] self._update(guild, data) - def __repr__(self): + def __repr__(self) -> str: attrs = [ ('id', self.id), ('name', self.name), ('position', self.position), ('nsfw', self.nsfw), ('news', self.is_news()), - ('category_id', self.category_id) + ('category_id', self.category_id), ] - return '<%s %s>' % (self.__class__.__name__, ' '.join('%s=%r' % t for t in attrs)) - - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.topic = data.get('topic') - self.position = data['position'] - self.nsfw = data.get('nsfw', False) + joined = ' '.join('%s=%r' % t for t in attrs) + return f'<{self.__class__.__name__} {joined}>' + + def _update(self, guild: Guild, data: Union[TextChannelPayload, NewsChannelPayload]) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.topic: Optional[str] = data.get('topic') + self.position: int = data['position'] + self.nsfw: bool = data.get('nsfw', False) # Does this need coercion into `int`? No idea yet. - self.slowmode_delay = data.get('rate_limit_per_user', 0) - self._type = data.get('type', self._type) - self.last_message_id = utils._get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get('rate_limit_per_user', 0) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) + self.default_thread_slowmode_delay: int = data.get('default_thread_rate_limit_per_user', 0) + self._type: Literal[0, 5] = data.get('type', self._type) + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self._fill_overwrites(data) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self @property - def type(self): + def type(self) -> Literal[ChannelType.text, ChannelType.news]: """:class:`ChannelType`: The channel's Discord type.""" - return try_enum(ChannelType, self._type) + if self._type == 0: + return ChannelType.text + return ChannelType.news @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: return ChannelType.text.value - def permissions_for(self, member): - base = super().permissions_for(member) + @property + def _scheduled_event_entity_type(self) -> Optional[EntityType]: + return None + + @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) + self._apply_implicit_permissions(base) # text channels do not have voice related permissions denied = Permissions.voice() base.value &= ~denied.value return base - permissions_for.__doc__ = discord.abc.GuildChannel.permissions_for.__doc__ - @property - def members(self): + def members(self) -> List[Member]: """List[:class:`Member`]: Returns all members that can see this channel.""" return [m for m in self.guild.members if self.permissions_for(m).read_messages] - def is_nsfw(self): - """Checks if the channel is NSFW.""" + @property + def threads(self) -> List[Thread]: + """List[:class:`Thread`]: Returns all the threads that you can see. + + .. versionadded:: 2.0 + """ + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the channel is NSFW.""" return self.nsfw - def is_news(self): - """Checks if the channel is a news channel.""" + def is_news(self) -> bool: + """:class:`bool`: Checks if the channel is a news channel.""" return self._type == ChannelType.news.value @property - def last_message(self): - """Fetches the last message from this channel in cache. + def last_message(self) -> Optional[Message]: + """Retrieves the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -186,13 +456,49 @@ def last_message(self): """ return self._state._get_message(self.last_message_id) if self.last_message_id else None - async def edit(self, *, reason=None, **options): + @overload + async def edit(self) -> Optional[TextChannel]: ... + + @overload + async def edit(self, *, position: int, reason: Optional[str] = ...) -> None: ... + + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + topic: str = ..., + position: int = ..., + nsfw: bool = ..., + sync_permissions: bool = ..., + category: Optional[CategoryChannel] = ..., + slowmode_delay: int = ..., + default_auto_archive_duration: ThreadArchiveDuration = ..., + default_thread_slowmode_delay: int = ..., + type: ChannelType = ..., + overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ..., + ) -> TextChannel: ... + + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]: """|coro| Edits the channel. - You must have the :attr:`~Permissions.manage_channels` permission to - use this. + You must have :attr:`~Permissions.manage_channels` to do this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 1.4 + The ``type`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. Parameters ---------- @@ -212,31 +518,72 @@ async def edit(self, *, reason=None, **options): category. slowmode_delay: :class:`int` Specifies the slowmode rate limit for user in this channel, in seconds. - A value of `0` disables slowmode. The maximum value possible is `21600`. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + type: :class:`ChannelType` + Change the type of this text channel. Currently, only conversion between + :attr:`ChannelType.text` and :attr:`ChannelType.news` is supported. This + is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`. reason: Optional[:class:`str`] The reason for editing this channel. Shows up on the audit log. - + overwrites: :class:`Mapping` + A :class:`Mapping` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply to the channel. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + + .. versionadded:: 2.0 + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for threads created in this channel. + + .. versionadded:: 2.3 Raises ------ - InvalidArgument - If position is less than 0 or greater than the number of channels. + ValueError + The new ``position`` is less than 0 or greater than the number of channels. + TypeError + The permission overwrite information is not in proper form. Forbidden You do not have permissions to edit the channel. HTTPException Editing the channel failed. + + Returns + -------- + Optional[:class:`.TextChannel`] + The newly edited text channel. If the edit was only positional + then ``None`` is returned instead. """ - await self._edit(options, reason=reason) - async def clone(self, *, name=None, reason=None): - return await self._clone_impl({ + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> TextChannel: + base: Dict[Any, Any] = { 'topic': self.topic, 'nsfw': self.nsfw, - 'rate_limit_per_user': self.slowmode_delay - }, name=name, reason=reason) - - clone.__doc__ = discord.abc.GuildChannel.clone.__doc__ - - async def delete_messages(self, messages): + 'default_auto_archive_duration': self.default_auto_archive_duration, + 'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay, + } + if not self.is_news(): + base['rate_limit_per_user'] = self.slowmode_delay + return await self._clone_impl( + base, + name=name, + category=category, + reason=reason, + ) + + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: Optional[str] = None) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -249,23 +596,29 @@ async def delete_messages(self, messages): You cannot bulk delete more than 100 messages or messages that are older than 14 days old. - You must have the :attr:`~Permissions.manage_messages` permission to - use this. + You must have :attr:`~Permissions.manage_messages` to do this. + + .. versionchanged:: 2.0 + + ``messages`` parameter is now positional-only. - Usable only by bot accounts. + The ``reason`` keyword-only parameter was added. Parameters ----------- messages: Iterable[:class:`abc.Snowflake`] An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. Raises ------ ClientException The number of messages to delete was more than 100. Forbidden - You do not have proper permissions to delete the messages or - you're not using a bot account. + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. HTTPException Deleting the messages failed. """ @@ -273,34 +626,45 @@ async def delete_messages(self, messages): messages = list(messages) if len(messages) == 0: - return # do nothing + return # do nothing if len(messages) == 1: - message_id = messages[0].id + message_id: int = messages[0].id await self._state.http.delete_message(self.id, message_id) return if len(messages) > 100: raise ClientException('Can only bulk delete messages up to 100 messages') - message_ids = [m.id for m in messages] - await self._state.http.delete_messages(self.id, message_ids) - - async def purge(self, *, limit=100, check=None, before=None, after=None, around=None, oldest_first=False, bulk=True): + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: Optional[int] = 100, + check: Callable[[Message], bool] = MISSING, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + bulk: bool = True, + reason: Optional[str] = None, + ) -> List[Message]: """|coro| Purges a list of messages that meet the criteria given by the predicate ``check``. If a ``check`` is not provided then all messages are deleted without discrimination. - You must have the :attr:`~Permissions.manage_messages` permission to - delete messages even if they are your own (unless you are a user - account). The :attr:`~Permissions.read_message_history` permission is + You must have :attr:`~Permissions.manage_messages` to + delete messages even if they are your own. + Having :attr:`~Permissions.read_message_history` is also needed to retrieve message history. - Internally, this employs a different number of strategies depending - on the conditions met such as if a bulk delete is possible or if - the account is a user bot or not. + .. versionchanged:: 2.0 + + The ``reason`` keyword-only parameter was added. Examples --------- @@ -311,7 +675,7 @@ def is_me(m): return m.author == client.user deleted = await channel.purge(limit=100, check=is_me) - await channel.send('Deleted {} message(s)'.format(len(deleted))) + await channel.send(f'Deleted {len(deleted)} message(s)') Parameters ----------- @@ -329,11 +693,12 @@ def is_me(m): Same as ``around`` in :meth:`history`. oldest_first: Optional[:class:`bool`] Same as ``oldest_first`` in :meth:`history`. - bulk: class:`bool` + bulk: :class:`bool` If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will - fall back to single delete if current account is a user bot, or if messages are - older than two weeks. + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for purging the messages. Shows up on the audit log. Raises ------- @@ -347,60 +712,24 @@ def is_me(m): List[:class:`.Message`] The list of messages that were deleted. """ - - if check is None: - check = lambda m: True - - iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) - ret = [] - count = 0 - - minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 - strategy = self.delete_messages if self._state.is_bot and bulk else _single_delete_strategy - - while True: - try: - msg = await iterator.next() - except NoMoreItems: - # no more messages to poll - if count >= 2: - # more than 2 messages -> bulk delete - to_delete = ret[-count:] - await strategy(to_delete) - elif count == 1: - # delete a single message - await ret[-1].delete() - - return ret - else: - if count == 100: - # we've reached a full 'queue' - to_delete = ret[-100:] - await strategy(to_delete) - count = 0 - await asyncio.sleep(1) - - if check(msg): - if msg.id < minimum_time: - # older than 14 days old - if count == 1: - await ret[-1].delete() - elif count >= 2: - to_delete = ret[-count:] - await strategy(to_delete) - - count = 0 - strategy = _single_delete_strategy - - count += 1 - ret.append(msg) - - async def webhooks(self): + return await discord.abc._purge_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> List[Webhook]: """|coro| Gets the list of webhooks from this channel. - Requires :attr:`~.Permissions.manage_webhooks` permissions. + You must have :attr:`~.Permissions.manage_webhooks` to do this. Raises ------- @@ -413,17 +742,19 @@ async def webhooks(self): The webhooks for this channel. """ + from .webhook import Webhook + data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook(self, *, name, avatar=None, reason=None): + async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: """|coro| Creates a webhook for this channel. - Requires :attr:`~.Permissions.manage_webhooks` permissions. + You must have :attr:`~.Permissions.manage_webhooks` to do this. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 Added the ``reason`` keyword-only parameter. Parameters @@ -449,245 +780,1356 @@ async def create_webhook(self, *, name, avatar=None, reason=None): The created webhook. """ + from .webhook import Webhook + if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) + avatar = utils._bytes_to_base64_data(avatar) # type: ignore # Silence reassignment error data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) -class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): - """Represents a Discord guild voice channel. + async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook: + """|coro| - .. container:: operations + Follows a channel using a webhook. - .. describe:: x == y + Only news channels can be followed. - Checks if two channels are equal. + .. note:: - .. describe:: x != y + The webhook returned will not provide a token to do webhook + actions, as Discord does not provide it. - Checks if two channels are not equal. + .. versionadded:: 1.3 - .. describe:: hash(x) + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` instead of + ``InvalidArgument``. - Returns the channel's hash. + Parameters + ----------- + destination: :class:`TextChannel` + The channel you would like to follow from. + reason: Optional[:class:`str`] + The reason for following the channel. Shows up on the destination guild's audit log. - .. describe:: str(x) + .. versionadded:: 1.4 - Returns the channel's name. + Raises + ------- + HTTPException + Following the channel failed. + Forbidden + You do not have the permissions to create a webhook. + ClientException + The channel is not a news channel. + TypeError + The destination channel is not a text channel. - Attributes - ----------- - name: :class:`str` - The channel name. - guild: :class:`Guild` - The guild the channel belongs to. - id: :class:`int` - The channel ID. - category_id: :class:`int` - The category channel ID this channel belongs to. - position: :class:`int` - The position in the channel list. This is a number that starts at 0. e.g. the - top channel is position 0. - bitrate: :class:`int` - The channel's preferred audio bitrate in bits per second. - user_limit: :class:`int` - The channel's limit for number of members that can be in a voice channel. - """ + Returns + -------- + :class:`Webhook` + The created webhook. + """ - __slots__ = ('name', 'id', 'guild', 'bitrate', 'user_limit', - '_state', 'position', '_overwrites', 'category_id') + if not self.is_news(): + raise ClientException('The channel must be a news channel.') - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) - self._update(guild, data) + if not isinstance(destination, TextChannel): + raise TypeError(f'Expected TextChannel received {destination.__class__.__name__}') - def __repr__(self): - attrs = [ - ('id', self.id), - ('name', self.name), - ('position', self.position), - ('bitrate', self.bitrate), - ('user_limit', self.user_limit), - ('category_id', self.category_id) - ] - return '<%s %s>' % (self.__class__.__name__, ' '.join('%s=%r' % t for t in attrs)) + from .webhook import Webhook - def _get_voice_client_key(self): - return self.guild.id, 'guild_id' + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) + return Webhook._as_follower(data, channel=destination, user=self._state.user) - def _get_voice_state_pair(self): - return self.guild.id, self.id + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. - @property - def type(self): - """:class:`ChannelType`: The channel's Discord type.""" - return ChannelType.voice + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.position = data['position'] - self.bitrate = data.get('bitrate') - self.user_limit = data.get('user_limit') - self._fill_overwrites(data) + .. versionadded:: 1.6 - @property - def _sorting_bucket(self): - return ChannelType.voice.value + .. versionchanged:: 2.0 - @property - def members(self): - """List[:class:`Member`]: Returns all members that are currently inside this voice channel.""" - ret = [] - for user_id, state in self.guild._voice_states.items(): - if state.channel.id == self.id: - member = self.guild.get_member(user_id) - if member is not None: - ret.append(member) - return ret + ``message_id`` parameter is now positional-only. - def permissions_for(self, member): - base = super().permissions_for(member) + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. - # voice channels cannot be edited by people who can't connect to them - # It also implicitly denies all other voice perms - if not base.connect: - denied = Permissions.voice() - denied.update(manage_channels=True, manage_roles=True) - base.value &= ~denied.value - return base + Returns + --------- + :class:`PartialMessage` + The partial message. + """ - permissions_for.__doc__ = discord.abc.GuildChannel.permissions_for.__doc__ + from .message import PartialMessage - async def clone(self, *, name=None, reason=None): - return await self._clone_impl({ - 'bitrate': self.bitrate, - 'user_limit': self.user_limit - }, name=name, reason=reason) + return PartialMessage(channel=self, id=message_id) - clone.__doc__ = discord.abc.GuildChannel.clone.__doc__ + def get_thread(self, thread_id: int, /) -> Optional[Thread]: + """Returns a thread with the given ID. - async def edit(self, *, reason=None, **options): - """|coro| + .. note:: - Edits the channel. + This does not always retrieve archived threads, as they are not retained in the internal + cache. Use :func:`Guild.fetch_channel` instead. - You must have the :attr:`~Permissions.manage_channels` permission to - use this. + .. versionadded:: 2.0 Parameters - ---------- - name: :class:`str` - The new channel's name. - bitrate: :class:`int` - The new channel's bitrate. - user_limit: :class:`int` - The new channel's user limit. - position: :class:`int` - The new channel's position. - sync_permissions: :class:`bool` - Whether to sync permissions with the channel's new or pre-existing - category. Defaults to ``False``. - category: Optional[:class:`CategoryChannel`] - The new category for this channel. Can be ``None`` to remove the - category. - reason: Optional[:class:`str`] - The reason for editing this channel. Shows up on the audit log. + ----------- + thread_id: :class:`int` + The ID to search for. - Raises - ------ - Forbidden - You do not have permissions to edit the channel. - HTTPException - Editing the channel failed. + Returns + -------- + Optional[:class:`Thread`] + The returned thread or ``None`` if not found. """ + return self.guild.get_thread(thread_id) + + async def create_thread( + self, + *, + name: str, + message: Optional[Snowflake] = None, + auto_archive_duration: ThreadArchiveDuration = MISSING, + type: Optional[ChannelType] = None, + reason: Optional[str] = None, + invitable: bool = True, + slowmode_delay: Optional[int] = None, + ) -> Thread: + """|coro| - await self._edit(options, reason=reason) + Creates a thread in this text channel. -class CategoryChannel(discord.abc.GuildChannel, Hashable): - """Represents a Discord channel category. + To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`. + For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead. - These are useful to group channels to logical compartments. + .. versionadded:: 2.0 - .. container:: operations + Parameters + ----------- + name: :class:`str` + The name of the thread. + message: Optional[:class:`abc.Snowflake`] + A snowflake representing the message to create the thread with. + If ``None`` is passed then a private thread is created. + Defaults to ``None``. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically hidden from the channel list. + If not provided, the channel's default auto archive duration is used. + + Must be one of ``60``, ``1440``, ``4320``, or ``10080``, if provided. + type: Optional[:class:`ChannelType`] + The type of thread to create. If a ``message`` is passed then this parameter + is ignored, as a thread created with a message is always a public thread. + By default this creates a private thread if this is ``None``. + reason: :class:`str` + The reason for creating a new thread. Shows up on the audit log. + invitable: :class:`bool` + Whether non-moderators can add users to the thread. Only applicable to private threads. + Defaults to ``True``. + slowmode_delay: Optional[:class:`int`] + Specifies the slowmode rate limit for user in this channel, in seconds. + The maximum value possible is ``21600``. By default no slowmode rate limit + if this is ``None``. - .. describe:: x == y + Raises + ------- + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. - Checks if two channels are equal. + Returns + -------- + :class:`Thread` + The created thread + """ - .. describe:: x != y + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + type=type.value, # type: ignore # we're assuming that the user is passing a valid variant + reason=reason, + invitable=invitable, + rate_limit_per_user=slowmode_delay, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + reason=reason, + rate_limit_per_user=slowmode_delay, + ) + + return Thread(guild=self.guild, state=self._state, data=data) + + async def archived_threads( + self, + *, + private: bool = False, + joined: bool = False, + limit: Optional[int] = 100, + before: Optional[Union[Snowflake, datetime.datetime]] = None, + ) -> AsyncIterator[Thread]: + """Returns an :term:`asynchronous iterator` that iterates over all archived threads in this text channel, + in order of decreasing ID for joined threads, and decreasing :attr:`Thread.archive_timestamp` otherwise. + + You must have :attr:`~Permissions.read_message_history` to do this. If iterating over private threads + then :attr:`~Permissions.manage_threads` is also required. + + .. versionadded:: 2.0 - Checks if two channels are not equal. + Parameters + ----------- + limit: Optional[:class:`bool`] + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve archived channels before the given date or ID. + private: :class:`bool` + Whether to retrieve private archived threads. + joined: :class:`bool` + Whether to retrieve private archived threads that you've joined. + You cannot set ``joined`` to ``True`` and ``private`` to ``False``. - .. describe:: hash(x) + Raises + ------ + Forbidden + You do not have permissions to get archived threads. + HTTPException + The request to get the archived threads failed. + ValueError + ``joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived + threads that you have joined. - Returns the category's hash. + Yields + ------- + :class:`Thread` + The archived threads. + """ + if joined and not private: + raise ValueError('Cannot retrieve joined public archived threads') - .. describe:: str(x) + before_timestamp = None - Returns the category's name. + if isinstance(before, datetime.datetime): + if joined: + before_timestamp = str(utils.time_snowflake(before, high=False)) + else: + before_timestamp = before.isoformat() + elif before is not None: + if joined: + before_timestamp = str(before.id) + else: + before_timestamp = utils.snowflake_time(before.id).isoformat() - Attributes - ----------- - name: :class:`str` - The category name. - guild: :class:`Guild` - The guild the category belongs to. - id: :class:`int` - The category channel ID. - position: :class:`int` - The position in the category list. This is a number that starts at 0. e.g. the - top category is position 0. - """ + update_before = lambda data: data['thread_metadata']['archive_timestamp'] + endpoint = self.guild._state.http.get_public_archived_threads - __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') + if joined: + update_before = lambda data: data['id'] + endpoint = self.guild._state.http.get_joined_private_archived_threads + elif private: + endpoint = self.guild._state.http.get_private_archived_threads - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) + while True: + retrieve = 100 + if limit is not None: + if limit <= 0: + return + retrieve = max(2, min(retrieve, limit)) + + data = await endpoint(self.id, before=before_timestamp, limit=retrieve) + + threads = data.get('threads', []) + for raw_thread in threads: + yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread) + # Currently the API doesn't let you request less than 2 threads. + # Bail out early if we had to retrieve more than what the limit was. + if limit is not None: + limit -= 1 + if limit <= 0: + return + + if not data.get('has_more', False): + return + + before_timestamp = update_before(threads[-1]) + + +class VocalGuildChannel(discord.abc.Messageable, discord.abc.Connectable, discord.abc.GuildChannel, Hashable): + __slots__ = ( + 'name', + 'id', + 'guild', + 'nsfw', + 'bitrate', + 'user_limit', + '_state', + 'position', + 'slowmode_delay', + '_overwrites', + 'category_id', + 'rtc_region', + 'video_quality_mode', + 'last_message_id', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): + self._state: ConnectionState = state + self.id: int = int(data['id']) self._update(guild, data) - def __repr__(self): - return ''.format(self) + async def _get_channel(self) -> Self: + return self - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.nsfw = data.get('nsfw', False) - self.position = data['position'] + def _get_voice_client_key(self) -> Tuple[int, str]: + return self.guild.id, 'guild_id' + + def _get_voice_state_pair(self) -> Tuple[int, int]: + return self.guild.id, self.id + + def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.nsfw: bool = data.get('nsfw', False) + self.rtc_region: Optional[str] = data.get('rtc_region') + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') + self.position: int = data['position'] + self.slowmode_delay = data.get('rate_limit_per_user', 0) + self.bitrate: int = data['bitrate'] + self.user_limit: int = data['user_limit'] self._fill_overwrites(data) @property - def _sorting_bucket(self): - return ChannelType.category.value + def _sorting_bucket(self) -> int: + return ChannelType.voice.value - @property - def type(self): - """:class:`ChannelType`: The channel's Discord type.""" - return ChannelType.category + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the channel is NSFW. - def is_nsfw(self): - """Checks if the category is NSFW.""" + .. versionadded:: 2.0 + """ return self.nsfw - async def clone(self, *, name=None, reason=None): - return await self._clone_impl({ - 'nsfw': self.nsfw - }, name=name, reason=reason) - - clone.__doc__ = discord.abc.GuildChannel.clone.__doc__ - - async def edit(self, *, reason=None, **options): - """|coro| + @property + def members(self) -> List[Member]: + """List[:class:`Member`]: Returns all members that are currently inside this voice channel.""" + ret = [] + for user_id, state in self.guild._voice_states.items(): + if state.channel and state.channel.id == self.id: + member = self.guild.get_member(user_id) + if member is not None: + ret.append(member) + return ret + + @property + def voice_states(self) -> Dict[int, VoiceState]: + """Returns a mapping of member IDs who have voice states in this channel. + + .. versionadded:: 1.3 + + .. note:: + + This function is intentionally low level to replace :attr:`members` + when the member cache is unavailable. + + Returns + -------- + Mapping[:class:`int`, :class:`VoiceState`] + The mapping of member ID to a voice state. + """ + # fmt: off + return { + key: value + for key, value in self.guild._voice_states.items() + if value.channel and value.channel.id == self.id + } + # fmt: on + + @property + def scheduled_events(self) -> List[ScheduledEvent]: + """List[:class:`ScheduledEvent`]: Returns all scheduled events for this channel. + + .. versionadded:: 2.0 + """ + return [event for event in self.guild.scheduled_events if event.channel_id == self.id] + + @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) + self._apply_implicit_permissions(base) + + # voice channels cannot be edited by people who can't connect to them + # It also implicitly denies all other voice perms + if not base.connect: + denied = Permissions.voice() + denied.update(manage_channels=True, manage_roles=True) + base.value &= ~denied.value + return base + + @property + def last_message(self) -> Optional[Message]: + """Retrieves the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. versionadded:: 2.0 + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + --------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return self._state._get_message(self.last_message_id) if self.last_message_id else None + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 2.0 + + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + --------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) # type: ignore # VocalGuildChannel is an impl detail + + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: Optional[str] = None) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have :attr:`~Permissions.manage_messages` to do this. + + .. versionadded:: 2.0 + + Parameters + ----------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id: int = messages[0].id + await self._state.http.delete_message(self.id, message_id) + return + + if len(messages) > 100: + raise ClientException('Can only bulk delete messages up to 100 messages') + + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: Optional[int] = 100, + check: Callable[[Message], bool] = MISSING, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + bulk: bool = True, + reason: Optional[str] = None, + ) -> List[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have :attr:`~Permissions.manage_messages` to + delete messages even if they are your own. + Having :attr:`~Permissions.read_message_history` is + also needed to retrieve message history. + + .. versionadded:: 2.0 + + Examples + --------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f'Deleted {len(deleted)} message(s)') + + Parameters + ----------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + bulk: :class:`bool` + If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting + a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for purging the messages. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Returns + -------- + List[:class:`.Message`] + The list of messages that were deleted. + """ + + return await discord.abc._purge_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> List[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + You must have :attr:`~.Permissions.manage_webhooks` to do this. + + .. versionadded:: 2.0 + + Raises + ------- + Forbidden + You don't have permissions to get the webhooks. + + Returns + -------- + List[:class:`Webhook`] + The webhooks for this channel. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + You must have :attr:`~.Permissions.manage_webhooks` to do this. + + .. versionadded:: 2.0 + + Parameters + ------------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Raises + ------- + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + + Returns + -------- + :class:`Webhook` + The created webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) # type: ignore # Silence reassignment error + + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + return Webhook.from_state(data, state=self._state) + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, *, name: Optional[str] = None, category: Optional[CategoryChannel] = None, reason: Optional[str] = None + ) -> Self: + base = { + 'bitrate': self.bitrate, + 'user_limit': self.user_limit, + 'rate_limit_per_user': self.slowmode_delay, + 'nsfw': self.nsfw, + 'video_quality_mode': self.video_quality_mode.value, + } + if self.rtc_region: + base['rtc_region'] = self.rtc_region + + return await self._clone_impl( + base, + name=name, + category=category, + reason=reason, + ) + + +class VoiceChannel(VocalGuildChannel): + """Represents a Discord guild voice channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ----------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + nsfw: :class:`bool` + If the channel is marked as "not safe for work" or "age restricted". + + .. versionadded:: 2.0 + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a voice channel. + rtc_region: Optional[:class:`str`] + The region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + + .. versionadded:: 1.7 + + .. versionchanged:: 2.0 + The type of this attribute has changed to :class:`str`. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + + .. versionadded:: 2.0 + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + + .. versionadded:: 2.0 + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + + .. versionadded:: 2.2 + """ + + __slots__ = () + + def __repr__(self) -> str: + attrs = [ + ('id', self.id), + ('name', self.name), + ('rtc_region', self.rtc_region), + ('position', self.position), + ('bitrate', self.bitrate), + ('video_quality_mode', self.video_quality_mode), + ('user_limit', self.user_limit), + ('category_id', self.category_id), + ] + joined = ' '.join('%s=%r' % t for t in attrs) + return f'<{self.__class__.__name__} {joined}>' + + @property + def _scheduled_event_entity_type(self) -> Optional[EntityType]: + return EntityType.voice + + @property + def type(self) -> Literal[ChannelType.voice]: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.voice + + @overload + async def edit(self) -> None: ... + + @overload + async def edit(self, *, position: int, reason: Optional[str] = ...) -> None: ... + + @overload + async def edit( + self, + *, + name: str = ..., + nsfw: bool = ..., + bitrate: int = ..., + user_limit: int = ..., + position: int = ..., + sync_permissions: int = ..., + category: Optional[CategoryChannel] = ..., + overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ..., + rtc_region: Optional[str] = ..., + video_quality_mode: VideoQualityMode = ..., + slowmode_delay: int = ..., + status: Optional[str] = ..., + reason: Optional[str] = ..., + ) -> VoiceChannel: ... + + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` to do this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + .. versionchanged:: 2.0 + The ``region`` parameter now accepts :class:`str` instead of an enum. + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` instead of + ``InvalidArgument``. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + bitrate: :class:`int` + The new channel's bitrate. + nsfw: :class:`bool` + To mark the channel as NSFW or not. + user_limit: :class:`int` + The new channel's user limit. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + overwrites: :class:`Mapping` + A :class:`Mapping` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply to the channel. + rtc_region: Optional[:class:`str`] + The new region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + + .. versionadded:: 1.7 + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + + .. versionadded:: 2.0 + status: Optional[:class:`str`] + The new voice channel status. It can be up to 500 characters. + Can be ``None`` to remove the status. + + .. versionadded:: 2.4 + + Raises + ------ + TypeError + If the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + + Returns + -------- + Optional[:class:`.VoiceChannel`] + The newly edited voice channel. If the edit was only positional + then ``None`` is returned instead. + """ + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + async def send_sound(self, sound: Union[SoundboardSound, SoundboardDefaultSound], /) -> None: + """|coro| + + Sends a soundboard sound for this channel. + + You must have :attr:`~Permissions.speak` and :attr:`~Permissions.use_soundboard` to do this. + Additionally, you must have :attr:`~Permissions.use_external_sounds` if the sound is from + a different guild. + + .. versionadded:: 2.5 + + Parameters + ----------- + sound: Union[:class:`SoundboardSound`, :class:`SoundboardDefaultSound`] + The sound to send for this channel. + + Raises + ------- + Forbidden + You do not have permissions to send a sound for this channel. + HTTPException + Sending the sound failed. + """ + payload = {'sound_id': sound.id} + if not isinstance(sound, SoundboardDefaultSound) and self.guild.id != sound.guild.id: + payload['source_guild_id'] = sound.guild.id + + await self._state.http.send_soundboard_sound(self.id, **payload) + + +class StageChannel(VocalGuildChannel): + """Represents a Discord guild stage channel. + + .. versionadded:: 1.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ----------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + nsfw: :class:`bool` + If the channel is marked as "not safe for work" or "age restricted". + + .. versionadded:: 2.0 + topic: Optional[:class:`str`] + The channel's topic. ``None`` if it isn't set. + category_id: Optional[:class:`int`] + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a stage channel. + rtc_region: Optional[:class:`str`] + The region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + + .. versionadded:: 2.0 + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + + .. versionadded:: 2.2 + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + + .. versionadded:: 2.2 + """ + + __slots__ = ('topic',) + + def __repr__(self) -> str: + attrs = [ + ('id', self.id), + ('name', self.name), + ('topic', self.topic), + ('rtc_region', self.rtc_region), + ('position', self.position), + ('bitrate', self.bitrate), + ('video_quality_mode', self.video_quality_mode), + ('user_limit', self.user_limit), + ('category_id', self.category_id), + ] + joined = ' '.join('%s=%r' % t for t in attrs) + return f'<{self.__class__.__name__} {joined}>' + + def _update(self, guild: Guild, data: StageChannelPayload) -> None: + super()._update(guild, data) + self.topic: Optional[str] = data.get('topic') + + @property + def _scheduled_event_entity_type(self) -> Optional[EntityType]: + return EntityType.stage_instance + + @property + def requesting_to_speak(self) -> List[Member]: + """List[:class:`Member`]: A list of members who are requesting to speak in the stage channel.""" + return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] + + @property + def speakers(self) -> List[Member]: + """List[:class:`Member`]: A list of members who have been permitted to speak in the stage channel. + + .. versionadded:: 2.0 + """ + return [ + member + for member in self.members + if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None + ] + + @property + def listeners(self) -> List[Member]: + """List[:class:`Member`]: A list of members who are listening in the stage channel. + + .. versionadded:: 2.0 + """ + return [member for member in self.members if member.voice and member.voice.suppress] + + @property + def moderators(self) -> List[Member]: + """List[:class:`Member`]: A list of members who are moderating the stage channel. + + .. versionadded:: 2.0 + """ + required_permissions = Permissions.stage_moderator() + return [member for member in self.members if self.permissions_for(member) >= required_permissions] + + @property + def type(self) -> Literal[ChannelType.stage_voice]: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.stage_voice + + @property + def instance(self) -> Optional[StageInstance]: + """Optional[:class:`StageInstance`]: The running stage instance of the stage channel. + + .. versionadded:: 2.0 + """ + return utils.get(self.guild.stage_instances, channel_id=self.id) + + async def create_instance( + self, + *, + topic: str, + privacy_level: PrivacyLevel = MISSING, + send_start_notification: bool = False, + scheduled_event: Snowflake = MISSING, + reason: Optional[str] = None, + ) -> StageInstance: + """|coro| + + Create a stage instance. + + You must have :attr:`~Permissions.manage_channels` to do this. + + .. versionadded:: 2.0 + + Parameters + ----------- + topic: :class:`str` + The stage instance's topic. + privacy_level: :class:`PrivacyLevel` + The stage instance's privacy level. Defaults to :attr:`PrivacyLevel.guild_only`. + send_start_notification: :class:`bool` + Whether to send a start notification. This sends a push notification to @everyone if ``True``. Defaults to ``False``. + You must have :attr:`~Permissions.mention_everyone` to do this. + + .. versionadded:: 2.3 + scheduled_event: :class:`~discord.abc.Snowflake` + The guild scheduled event associated with the stage instance. + + .. versionadded:: 2.4 + reason: :class:`str` + The reason the stage instance was created. Shows up on the audit log. + + Raises + ------ + TypeError + If the ``privacy_level`` parameter is not the proper type. + Forbidden + You do not have permissions to create a stage instance. + HTTPException + Creating a stage instance failed. + + Returns + -------- + :class:`StageInstance` + The newly created stage instance. + """ + + payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic} + + if privacy_level is not MISSING: + if not isinstance(privacy_level, PrivacyLevel): + raise TypeError('privacy_level field must be of type PrivacyLevel') + + payload['privacy_level'] = privacy_level.value + + if scheduled_event is not MISSING: + payload['guild_scheduled_event_id'] = scheduled_event.id + + payload['send_start_notification'] = send_start_notification + + data = await self._state.http.create_stage_instance(**payload, reason=reason) + return StageInstance(guild=self.guild, state=self._state, data=data) + + async def fetch_instance(self) -> StageInstance: + """|coro| + + Gets the running :class:`StageInstance`. + + .. versionadded:: 2.0 + + Raises + ------- + NotFound + The stage instance or channel could not be found. + HTTPException + Getting the stage instance failed. + + Returns + -------- + :class:`StageInstance` + The stage instance. + """ + data = await self._state.http.get_stage_instance(self.id) + return StageInstance(guild=self.guild, state=self._state, data=data) + + @overload + async def edit(self) -> None: ... + + @overload + async def edit(self, *, position: int, reason: Optional[str] = ...) -> None: ... + + @overload + async def edit( + self, + *, + name: str = ..., + nsfw: bool = ..., + bitrate: int = ..., + user_limit: int = ..., + position: int = ..., + sync_permissions: int = ..., + category: Optional[CategoryChannel] = ..., + overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ..., + rtc_region: Optional[str] = ..., + video_quality_mode: VideoQualityMode = ..., + slowmode_delay: int = ..., + reason: Optional[str] = ..., + ) -> StageChannel: ... + + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` to do this. + + .. versionchanged:: 2.0 + The ``topic`` parameter must now be set via :attr:`create_instance`. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + .. versionchanged:: 2.0 + The ``region`` parameter now accepts :class:`str` instead of an enum. + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` instead of + ``InvalidArgument``. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + bitrate: :class:`int` + The new channel's bitrate. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + To mark the channel as NSFW or not. + user_limit: :class:`int` + The new channel's user limit. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + overwrites: :class:`Mapping` + A :class:`Mapping` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply to the channel. + rtc_region: Optional[:class:`str`] + The new region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + + .. versionadded:: 2.0 + + Raises + ------ + ValueError + If the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + + Returns + -------- + Optional[:class:`.StageChannel`] + The newly edited stage channel. If the edit was only positional + then ``None`` is returned instead. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + +class CategoryChannel(discord.abc.GuildChannel, Hashable): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ----------- + name: :class:`str` + The category name. + guild: :class:`Guild` + The guild the category belongs to. + id: :class:`int` + The category channel ID. + position: :class:`int` + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. + nsfw: :class:`bool` + If the channel is marked as "not safe for work". + + .. note:: + + To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. + """ + + __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') + + def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self._update(guild, data) + + def __repr__(self) -> str: + return f'' + + def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.nsfw: bool = data.get('nsfw', False) + self.position: int = data['position'] + self._fill_overwrites(data) + + @property + def _sorting_bucket(self) -> int: + return ChannelType.category.value + + @property + def _scheduled_event_entity_type(self) -> Optional[EntityType]: + return None + + @property + def type(self) -> Literal[ChannelType.category]: + """:class:`ChannelType`: The channel's Discord type.""" + return ChannelType.category + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the category is NSFW.""" + return self.nsfw + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel] = None, + reason: Optional[str] = None, + ) -> CategoryChannel: + return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) + + @overload + async def edit(self) -> None: ... + + @overload + async def edit(self, *, position: int, reason: Optional[str] = ...) -> None: ... + + @overload + async def edit( + self, + *, + name: str = ..., + position: int = ..., + nsfw: bool = ..., + overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ..., + reason: Optional[str] = ..., + ) -> CategoryChannel: ... + + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` to do this. - Edits the channel. + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. - You must have the :attr:`~Permissions.manage_channels` permission to - use this. + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. Parameters ---------- @@ -699,35 +2141,45 @@ async def edit(self, *, reason=None, **options): To mark the category as NSFW or not. reason: Optional[:class:`str`] The reason for editing this category. Shows up on the audit log. + overwrites: :class:`Mapping` + A :class:`Mapping` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply to the channel. Raises ------ - InvalidArgument + ValueError If position is less than 0 or greater than the number of categories. + TypeError + The overwrite information is not in proper form. Forbidden You do not have permissions to edit the category. HTTPException Editing the category failed. + + Returns + -------- + Optional[:class:`.CategoryChannel`] + The newly edited category channel. If the edit was only positional + then ``None`` is returned instead. """ - try: - position = options.pop('position') - except KeyError: - pass - else: - await self._move(position, reason=reason) - self.position = position + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - if options: - data = await self._state.http.edit_channel(self.id, reason=reason, **options) - self._update(self.guild, data) + @utils.copy_doc(discord.abc.GuildChannel.move) + async def move(self, **kwargs: Any) -> None: + kwargs.pop('category', None) + await super().move(**kwargs) @property - def channels(self): + def channels(self) -> List[GuildChannelType]: """List[:class:`abc.GuildChannel`]: Returns the channels that are under this category. These are sorted by the official Discord UI, which places voice channels below the text channels. """ + def comparator(channel): return (not isinstance(channel, TextChannel), channel.position) @@ -736,158 +2188,1009 @@ def comparator(channel): return ret @property - def text_channels(self): + def text_channels(self) -> List[TextChannel]: """List[:class:`TextChannel`]: Returns the text channels that are under this category.""" - ret = [c for c in self.guild.channels - if c.category_id == self.id - and isinstance(c, TextChannel)] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @property - def voice_channels(self): + def voice_channels(self) -> List[VoiceChannel]: """List[:class:`VoiceChannel`]: Returns the voice channels that are under this category.""" - ret = [c for c in self.guild.channels - if c.category_id == self.id - and isinstance(c, VoiceChannel)] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] + ret.sort(key=lambda c: (c.position, c.id)) + return ret + + @property + def stage_channels(self) -> List[StageChannel]: + """List[:class:`StageChannel`]: Returns the stage channels that are under this category. + + .. versionadded:: 1.7 + """ + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret - async def create_text_channel(self, name, *, overwrites=None, reason=None, **options): + @property + def forums(self) -> List[ForumChannel]: + """List[:class:`ForumChannel`]: Returns the forum channels that are under this category. + + .. versionadded:: 2.4 + """ + r = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, ForumChannel)] + r.sort(key=lambda c: (c.position, c.id)) + return r + + async def create_text_channel(self, name: str, **options: Unpack[_CreateTextChannelOptions]) -> TextChannel: """|coro| A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category. + + Returns + ------- + :class:`TextChannel` + The channel that was just created. """ - return await self.guild.create_text_channel(name, overwrites=overwrites, category=self, reason=reason, **options) + return await self.guild.create_text_channel(name, category=self, **options) - async def create_voice_channel(self, name, *, overwrites=None, reason=None, **options): + async def create_voice_channel(self, name: str, **options: Unpack[_CreateVoiceChannelOptions]) -> VoiceChannel: """|coro| A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category. + + Returns + ------- + :class:`VoiceChannel` + The channel that was just created. + """ + return await self.guild.create_voice_channel(name, category=self, **options) + + async def create_stage_channel(self, name: str, **options: Unpack[_CreateStageChannelOptions]) -> StageChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category. + + .. versionadded:: 1.7 + + Returns + ------- + :class:`StageChannel` + The channel that was just created. + """ + return await self.guild.create_stage_channel(name, category=self, **options) + + async def create_forum(self, name: str, **options: Unpack[_CreateForumChannelOptions]) -> ForumChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_forum` to create a :class:`ForumChannel` in the category. + + .. versionadded:: 2.0 + + Returns + -------- + :class:`ForumChannel` + The channel that was just created. """ - return await self.guild.create_voice_channel(name, overwrites=overwrites, category=self, reason=reason, **options) + return await self.guild.create_forum(name, category=self, **options) + + +class ForumTag(Hashable): + """Represents a forum tag that can be applied to a thread within a :class:`ForumChannel`. + + .. versionadded:: 2.1 + + .. container:: operations + + .. describe:: x == y + + Checks if two forum tags are equal. + + .. describe:: x != y + + Checks if two forum tags are not equal. + + .. describe:: hash(x) + + Returns the forum tag's hash. + + .. describe:: str(x) + + Returns the forum tag's name. + + + Attributes + ----------- + id: :class:`int` + The ID of the tag. If this was manually created then the ID will be ``0``. + name: :class:`str` + The name of the tag. Can only be up to 20 characters. + moderated: :class:`bool` + Whether this tag can only be added or removed by a moderator with + the :attr:`~Permissions.manage_threads` permission. + emoji: Optional[:class:`PartialEmoji`] + The emoji that is used to represent this tag. + Note that if the emoji is a custom emoji, it will *not* have name information. + """ + + __slots__ = ('name', 'id', 'moderated', 'emoji') + + def __init__(self, *, name: str, emoji: Optional[EmojiInputType] = None, moderated: bool = False) -> None: + self.name: str = name + self.id: int = 0 + self.moderated: bool = moderated + self.emoji: Optional[PartialEmoji] = None + if isinstance(emoji, _EmojiTag): + self.emoji = emoji._to_partial() + elif isinstance(emoji, str): + self.emoji = PartialEmoji.from_str(emoji) + elif emoji is not None: + raise TypeError(f'emoji must be a Emoji, PartialEmoji, str or None not {emoji.__class__.__name__}') + + @classmethod + def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> Self: + self = cls.__new__(cls) + self.name = data['name'] + self.id = int(data['id']) + self.moderated = data.get('moderated', False) + + emoji_name = data['emoji_name'] or '' + emoji_id = utils._get_as_snowflake(data, 'emoji_id') or None # Coerce 0 -> None + if not emoji_name and not emoji_id: + self.emoji = None + else: + self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) + return self + + def to_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + 'name': self.name, + 'moderated': self.moderated, + } + if self.emoji is not None: + payload.update(self.emoji._to_forum_tag_payload()) + else: + payload.update(emoji_id=None, emoji_name=None) + + if self.id: + payload['id'] = self.id + + return payload + + def __repr__(self) -> str: + return f'' + + def __str__(self) -> str: + return self.name + + +class ForumChannel(discord.abc.GuildChannel, Hashable): + """Represents a Discord guild forum channel. -class StoreChannel(discord.abc.GuildChannel, Hashable): - """Represents a Discord guild store channel. + .. versionadded:: 2.0 .. container:: operations - .. describe:: x == y + .. describe:: x == y + + Checks if two forums are equal. + + .. describe:: x != y + + Checks if two forums are not equal. + + .. describe:: hash(x) + + Returns the forum's hash. + + .. describe:: str(x) + + Returns the forum's name. + + Attributes + ----------- + name: :class:`str` + The forum name. + guild: :class:`Guild` + The guild the forum belongs to. + id: :class:`int` + The forum ID. + category_id: Optional[:class:`int`] + The category channel ID this forum belongs to, if applicable. + topic: Optional[:class:`str`] + The forum's topic. ``None`` if it doesn't exist. Called "Guidelines" in the UI. + Can be up to 4096 characters long. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + last_message_id: Optional[:class:`int`] + The last thread ID that was created on this forum. This technically also + coincides with the message ID that started the thread that was created. + It may *not* point to an existing or valid thread or message. + slowmode_delay: :class:`int` + The number of seconds a member must wait between creating threads + in this forum. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + nsfw: :class:`bool` + If the forum is marked as "not safe for work" or "age restricted". + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this forum. + default_thread_slowmode_delay: :class:`int` + The default slowmode delay in seconds for threads created in this forum. + + .. versionadded:: 2.1 + default_reaction_emoji: Optional[:class:`PartialEmoji`] + The default reaction emoji for threads created in this forum to show in the + add reaction button. + + .. versionadded:: 2.1 + default_layout: :class:`ForumLayoutType` + The default layout for posts in this forum channel. + Defaults to :attr:`ForumLayoutType.not_set`. + + .. versionadded:: 2.2 + default_sort_order: Optional[:class:`ForumOrderType`] + The default sort order for posts in this forum channel. + + .. versionadded:: 2.3 + """ + + __slots__ = ( + 'name', + 'id', + 'guild', + 'topic', + '_state', + '_flags', + '_type', + 'nsfw', + 'category_id', + 'position', + 'slowmode_delay', + '_overwrites', + 'last_message_id', + 'default_auto_archive_duration', + 'default_thread_slowmode_delay', + 'default_reaction_emoji', + 'default_layout', + 'default_sort_order', + '_available_tags', + '_flags', + ) + + def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPayload]): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self._type: Literal[15, 16] = data['type'] + self._update(guild, data) + + def __repr__(self) -> str: + attrs = [ + ('id', self.id), + ('name', self.name), + ('position', self.position), + ('nsfw', self.nsfw), + ('category_id', self.category_id), + ] + joined = ' '.join('%s=%r' % t for t in attrs) + return f'<{self.__class__.__name__} {joined}>' + + def _update(self, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPayload]) -> None: + self.guild: Guild = guild + self.name: str = data['name'] + self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.topic: Optional[str] = data.get('topic') + self.position: int = data['position'] + self.nsfw: bool = data.get('nsfw', False) + self.slowmode_delay: int = data.get('rate_limit_per_user', 0) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') + # This takes advantage of the fact that dicts are ordered since Python 3.7 + tags = [ForumTag.from_data(state=self._state, data=tag) for tag in data.get('available_tags', [])] + self.default_thread_slowmode_delay: int = data.get('default_thread_rate_limit_per_user', 0) + self.default_layout: ForumLayoutType = try_enum(ForumLayoutType, data.get('default_forum_layout', 0)) + self._available_tags: Dict[int, ForumTag] = {tag.id: tag for tag in tags} + + self.default_reaction_emoji: Optional[PartialEmoji] = None + default_reaction_emoji = data.get('default_reaction_emoji') + if default_reaction_emoji: + self.default_reaction_emoji = PartialEmoji.with_state( + state=self._state, + id=utils._get_as_snowflake(default_reaction_emoji, 'emoji_id') or None, # Coerce 0 -> None + name=default_reaction_emoji.get('emoji_name') or '', + ) + + self.default_sort_order: Optional[ForumOrderType] = None + default_sort_order = data.get('default_sort_order') + if default_sort_order is not None: + self.default_sort_order = try_enum(ForumOrderType, default_sort_order) + + self._flags: int = data.get('flags', 0) + self._fill_overwrites(data) + + @property + def type(self) -> Literal[ChannelType.forum, ChannelType.media]: + """:class:`ChannelType`: The channel's Discord type.""" + if self._type == 16: + return ChannelType.media + return ChannelType.forum + + @property + def _sorting_bucket(self) -> int: + return ChannelType.text.value + + @property + def members(self) -> List[Member]: + """List[:class:`Member`]: Returns all members that can see this channel. + + .. versionadded:: 2.5 + """ + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + + @property + def _scheduled_event_entity_type(self) -> Optional[EntityType]: + return None + + @utils.copy_doc(discord.abc.GuildChannel.permissions_for) + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: + base = super().permissions_for(obj) + self._apply_implicit_permissions(base) + + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + def get_thread(self, thread_id: int, /) -> Optional[Thread]: + """Returns a thread with the given ID. + + .. note:: + + This does not always retrieve archived threads, as they are not retained in the internal + cache. Use :func:`Guild.fetch_channel` instead. + + .. versionadded:: 2.2 + + Parameters + ----------- + thread_id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`Thread`] + The returned thread or ``None`` if not found. + """ + thread = self.guild.get_thread(thread_id) + if thread is not None and thread.parent_id == self.id: + return thread + return None + + @property + def threads(self) -> List[Thread]: + """List[:class:`Thread`]: Returns all the threads that you can see.""" + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] + + @property + def flags(self) -> ChannelFlags: + """:class:`ChannelFlags`: The flags associated with this thread. + + .. versionadded:: 2.1 + """ + return ChannelFlags._from_value(self._flags) + + @property + def available_tags(self) -> Sequence[ForumTag]: + """Sequence[:class:`ForumTag`]: Returns all the available tags for this forum. + + .. versionadded:: 2.1 + """ + return utils.SequenceProxy(self._available_tags.values()) + + def get_tag(self, tag_id: int, /) -> Optional[ForumTag]: + """Returns the tag with the given ID. + + .. versionadded:: 2.1 + + Parameters + ---------- + tag_id: :class:`int` + The ID to search for. + + Returns + ------- + Optional[:class:`ForumTag`] + The tag with the given ID, or ``None`` if not found. + """ + return self._available_tags.get(tag_id) + + def is_nsfw(self) -> bool: + """:class:`bool`: Checks if the forum is NSFW.""" + return self.nsfw + + def is_media(self) -> bool: + """:class:`bool`: Checks if the channel is a media channel. + + .. versionadded:: 2.4 + """ + return self._type == ChannelType.media.value + + @utils.copy_doc(discord.abc.GuildChannel.clone) + async def clone( + self, + *, + name: Optional[str] = None, + category: Optional[CategoryChannel], + reason: Optional[str] = None, + ) -> ForumChannel: + base = { + 'topic': self.topic, + 'rate_limit_per_user': self.slowmode_delay, + 'nsfw': self.nsfw, + 'default_auto_archive_duration': self.default_auto_archive_duration, + 'available_tags': [tag.to_dict() for tag in self.available_tags], + 'default_thread_rate_limit_per_user': self.default_thread_slowmode_delay, + } + if self.default_sort_order: + base['default_sort_order'] = self.default_sort_order.value + if self.default_reaction_emoji: + base['default_reaction_emoji'] = self.default_reaction_emoji._to_forum_tag_payload() + if not self.is_media() and self.default_layout: + base['default_forum_layout'] = self.default_layout.value + + return await self._clone_impl( + base, + name=name, + category=category, + reason=reason, + ) + + @overload + async def edit(self) -> None: ... + + @overload + async def edit(self, *, position: int, reason: Optional[str] = ...) -> None: ... + + @overload + async def edit( + self, + *, + reason: Optional[str] = ..., + name: str = ..., + topic: str = ..., + position: int = ..., + nsfw: bool = ..., + sync_permissions: bool = ..., + category: Optional[CategoryChannel] = ..., + slowmode_delay: int = ..., + default_auto_archive_duration: ThreadArchiveDuration = ..., + type: ChannelType = ..., + overwrites: Mapping[OverwriteKeyT, PermissionOverwrite] = ..., + available_tags: Sequence[ForumTag] = ..., + default_thread_slowmode_delay: int = ..., + default_reaction_emoji: Optional[EmojiInputType] = ..., + default_layout: ForumLayoutType = ..., + default_sort_order: ForumOrderType = ..., + require_tag: bool = ..., + ) -> ForumChannel: ... + + async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[ForumChannel]: + """|coro| + + Edits the forum. + + You must have :attr:`~Permissions.manage_channels` to do this. + + Parameters + ---------- + name: :class:`str` + The new forum name. + topic: :class:`str` + The new forum's topic. + position: :class:`int` + The new forum's position. + nsfw: :class:`bool` + To mark the forum as NSFW or not. + sync_permissions: :class:`bool` + Whether to sync permissions with the forum's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this forum. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this forum, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + type: :class:`ChannelType` + Change the type of this text forum. Currently, only conversion between + :attr:`ChannelType.text` and :attr:`ChannelType.news` is supported. This + is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`. + reason: Optional[:class:`str`] + The reason for editing this forum. Shows up on the audit log. + overwrites: :class:`Mapping` + A :class:`Mapping` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply to the forum. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + available_tags: Sequence[:class:`ForumTag`] + The new available tags for this forum. + + .. versionadded:: 2.1 + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay for threads in this channel. + + .. versionadded:: 2.1 + default_reaction_emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]] + The new default reaction emoji for threads in this channel. + + .. versionadded:: 2.1 + default_layout: :class:`ForumLayoutType` + The new default layout for posts in this forum. + + .. versionadded:: 2.2 + default_sort_order: Optional[:class:`ForumOrderType`] + The new default sort order for posts in this forum. + + .. versionadded:: 2.3 + require_tag: :class:`bool` + Whether to require a tag for threads in this channel or not. + + .. versionadded:: 2.1 + + Raises + ------ + ValueError + The new ``position`` is less than 0 or greater than the number of channels. + TypeError + The permission overwrite information is not in proper form or a type + is not the expected type. + Forbidden + You do not have permissions to edit the forum. + HTTPException + Editing the forum failed. + + Returns + -------- + Optional[:class:`.ForumChannel`] + The newly edited forum channel. If the edit was only positional + then ``None`` is returned instead. + """ + + try: + tags: Sequence[ForumTag] = options.pop('available_tags') + except KeyError: + pass + else: + options['available_tags'] = [tag.to_dict() for tag in tags] + + try: + default_reaction_emoji: Optional[EmojiInputType] = options.pop('default_reaction_emoji') + except KeyError: + pass + else: + if default_reaction_emoji is None: + options['default_reaction_emoji'] = None + elif isinstance(default_reaction_emoji, _EmojiTag): + options['default_reaction_emoji'] = default_reaction_emoji._to_partial()._to_forum_tag_payload() + elif isinstance(default_reaction_emoji, str): + options['default_reaction_emoji'] = PartialEmoji.from_str(default_reaction_emoji)._to_forum_tag_payload() + + try: + require_tag = options.pop('require_tag') + except KeyError: + pass + else: + flags = self.flags + flags.require_tag = require_tag + options['flags'] = flags.value + + try: + layout = options.pop('default_layout') + except KeyError: + pass + else: + if not isinstance(layout, ForumLayoutType): + raise TypeError(f'default_layout parameter must be a ForumLayoutType not {layout.__class__.__name__}') + + options['default_forum_layout'] = layout.value + + try: + sort_order = options.pop('default_sort_order') + except KeyError: + pass + else: + if sort_order is None: + options['default_sort_order'] = None + else: + if not isinstance(sort_order, ForumOrderType): + raise TypeError( + f'default_sort_order parameter must be a ForumOrderType not {sort_order.__class__.__name__}' + ) + + options['default_sort_order'] = sort_order.value + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + + async def create_tag( + self, + *, + name: str, + emoji: Optional[PartialEmoji] = None, + moderated: bool = False, + reason: Optional[str] = None, + ) -> ForumTag: + """|coro| + + Creates a new tag in this forum. + + You must have :attr:`~Permissions.manage_channels` to do this. + + Parameters + ---------- + name: :class:`str` + The name of the tag. Can only be up to 20 characters. + emoji: Optional[Union[:class:`str`, :class:`PartialEmoji`]] + The emoji to use for the tag. + moderated: :class:`bool` + Whether the tag can only be applied by moderators. + reason: Optional[:class:`str`] + The reason for creating this tag. Shows up on the audit log. + + Raises + ------ + Forbidden + You do not have permissions to create a tag in this forum. + HTTPException + Creating the tag failed. + + Returns + ------- + :class:`ForumTag` + The newly created tag. + """ + + prior = list(self._available_tags.values()) + result = ForumTag(name=name, emoji=emoji, moderated=moderated) + prior.append(result) + payload = await self._state.http.edit_channel( + self.id, reason=reason, available_tags=[tag.to_dict() for tag in prior] + ) + try: + result.id = int(payload['available_tags'][-1]['id']) # type: ignore + except (KeyError, IndexError, ValueError): + pass + + return result + + @overload + async def create_thread( + self, + *, + name: str, + auto_archive_duration: ThreadArchiveDuration = ..., + slowmode_delay: Optional[int] = ..., + file: File = ..., + files: Sequence[File] = ..., + allowed_mentions: AllowedMentions = ..., + mention_author: bool = ..., + applied_tags: Sequence[ForumTag] = ..., + view: LayoutView, + suppress_embeds: bool = ..., + silent: bool = ..., + reason: Optional[str] = ..., + ) -> ThreadWithMessage: ... + + @overload + async def create_thread( + self, + *, + name: str, + auto_archive_duration: ThreadArchiveDuration = ..., + slowmode_delay: Optional[int] = ..., + content: Optional[str] = ..., + tts: bool = ..., + embed: Embed = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + allowed_mentions: AllowedMentions = ..., + mention_author: bool = ..., + applied_tags: Sequence[ForumTag] = ..., + view: View = ..., + suppress_embeds: bool = ..., + silent: bool = ..., + reason: Optional[str] = ..., + ) -> ThreadWithMessage: ... + + async def create_thread( + self, + *, + name: str, + auto_archive_duration: ThreadArchiveDuration = MISSING, + slowmode_delay: Optional[int] = None, + content: Optional[str] = None, + tts: bool = False, + embed: Embed = MISSING, + embeds: Sequence[Embed] = MISSING, + file: File = MISSING, + files: Sequence[File] = MISSING, + stickers: Sequence[Union[GuildSticker, StickerItem]] = MISSING, + allowed_mentions: AllowedMentions = MISSING, + mention_author: bool = MISSING, + applied_tags: Sequence[ForumTag] = MISSING, + view: BaseView = MISSING, + suppress_embeds: bool = False, + silent: bool = False, + reason: Optional[str] = None, + ) -> ThreadWithMessage: + """|coro| + + Creates a thread in this forum. + + This thread is a public thread with the initial message given. Currently in order + to start a thread in this forum, the user needs :attr:`~discord.Permissions.send_messages`. + + You must send at least one of ``content``, ``embed``, ``embeds``, ``file``, ``files``, + or ``view`` to create a thread in a forum, since forum channels must have a starter message. + + Parameters + ----------- + name: :class:`str` + The name of the thread. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically hidden from the channel list. + If not provided, the channel's default auto archive duration is used. + + Must be one of ``60``, ``1440``, ``4320``, or ``10080``, if provided. + slowmode_delay: Optional[:class:`int`] + Specifies the slowmode rate limit for user in this channel, in seconds. + The maximum value possible is ``21600``. By default no slowmode rate limit + if this is ``None``. + content: Optional[:class:`str`] + The content of the message to send with the thread. + tts: :class:`bool` + Indicates if the message should be sent using text-to-speech. + embed: :class:`~discord.Embed` + The rich embed for the content. + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + mention_author: :class:`bool` + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + applied_tags: List[:class:`discord.ForumTag`] + A list of tags to apply to the thread. + view: Union[:class:`discord.ui.View`, :class:`discord.ui.LayoutView`] + A Discord UI View to add to the message. + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. + suppress_embeds: :class:`bool` + Whether to suppress embeds for the message. This sends the message without any embeds if set to ``True``. + silent: :class:`bool` + Whether to suppress push and desktop notifications for the message. This will increment the mention counter + in the UI, but will not actually send a notification. + + .. versionadded:: 2.7 + reason: :class:`str` + The reason for creating a new thread. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. + ValueError + The ``files`` or ``embeds`` list is not of the appropriate size. + TypeError + You specified both ``file`` and ``files``, + or you specified both ``embed`` and ``embeds``. - Checks if two channels are equal. + Returns + -------- + Tuple[:class:`Thread`, :class:`Message`] + The created thread with the created message. + This is also accessible as a namedtuple with ``thread`` and ``message`` fields. + """ - .. describe:: x != y + state = self._state + previous_allowed_mention = state.allowed_mentions + if stickers is MISSING: + sticker_ids = MISSING + else: + sticker_ids: SnowflakeList = [s.id for s in stickers] - Checks if two channels are not equal. + if view and not hasattr(view, '__discord_ui_view__'): + raise TypeError(f'view parameter must be View not {view.__class__.__name__}') - .. describe:: hash(x) + if suppress_embeds or silent: + flags = MessageFlags._from_value(0) + flags.suppress_embeds = suppress_embeds + flags.suppress_notifications = silent + else: + flags = MISSING + + content = str(content) if content else MISSING + + channel_payload = { + 'name': name, + 'auto_archive_duration': auto_archive_duration or self.default_auto_archive_duration, + 'rate_limit_per_user': slowmode_delay, + 'type': 11, # Private threads don't seem to be allowed + } + + if applied_tags is not MISSING: + channel_payload['applied_tags'] = [str(tag.id) for tag in applied_tags] + + with handle_message_parameters( + content=content, + tts=tts, + file=file, + files=files, + embed=embed, + embeds=embeds, + allowed_mentions=allowed_mentions, + previous_allowed_mentions=previous_allowed_mention, + mention_author=None if mention_author is MISSING else mention_author, + stickers=sticker_ids, + view=view, + flags=flags, + channel_payload=channel_payload, + ) as params: + # Circular import + from .message import Message + + data = await state.http.start_thread_in_forum(self.id, params=params, reason=reason) + thread = Thread(guild=self.guild, state=self._state, data=data) + message = Message(state=self._state, channel=thread, data=data['message']) + if view and not view.is_finished() and view.is_dispatchable(): + self._state.store_view(view, message.id) + + return ThreadWithMessage(thread=thread, message=message) + + async def webhooks(self) -> List[Webhook]: + """|coro| - Returns the channel's hash. + Gets the list of webhooks from this channel. - .. describe:: str(x) + You must have :attr:`~.Permissions.manage_webhooks` to do this. - Returns the channel's name. + Raises + ------- + Forbidden + You don't have permissions to get the webhooks. - Attributes - ----------- - name: :class:`str` - The channel name. - guild: :class:`Guild` - The guild the channel belongs to. - id: :class:`int` - The channel ID. - category_id: :class:`int` - The category channel ID this channel belongs to. - position: :class:`int` - The position in the channel list. This is a number that starts at 0. e.g. the - top channel is position 0. - """ - __slots__ = ('name', 'id', 'guild', '_state', 'nsfw', - 'category_id', 'position', '_overwrites',) + Returns + -------- + List[:class:`Webhook`] + The webhooks for this channel. + """ - def __init__(self, *, state, guild, data): - self._state = state - self.id = int(data['id']) - self._update(guild, data) + from .webhook import Webhook - def __repr__(self): - return ''.format(self) + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] - def _update(self, guild, data): - self.guild = guild - self.name = data['name'] - self.category_id = utils._get_as_snowflake(data, 'parent_id') - self.position = data['position'] - self.nsfw = data.get('nsfw', False) - self._fill_overwrites(data) + async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: + """|coro| - @property - def _sorting_bucket(self): - return ChannelType.text.value + Creates a webhook for this channel. - @property - def type(self): - """:class:`ChannelType`: The channel's Discord type.""" - return ChannelType.store + You must have :attr:`~.Permissions.manage_webhooks` to do this. - def permissions_for(self, member): - base = super().permissions_for(member) + Parameters + ------------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. - # store channels do not have voice related permissions - denied = Permissions.voice() - base.value &= ~denied.value - return base + Raises + ------- + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. - permissions_for.__doc__ = discord.abc.GuildChannel.permissions_for.__doc__ + Returns + -------- + :class:`Webhook` + The created webhook. + """ - def is_nsfw(self): - """Checks if the channel is NSFW.""" - return self.nsfw + from .webhook import Webhook - async def clone(self, *, name=None, reason=None): - return await self._clone_impl({ - 'nsfw': self.nsfw - }, name=name, reason=reason) + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) # type: ignore # Silence reassignment error - clone.__doc__ = discord.abc.GuildChannel.clone.__doc__ + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + return Webhook.from_state(data, state=self._state) - async def edit(self, *, reason=None, **options): - """|coro| + async def archived_threads( + self, + *, + limit: Optional[int] = 100, + before: Optional[Union[Snowflake, datetime.datetime]] = None, + ) -> AsyncIterator[Thread]: + """Returns an :term:`asynchronous iterator` that iterates over all archived threads in this forum + in order of decreasing :attr:`Thread.archive_timestamp`. - Edits the channel. + You must have :attr:`~Permissions.read_message_history` to do this. - You must have the :attr:`~Permissions.manage_channels` permission to - use this. + .. versionadded:: 2.0 Parameters - ---------- - name: :class:`str` - The new channel name. - position: :class:`int` - The new channel's position. - nsfw: :class:`bool` - To mark the channel as NSFW or not. - sync_permissions: :class:`bool` - Whether to sync permissions with the channel's new or pre-existing - category. Defaults to ``False``. - category: Optional[:class:`CategoryChannel`] - The new category for this channel. Can be ``None`` to remove the - category. - reason: Optional[:class:`str`] - The reason for editing this channel. Shows up on the audit log. + ----------- + limit: Optional[:class:`bool`] + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve archived channels before the given date or ID. Raises ------ - InvalidArgument - If position is less than 0 or greater than the number of channels. Forbidden - You do not have permissions to edit the channel. + You do not have permissions to get archived threads. HTTPException - Editing the channel failed. + The request to get the archived threads failed. + + Yields + ------- + :class:`Thread` + The archived threads. """ - await self._edit(options, reason=reason) + before_timestamp = None + + if isinstance(before, datetime.datetime): + before_timestamp = before.isoformat() + elif before is not None: + before_timestamp = utils.snowflake_time(before.id).isoformat() + + update_before = lambda data: data['thread_metadata']['archive_timestamp'] + + while True: + retrieve = 100 + if limit is not None: + if limit <= 0: + return + retrieve = max(2, min(retrieve, limit)) + + data = await self.guild._state.http.get_public_archived_threads(self.id, before=before_timestamp, limit=retrieve) + + threads = data.get('threads', []) + for raw_thread in threads: + yield Thread(guild=self.guild, state=self.guild._state, data=raw_thread) + # Currently the API doesn't let you request less than 2 threads. + # Bail out early if we had to retrieve more than what the limit was. + if limit is not None: + limit -= 1 + if limit <= 0: + return + + if not data.get('has_more', False): + return + + before_timestamp = update_before(threads[-1]) + -class DMChannel(discord.abc.Messageable, Hashable): +class DMChannel(discord.abc.Messageable, discord.abc.PrivateChannel, Hashable): """Represents a Discord direct message channel. .. container:: operations @@ -910,42 +3213,84 @@ class DMChannel(discord.abc.Messageable, Hashable): Attributes ---------- - recipient: :class:`User` + recipient: Optional[:class:`User`] The user you are participating with in the direct message channel. + If this channel is received through the gateway, the recipient information + may not be always available. + recipients: List[:class:`User`] + The users you are participating with in the DM channel. + + .. versionadded:: 2.4 me: :class:`ClientUser` The user presenting yourself. id: :class:`int` The direct message channel ID. """ - __slots__ = ('id', 'recipient', 'me', '_state') + __slots__ = ('id', 'recipients', 'me', '_state') - def __init__(self, *, me, state, data): - self._state = state - self.recipient = state.store_user(data['recipients'][0]) - self.me = me - self.id = int(data['id']) + def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): + self._state: ConnectionState = state + self.recipients: List[User] = [state.store_user(u) for u in data.get('recipients', [])] + self.me: ClientUser = me + self.id: int = int(data['id']) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self - def __str__(self): - return 'Direct Message with %s' % self.recipient + def __str__(self) -> str: + if self.recipient: + return f'Direct Message with {self.recipient}' + return 'Direct Message with Unknown User' + + def __repr__(self) -> str: + return f'' + + @classmethod + def _from_message(cls, state: ConnectionState, channel_id: int) -> Self: + self = cls.__new__(cls) + self._state = state + self.id = channel_id + self.recipients = [] + # state.user won't be None here + self.me = state.user # type: ignore + return self - def __repr__(self): - return ''.format(self) + @property + def recipient(self) -> Optional[User]: + if self.recipients: + return self.recipients[0] + return None @property - def type(self): + def type(self) -> Literal[ChannelType.private]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.private @property - def created_at(self): - """Returns the direct message channel's creation time in UTC.""" + def guild(self) -> Optional[Guild]: + """Optional[:class:`Guild`]: The guild this DM channel belongs to. Always ``None``. + + This is mainly provided for compatibility purposes in duck typing. + + .. versionadded:: 2.0 + """ + return None + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f'https://discord.com/channels/@me/{self.id}' + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the direct message channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def permissions_for(self, user=None): + def permissions_for(self, obj: Any = None, /) -> Permissions: """Handles permission resolution for a :class:`User`. This function is there for compatibility with other channel types. @@ -956,25 +3301,61 @@ def permissions_for(self, user=None): - :attr:`~Permissions.send_tts_messages`: You cannot send TTS messages in a DM. - :attr:`~Permissions.manage_messages`: You cannot delete others messages in a DM. + - :attr:`~Permissions.create_private_threads`: There are no threads in a DM. + - :attr:`~Permissions.create_public_threads`: There are no threads in a DM. + - :attr:`~Permissions.manage_threads`: There are no threads in a DM. + - :attr:`~Permissions.send_messages_in_threads`: There are no threads in a DM. + + .. versionchanged:: 2.0 + + ``obj`` parameter is now positional-only. + + .. versionchanged:: 2.1 + + Thread related permissions are now set to ``False``. Parameters ----------- - user: :class:`User` + obj: :class:`User` The user to check permissions for. This parameter is ignored - but kept for compatibility. + but kept for compatibility with other ``permissions_for`` methods. Returns -------- :class:`Permissions` The resolved permissions. """ + return Permissions._dm_permissions() + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + .. versionchanged:: 2.0 + + ``message_id`` parameter is now positional-only. + + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + --------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) - base = Permissions.text() - base.send_tts_messages = False - base.manage_messages = False - return base -class GroupChannel(discord.abc.Messageable, Hashable): +class GroupChannel(discord.abc.Messageable, discord.abc.PrivateChannel, Hashable): """Represents a Discord group channel. .. container:: operations @@ -1003,41 +3384,40 @@ class GroupChannel(discord.abc.Messageable, Hashable): The user presenting yourself. id: :class:`int` The group channel ID. - owner: :class:`User` + owner: Optional[:class:`User`] The user that owns the group channel. - icon: Optional[:class:`str`] - The group channel's icon hash if provided. + owner_id: :class:`int` + The owner ID that owns the group channel. + + .. versionadded:: 2.0 name: Optional[:class:`str`] The group channel's name if provided. """ - __slots__ = ('id', 'recipients', 'owner', 'icon', 'name', 'me', '_state') + __slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state') - def __init__(self, *, me, state, data): - self._state = state - self.id = int(data['id']) - self.me = me + def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): + self._state: ConnectionState = state + self.id: int = int(data['id']) + self.me: ClientUser = me self._update_group(data) - def _update_group(self, data): - owner_id = utils._get_as_snowflake(data, 'owner_id') - self.icon = data.get('icon') - self.name = data.get('name') - - try: - self.recipients = [self._state.store_user(u) for u in data['recipients']] - except KeyError: - pass + def _update_group(self, data: GroupChannelPayload) -> None: + self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id') + self._icon: Optional[str] = data.get('icon') + self.name: Optional[str] = data.get('name') + self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])] - if owner_id == self.me.id: + self.owner: Optional[BaseUser] + if self.owner_id == self.me.id: self.owner = self.me else: - self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) + self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients) - async def _get_channel(self): + async def _get_channel(self) -> Self: return self - def __str__(self): + def __str__(self) -> str: if self.name: return self.name @@ -1046,25 +3426,45 @@ def __str__(self): return ', '.join(map(lambda x: x.name, self.recipients)) - def __repr__(self): - return ''.format(self) + def __repr__(self) -> str: + return f'' @property - def type(self): + def type(self) -> Literal[ChannelType.group]: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.group @property - def icon_url(self): - """:class:`Asset`: Returns the channel's icon asset if available.""" - return Asset._from_icon(self._state, self, 'channel') + def guild(self) -> Optional[Guild]: + """Optional[:class:`Guild`]: The guild this group channel belongs to. Always ``None``. + + This is mainly provided for compatibility purposes in duck typing. + + .. versionadded:: 2.0 + """ + return None + + @property + def icon(self) -> Optional[Asset]: + """Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" + if self._icon is None: + return None + return Asset._from_icon(self._state, self.id, self._icon, path='channel') @property - def created_at(self): + def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def permissions_for(self, user): + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f'https://discord.com/channels/@me/{self.id}' + + def permissions_for(self, obj: Snowflake, /) -> Permissions: """Handles permission resolution for a :class:`User`. This function is there for compatibility with other channel types. @@ -1073,14 +3473,26 @@ def permissions_for(self, user): This returns all the Text related permissions set to ``True`` except: - - send_tts_messages: You cannot send TTS messages in a DM. - - manage_messages: You cannot delete others messages in a DM. + - :attr:`~Permissions.send_tts_messages`: You cannot send TTS messages in a DM. + - :attr:`~Permissions.manage_messages`: You cannot delete others messages in a DM. + - :attr:`~Permissions.create_private_threads`: There are no threads in a DM. + - :attr:`~Permissions.create_public_threads`: There are no threads in a DM. + - :attr:`~Permissions.manage_threads`: There are no threads in a DM. + - :attr:`~Permissions.send_messages_in_threads`: There are no threads in a DM. This also checks the kick_members permission if the user is the owner. + .. versionchanged:: 2.0 + + ``obj`` parameter is now positional-only. + + .. versionchanged:: 2.1 + + Thread related permissions are now set to ``False``. + Parameters ----------- - user: :class:`User` + obj: :class:`~discord.abc.Snowflake` The user to check permissions for. Returns @@ -1089,126 +3501,184 @@ def permissions_for(self, user): The resolved permissions for the user. """ - base = Permissions.text() - base.send_tts_messages = False - base.manage_messages = False + base = Permissions._dm_permissions() base.mention_everyone = True - if user.id == self.owner.id: + if obj.id == self.owner_id: base.kick_members = True return base - async def add_recipients(self, *recipients): - r"""|coro| - - Adds recipients to this group. + async def leave(self) -> None: + """|coro| - A group can only have a maximum of 10 members. - Attempting to add more ends up in an exception. To - add a recipient to the group, you must have a relationship - with the user of type :attr:`RelationshipType.friend`. + Leave the group. - Parameters - ----------- - \*recipients: :class:`User` - An argument list of users to add to this group. + If you are the only one in the group, this deletes it as well. Raises ------- HTTPException - Adding a recipient to this group failed. + Leaving the group failed. """ - # TODO: wait for the corresponding WS event + await self._state.http.leave_group(self.id) - req = self._state.http.add_group_recipient - for recipient in recipients: - await req(self.id, recipient.id) - async def remove_recipients(self, *recipients): - r"""|coro| +class PartialMessageable(discord.abc.Messageable, Hashable): + """Represents a partial messageable to aid with working messageable channels when + only a channel ID is present. - Removes recipients from this group. + The only way to construct this class is through :meth:`Client.get_partial_messageable`. - Parameters - ----------- - \*recipients: :class:`User` - An argument list of users to remove from this group. + Note that this class is trimmed down and has no rich attributes. - Raises - ------- - HTTPException - Removing a recipient from this group failed. - """ + .. versionadded:: 2.0 - # TODO: wait for the corresponding WS event + .. container:: operations + + .. describe:: x == y - req = self._state.http.remove_group_recipient - for recipient in recipients: - await req(self.id, recipient.id) + Checks if two partial messageables are equal. - async def edit(self, **fields): - """|coro| + .. describe:: x != y + + Checks if two partial messageables are not equal. + + .. describe:: hash(x) - Edits the group. + Returns the partial messageable's hash. + + Attributes + ----------- + id: :class:`int` + The channel ID associated with this partial messageable. + guild_id: Optional[:class:`int`] + The guild ID associated with this partial messageable. + type: Optional[:class:`ChannelType`] + The channel type associated with this partial messageable, if given. + """ + + def __init__(self, state: ConnectionState, id: int, guild_id: Optional[int] = None, type: Optional[ChannelType] = None): + self._state: ConnectionState = state + self.id: int = id + self.guild_id: Optional[int] = guild_id + self.type: Optional[ChannelType] = type + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={self.id} type={self.type!r}>' + + async def _get_channel(self) -> PartialMessageable: + return self + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`Guild`]: The guild this partial messageable is in.""" + return self._state._get_guild(self.guild_id) + + @property + def jump_url(self) -> str: + """:class:`str`: Returns a URL that allows the client to jump to the channel.""" + if self.guild_id is None: + return f'https://discord.com/channels/@me/{self.id}' + return f'https://discord.com/channels/{self.guild_id}/{self.id}' + + @property + def created_at(self) -> datetime.datetime: + """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def permissions_for(self, obj: Any = None, /) -> Permissions: + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Since partial messageables cannot reasonably have the concept of + permissions, this will always return :meth:`Permissions.none`. Parameters ----------- - name: Optional[:class:`str`] - The new name to change the group to. - Could be ``None`` to remove the name. - icon: Optional[:class:`bytes`] - A :term:`py:bytes-like object` representing the new icon. - Could be ``None`` to remove the icon. + obj: :class:`User` + The user to check permissions for. This parameter is ignored + but kept for compatibility with other ``permissions_for`` methods. - Raises - ------- - HTTPException - Editing the group failed. + Returns + -------- + :class:`Permissions` + The resolved permissions. """ - try: - icon_bytes = fields['icon'] - except KeyError: - pass - else: - if icon_bytes is not None: - fields['icon'] = utils._bytes_to_base64_data(icon_bytes) + return Permissions.none() - data = await self._state.http.edit_group(self.id, **fields) - self._update_group(data) + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the channel. - async def leave(self): - """|coro| + .. versionadded:: 2.5 + """ + return f'<#{self.id}>' - Leave the group. + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. - If you are the only one in the group, this deletes it as well. + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. - Raises - ------- - HTTPException - Leaving the group failed. + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + --------- + :class:`PartialMessage` + The partial message. """ - await self._state.http.leave_group(self.id) + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) -def _channel_factory(channel_type): + +def _guild_channel_factory(channel_type: int): value = try_enum(ChannelType, channel_type) if value is ChannelType.text: return TextChannel, value elif value is ChannelType.voice: return VoiceChannel, value - elif value is ChannelType.private: - return DMChannel, value elif value is ChannelType.category: return CategoryChannel, value - elif value is ChannelType.group: - return GroupChannel, value elif value is ChannelType.news: return TextChannel, value - elif value is ChannelType.store: - return StoreChannel, value + elif value is ChannelType.stage_voice: + return StageChannel, value + elif value is ChannelType.forum: + return ForumChannel, value + elif value is ChannelType.media: + return ForumChannel, value else: return None, value + + +def _channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return cls, value + + +def _threaded_channel_factory(channel_type: int): + cls, value = _channel_factory(channel_type) + if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + return Thread, value + return cls, value + + +def _threaded_guild_channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + return Thread, value + return cls, value diff --git a/discord/client.py b/discord/client.py index e4804594247b..88c390be0946 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,197 +22,353 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio -from collections import namedtuple +import datetime import logging -import signal -import sys -import traceback +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + Generator, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + TypedDict, + Union, + overload, +) import aiohttp -import websockets -from .user import User, Profile -from .asset import Asset +from .sku import SKU, Entitlement +from .user import User, ClientUser from .invite import Invite +from .template import Template from .widget import Widget -from .guild import Guild -from .channel import _channel_factory -from .enums import ChannelType -from .member import Member +from .guild import Guild, GuildPreview +from .emoji import Emoji +from .channel import _threaded_channel_factory, PartialMessageable +from .enums import ChannelType, EntitlementOwnerType +from .mentions import AllowedMentions from .errors import * -from .enums import Status, VoiceRegion +from .enums import Status +from .flags import ApplicationFlags, Intents from .gateway import * -from .activity import _ActivityTag, create_activity +from .activity import ActivityTypes, BaseActivity, create_activity from .voice_client import VoiceClient from .http import HTTPClient from .state import ConnectionState from . import utils +from .utils import MISSING, time_snowflake, deprecated from .object import Object from .backoff import ExponentialBackoff from .webhook import Webhook -from .iterators import GuildIterator from .appinfo import AppInfo +from .ui.view import BaseView +from .ui.dynamic import DynamicItem +from .stage_instance import StageInstance +from .threads import Thread +from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory +from .soundboard import SoundboardDefaultSound, SoundboardSound + +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self, Unpack + + from .abc import Messageable, PrivateChannel, Snowflake, SnowflakeTime + from .app_commands import Command, ContextMenu + from .automod import AutoModAction, AutoModRule + from .channel import DMChannel, GroupChannel + from .ext.commands import AutoShardedBot, Bot, Context, CommandError + from .guild import GuildChannel + from .integrations import Integration + from .interactions import Interaction + from .member import Member, VoiceState + from .message import Message + from .raw_models import ( + RawAppCommandPermissionsUpdateEvent, + RawBulkMessageDeleteEvent, + RawIntegrationDeleteEvent, + RawMemberRemoveEvent, + RawMessageDeleteEvent, + RawMessageUpdateEvent, + RawReactionActionEvent, + RawReactionClearEmojiEvent, + RawReactionClearEvent, + RawThreadDeleteEvent, + RawThreadMembersUpdate, + RawThreadUpdateEvent, + RawTypingEvent, + RawPollVoteActionEvent, + ) + from .reaction import Reaction + from .role import Role + from .scheduled_event import ScheduledEvent + from .threads import ThreadMember + from .types.guild import Guild as GuildPayload + from .ui.item import Item + from .voice_client import VoiceProtocol + from .audit_logs import AuditLogEntry + from .poll import PollAnswer + from .subscription import Subscription + from .flags import MemberCacheFlags + + class _ClientOptions(TypedDict, total=False): + max_messages: Optional[int] + proxy: Optional[str] + proxy_auth: Optional[aiohttp.BasicAuth] + shard_id: Optional[int] + shard_count: Optional[int] + application_id: int + member_cache_flags: MemberCacheFlags + chunk_guilds_at_startup: bool + status: Optional[Status] + activity: Optional[BaseActivity] + allowed_mentions: Optional[AllowedMentions] + heartbeat_timeout: float + guild_ready_timeout: float + assume_unsync_clock: bool + enable_debug_events: bool + enable_raw_presences: bool + http_trace: aiohttp.TraceConfig + max_ratelimit_timeout: Optional[float] + connector: Optional[aiohttp.BaseConnector] + + +# fmt: off +__all__ = ( + 'Client', +) +# fmt: on + +T = TypeVar('T') +Coro = Coroutine[Any, Any, T] +CoroT = TypeVar('CoroT', bound=Callable[..., Coro[Any]]) + +_log = logging.getLogger(__name__) + + +class _LoopSentinel: + __slots__ = () + + def __getattr__(self, attr: str) -> None: + msg = ( + 'loop attribute cannot be accessed in non-async contexts. ' + 'Consider using either an asynchronous main function and passing it to asyncio.run or ' + 'using asynchronous initialisation hooks such as Client.setup_hook' + ) + raise AttributeError(msg) + + +_loop: Any = _LoopSentinel() -log = logging.getLogger(__name__) - -def _cancel_tasks(loop): - try: - task_retriever = asyncio.Task.all_tasks - except AttributeError: - # future proofing for 3.9 I guess - task_retriever = asyncio.all_tasks - - tasks = {t for t in task_retriever(loop=loop) if not t.done()} - - if not tasks: - return - - log.info('Cleaning up after %d tasks.', len(tasks)) - for task in tasks: - task.cancel() - - loop.run_until_complete(asyncio.gather(*tasks, loop=loop, return_exceptions=True)) - log.info('All tasks finished cancelling.') - - for task in tasks: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task - }) - -def _cleanup_loop(loop): - try: - _cancel_tasks(loop) - if sys.version_info >= (3, 6): - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - log.info('Closing the event loop.') - loop.close() - -class _ClientEventTask(asyncio.Task): - def __init__(self, original_coro, event_name, coro, *, loop): - super().__init__(coro, loop=loop) - self.__event_name = event_name - self.__original_coro = original_coro - - def __repr__(self): - info = [ - ('state', self._state.lower()), - ('event', self.__event_name), - ('coro', repr(self.__original_coro)), - ] - if self._exception is not None: - info.append(('exception', repr(self._exception))) - return ''.format(' '.join('%s=%s' % t for t in info)) class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. + .. container:: operations + + .. describe:: async with x + + Asynchronously initialises the client and automatically cleans up. + + .. versionadded:: 2.0 + A number of options can be passed to the :class:`Client`. Parameters ----------- max_messages: Optional[:class:`int`] The maximum number of messages to store in the internal message cache. - This defaults to 5000. Passing in ``None`` or a value less than 100 - will use the default instead of the passed in value. - loop: Optional[:class:`asyncio.AbstractEventLoop`] - The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations. - Defaults to ``None``, in which case the default event loop is used via - :func:`asyncio.get_event_loop()`. - connector: :class:`aiohttp.BaseConnector` - The connector to use for connection pooling. + This defaults to ``1000``. Passing in ``None`` disables the message cache. + + .. versionchanged:: 1.3 + Allow disabling the message cache and change the default size to ``1000``. proxy: Optional[:class:`str`] Proxy URL. proxy_auth: Optional[:class:`aiohttp.BasicAuth`] An object that represents proxy HTTP Basic Authorization. shard_id: Optional[:class:`int`] - Integer starting at 0 and less than :attr:`.shard_count`. + Integer starting at ``0`` and less than :attr:`.shard_count`. shard_count: Optional[:class:`int`] The total number of shards. - fetch_offline_members: :class:`bool` - Indicates if :func:`.on_ready` should be delayed to fetch all offline - members from the guilds the bot belongs to. If this is ``False``\, then - no offline members are received and :meth:`request_offline_members` - must be used to fetch the offline members of the guild. + application_id: :class:`int` + The client's application ID. + intents: :class:`Intents` + The intents that you want to enable for the session. This is a way of + disabling and enabling certain gateway events from triggering and being sent. + + .. versionadded:: 1.5 + + .. versionchanged:: 2.0 + Parameter is now required. + member_cache_flags: :class:`MemberCacheFlags` + Allows for finer control over how the library caches members. + If not given, defaults to cache as much as possible with the + currently selected intents. + + .. versionadded:: 1.5 + chunk_guilds_at_startup: :class:`bool` + Indicates if :func:`.on_ready` should be delayed to chunk all guilds + at start-up if necessary. This operation is incredibly slow for large + amounts of guilds. The default is ``True`` if :attr:`Intents.members` + is ``True``. + + .. versionadded:: 1.5 status: Optional[:class:`.Status`] A status to start your presence with upon logging on to Discord. - activity: Optional[Union[:class:`.Activity`, :class:`.Game`, :class:`.Streaming`]] + activity: Optional[:class:`.BaseActivity`] An activity to start your presence with upon logging on to Discord. + allowed_mentions: Optional[:class:`AllowedMentions`] + Control how the client handles mentions by default on every message sent. + + .. versionadded:: 1.4 heartbeat_timeout: :class:`float` The maximum numbers of seconds before timing out and restarting the WebSocket in the case of not receiving a HEARTBEAT_ACK. Useful if processing the initial packets take too long to the point of disconnecting you. The default timeout is 60 seconds. + guild_ready_timeout: :class:`float` + The maximum number of seconds to wait for the GUILD_CREATE stream to end before + preparing the member cache and firing READY. The default timeout is 2 seconds. + + .. versionadded:: 1.4 + assume_unsync_clock: :class:`bool` + Whether to assume the system clock is unsynced. This applies to the ratelimit handling + code. If this is set to ``True``, the default, then the library uses the time to reset + a rate limit bucket given by Discord. If this is ``False`` then your system clock is + used to calculate how long to sleep for. If this is set to ``False`` it is recommended to + sync your system clock to Google's NTP server. + + .. versionadded:: 1.3 + enable_debug_events: :class:`bool` + Whether to enable events that are useful only for debugging gateway related information. + + Right now this involves :func:`on_socket_raw_receive` and :func:`on_socket_raw_send`. If + this is ``False`` then those events will not be dispatched (due to performance considerations). + To enable these events, this must be set to ``True``. Defaults to ``False``. + + .. versionadded:: 2.0 + enable_raw_presences: :class:`bool` + Whether to manually enable or disable the :func:`on_raw_presence_update` event. + + Setting this flag to ``True`` requires :attr:`Intents.presences` to be enabled. + + By default, this flag is set to ``True`` only when :attr:`Intents.presences` is enabled and :attr:`Intents.members` + is disabled, otherwise it's set to ``False``. + + .. versionadded:: 2.5 + http_trace: :class:`aiohttp.TraceConfig` + The trace configuration to use for tracking HTTP requests the library does using ``aiohttp``. + This allows you to check requests the library is using. For more information, check the + `aiohttp documentation `_. + + .. versionadded:: 2.0 + max_ratelimit_timeout: Optional[:class:`float`] + The maximum number of seconds to wait when a non-global rate limit is encountered. + If a request requires sleeping for more than the seconds passed in, then + :exc:`~discord.RateLimited` will be raised. By default, there is no timeout limit. + In order to prevent misuse and unnecessary bans, the minimum value this can be + set to is ``30.0`` seconds. + + .. versionadded:: 2.0 + connector: Optional[:class:`aiohttp.BaseConnector`] + The aiohttp connector to use for this client. This can be used to control underlying aiohttp + behavior, such as setting a dns resolver or sslcontext. + + .. versionadded:: 2.5 Attributes ----------- ws The websocket gateway the client is currently connected to. Could be ``None``. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the client uses for HTTP requests and websocket operations. """ - def __init__(self, *, loop=None, **options): - self.ws = None - self.loop = asyncio.get_event_loop() if loop is None else loop - self._listeners = {} - self.shard_id = options.get('shard_id') - self.shard_count = options.get('shard_count') - - connector = options.pop('connector', None) - proxy = options.pop('proxy', None) - proxy_auth = options.pop('proxy_auth', None) - self.http = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, loop=self.loop) - - self._handlers = { - 'ready': self._handle_ready + + def __init__(self, *, intents: Intents, **options: Unpack[_ClientOptions]) -> None: + self.loop: asyncio.AbstractEventLoop = _loop + # self.ws is set in the connect method + self.ws: DiscordWebSocket = None # type: ignore + self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} + self.shard_id: Optional[int] = options.get('shard_id') + self.shard_count: Optional[int] = options.get('shard_count') + + connector: Optional[aiohttp.BaseConnector] = options.get('connector', None) + proxy: Optional[str] = options.pop('proxy', None) + proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) + unsync_clock: bool = options.pop('assume_unsync_clock', True) + http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None) + max_ratelimit_timeout: Optional[float] = options.pop('max_ratelimit_timeout', None) + self.http: HTTPClient = HTTPClient( + self.loop, + connector, + proxy=proxy, + proxy_auth=proxy_auth, + unsync_clock=unsync_clock, + http_trace=http_trace, + max_ratelimit_timeout=max_ratelimit_timeout, + ) + + self._handlers: Dict[str, Callable[..., None]] = { + 'ready': self._handle_ready, } - self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers, - syncer=self._syncer, http=self.http, loop=self.loop, **options) + self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = { + 'before_identify': self._call_before_identify_hook, + } + self._enable_debug_events: bool = options.pop('enable_debug_events', False) + self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options) self._connection.shard_count = self.shard_count - self._closed = False - self._ready = asyncio.Event(loop=self.loop) - self._connection._get_websocket = lambda g: self.ws + self._closing_task: Optional[asyncio.Task[None]] = None + self._ready: asyncio.Event = MISSING + self._application: Optional[AppInfo] = None + self._connection._get_websocket = self._get_websocket + self._connection._get_client = lambda: self if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False - log.warning("PyNaCl is not installed, voice will NOT be supported") + _log.warning('PyNaCl is not installed, voice will NOT be supported') + + async def __aenter__(self) -> Self: + await self._async_setup_hook() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + # This avoids double-calling a user-provided .close() + if self._closing_task: + await self._closing_task + else: + await self.close() # internals - async def _syncer(self, guilds): - await self.ws.request_sync(guilds) - - async def _chunker(self, guild): - try: - guild_id = guild.id - except AttributeError: - guild_id = [s.id for s in guild] - - payload = { - 'op': 8, - 'd': { - 'guild_id': guild_id, - 'query': '', - 'limit': 0 - } - } + def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: + return self.ws - await self.ws.send_as_json(payload) + def _get_state(self, **options: Any) -> ConnectionState[Self]: + return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, hooks=self._hooks, http=self.http, **options) - def _handle_ready(self): + def _handle_ready(self) -> None: self._ready.set() @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. This could be referred to as the Discord WebSocket protocol latency. @@ -222,32 +376,66 @@ def latency(self): ws = self.ws return float('nan') if not ws else ws.latency + def is_ws_ratelimited(self) -> bool: + """:class:`bool`: Whether the websocket is currently rate limited. + + This can be useful to know when deciding whether you should query members + using HTTP or via the gateway. + + .. versionadded:: 1.6 + """ + if self.ws: + return self.ws.is_ratelimited() + return False + @property - def user(self): - """Optional[:class:`.ClientUser`]: Represents the connected client. None if not logged in.""" + def user(self) -> Optional[ClientUser]: + """Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in.""" return self._connection.user @property - def guilds(self): - """List[:class:`.Guild`]: The guilds that the connected client is a member of.""" + def guilds(self) -> Sequence[Guild]: + """Sequence[:class:`.Guild`]: The guilds that the connected client is a member of.""" return self._connection.guilds @property - def emojis(self): - """List[:class:`.Emoji`]: The emojis that the connected client has.""" + def emojis(self) -> Sequence[Emoji]: + """Sequence[:class:`.Emoji`]: The emojis that the connected client has. + + .. note:: + + This does not include the emojis that are owned by the application. + Use :meth:`.fetch_application_emoji` to get those. + """ return self._connection.emojis @property - def cached_messages(self): + def stickers(self) -> Sequence[GuildSticker]: + """Sequence[:class:`.GuildSticker`]: The stickers that the connected client has. + + .. versionadded:: 2.0 + """ + return self._connection.stickers + + @property + def soundboard_sounds(self) -> List[SoundboardSound]: + """List[:class:`.SoundboardSound`]: The soundboard sounds that the connected client has. + + .. versionadded:: 2.5 + """ + return self._connection.soundboard_sounds + + @property + def cached_messages(self) -> Sequence[Message]: """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. - .. versionadded:: 1.1.0 + .. versionadded:: 1.1 """ - return utils.SequenceProxy(self._connection._messages) + return utils.SequenceProxy(self._connection._messages or []) @property - def private_channels(self): - """List[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on. + def private_channels(self) -> Sequence[PrivateChannel]: + """Sequence[:class:`.abc.PrivateChannel`]: The private channels that the connected client is participating on. .. note:: @@ -257,15 +445,61 @@ def private_channels(self): return self._connection.private_channels @property - def voice_clients(self): - """List[:class:`.VoiceClient`]: Represents a list of voice connections.""" + def voice_clients(self) -> List[VoiceProtocol]: + """List[:class:`.VoiceProtocol`]: Represents a list of voice connections. + + These are usually :class:`.VoiceClient` instances. + """ return self._connection.voice_clients - def is_ready(self): - """Specifies if the client's internal cache is ready for use.""" - return self._ready.is_set() + @property + def application_id(self) -> Optional[int]: + """Optional[:class:`int`]: The client's application ID. + + If this is not passed via ``__init__`` then this is retrieved + through the gateway when an event contains the data or after a call + to :meth:`~discord.Client.login`. Usually after :func:`~discord.on_connect` + is called. + + .. versionadded:: 2.0 + """ + return self._connection.application_id + + @property + def application_flags(self) -> ApplicationFlags: + """:class:`~discord.ApplicationFlags`: The client's application flags. + + .. versionadded:: 2.0 + """ + return self._connection.application_flags + + @property + def application(self) -> Optional[AppInfo]: + """Optional[:class:`~discord.AppInfo`]: The client's application info. + + This is retrieved on :meth:`~discord.Client.login` and is not updated + afterwards. This allows populating the application_id without requiring a + gateway connection. - async def _run_event(self, coro, event_name, *args, **kwargs): + This is ``None`` if accessed before :meth:`~discord.Client.login` is called. + + .. seealso:: The :meth:`~discord.Client.application_info` API call + + .. versionadded:: 2.0 + """ + return self._application + + def is_ready(self) -> bool: + """:class:`bool`: Specifies if the client's internal cache is ready for use.""" + return self._ready is not MISSING and self._ready.is_set() + + async def _run_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -276,13 +510,19 @@ async def _run_event(self, coro, event_name, *args, **kwargs): except asyncio.CancelledError: pass - def _schedule_event(self, coro, event_name, *args, **kwargs): + def _schedule_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task - return _ClientEventTask(original_coro=coro, event_name=event_name, coro=wrapped, loop=self.loop) + return self.loop.create_task(wrapped, name=f'discord.py: {event_name}') - def dispatch(self, event, *args, **kwargs): - log.debug('Dispatching event %s', event) + def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: + _log.debug('Dispatching event %s', event) method = 'on_' + event listeners = self._listeners.get(event) @@ -321,112 +561,134 @@ def dispatch(self, event, *args, **kwargs): else: self._schedule_event(coro, method, *args, **kwargs) - async def on_error(self, event_method, *args, **kwargs): + async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None: """|coro| The default error handler provided by the client. - By default this prints to :data:`sys.stderr` however it could be + By default this logs to the library logger however it could be overridden to have a different implementation. Check :func:`~discord.on_error` for more details. + + .. versionchanged:: 2.0 + + ``event_method`` parameter is now positional-only + and instead of writing to ``sys.stderr`` it logs instead. """ - print('Ignoring exception in {}'.format(event_method), file=sys.stderr) - traceback.print_exc() + _log.exception('Ignoring exception in %s', event_method) - async def request_offline_members(self, *guilds): - r"""|coro| + # hooks - Requests previously offline members from the guild to be filled up - into the :attr:`.Guild.members` cache. This function is usually not - called. It should only be used if you have the ``fetch_offline_members`` - parameter set to ``False``. + async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: + # This hook is an internal hook that actually calls the public one. + # It allows the library to have its own hook without stepping on the + # toes of those who need to override their own hook. + await self.before_identify_hook(shard_id, initial=initial) - When the client logs on and connects to the websocket, Discord does - not provide the library with offline members if the number of members - in the guild is larger than 250. You can check if a guild is large - if :attr:`.Guild.large` is ``True``. + async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: + """|coro| - Parameters - ----------- - \*guilds: :class:`.Guild` - An argument list of guilds to request offline members for. + A hook that is called before IDENTIFYing a session. This is useful + if you wish to have more control over the synchronization of multiple + IDENTIFYing clients. - Raises - ------- - :exc:`.InvalidArgument` - If any guild is unavailable or not large in the collection. + The default implementation sleeps for 5 seconds. + + .. versionadded:: 1.4 + + Parameters + ------------ + shard_id: :class:`int` + The shard ID that requested being IDENTIFY'd + initial: :class:`bool` + Whether this IDENTIFY is the first initial IDENTIFY. """ - if any(not g.large or g.unavailable for g in guilds): - raise InvalidArgument('An unavailable or non-large guild was passed.') - await self._connection.request_offline_members(guilds) + if not initial: + await asyncio.sleep(5.0) - # login state management + async def _async_setup_hook(self) -> None: + # Called whenever the client needs to initialise asyncio objects with a running loop + loop = asyncio.get_running_loop() + self.loop = loop + self.http.loop = loop + self._connection.loop = loop + + self._ready = asyncio.Event() - async def login(self, token, *, bot=True): + async def setup_hook(self) -> None: """|coro| - Logs in the client with the specified credentials. + A coroutine to be called to setup the bot, by default this is blank. - This function can be used in two different ways. + To perform asynchronous setup after the bot is logged in but before + it has connected to the Websocket, overwrite this coroutine. + + This is only called once, in :meth:`login`, and will be called before + any events are dispatched, making it a better solution than doing such + setup in the :func:`~discord.on_ready` event. .. warning:: - Logging on with a user token is against the Discord - `Terms of Service `_ - and doing so might potentially get your account banned. - Use this at your own risk. + Since this is called *before* the websocket connection is made therefore + anything that waits for the websocket will deadlock, this includes things + like :meth:`wait_for` and :meth:`wait_until_ready`. + + .. versionadded:: 2.0 + """ + pass + + # login state management + + async def login(self, token: str) -> None: + """|coro| + + Logs in the client with the specified credentials and + calls the :meth:`setup_hook`. + Parameters ----------- token: :class:`str` The authentication token. Do not prefix this token with anything as the library will do it for you. - bot: :class:`bool` - Keyword argument that specifies if the account logging on is a bot - token or not. Raises ------ - :exc:`.LoginFailure` + LoginFailure The wrong credentials are passed. - :exc:`.HTTPException` + HTTPException An unknown HTTP related error occurred, usually when it isn't 200 or the known incorrect credentials passing status code. """ - log.info('logging in using static token') - await self.http.static_login(token, bot=bot) - self._connection.is_bot = bot + _log.info('logging in using static token') - async def logout(self): - """|coro| + if self.loop is _loop: + await self._async_setup_hook() - Logs out of Discord and closes all connections. + if not isinstance(token, str): + raise TypeError(f'expected token to be a str, received {token.__class__.__name__} instead') + token = token.strip() - .. note:: + data = await self.http.static_login(token) + self._connection.user = ClientUser(state=self._connection, data=data) + self._application = await self.application_info() + if self._connection.application_id is None: + self._connection.application_id = self._application.id - This is just an alias to :meth:`close`. If you want - to do extraneous cleanup when subclassing, it is suggested - to override :meth:`close` instead. - """ - await self.close() + if self._application.interactions_endpoint_url is not None: + _log.warning( + 'Application has an interaction endpoint URL set, this means registered components and app commands will not be received by the library.' + ) - async def _connect(self): - coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id) - self.ws = await asyncio.wait_for(coro, timeout=180.0, loop=self.loop) - while True: - try: - await self.ws.poll_event() - except ResumeWebSocket: - log.info('Got a request to RESUME the websocket.') - self.dispatch('disconnect') - coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, - sequence=self.ws.sequence, resume=True) - self.ws = await asyncio.wait_for(coro, timeout=180.0, loop=self.loop) + if not self._connection.application_flags: + self._connection.application_flags = self._application.flags + + await self.setup_hook() - async def connect(self, *, reconnect=True): + async def connect(self, *, reconnect: bool = True) -> None: """|coro| Creates a websocket connection and lets the websocket listen @@ -444,26 +706,40 @@ async def connect(self, *, reconnect=True): Raises ------- - :exc:`.GatewayNotFound` + GatewayNotFound If the gateway to connect to Discord is not found. Usually if this is thrown then there is a Discord API outage. - :exc:`.ConnectionClosed` + ConnectionClosed The websocket connection has been terminated. """ backoff = ExponentialBackoff() + ws_params = { + 'initial': True, + 'shard_id': self.shard_id, + } while not self.is_closed(): try: - await self._connect() - except (OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError, - websockets.InvalidHandshake, - websockets.WebSocketProtocolError) as exc: - + coro = DiscordWebSocket.from_client(self, **ws_params) + self.ws = await asyncio.wait_for(coro, timeout=60.0) + ws_params['initial'] = False + while True: + await self.ws.poll_event() + except ReconnectWebSocket as e: + _log.debug('Got a request to %s the websocket.', e.op) + self.dispatch('disconnect') + ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) + if e.resume: + ws_params['gateway'] = self.ws.gateway + continue + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: self.dispatch('disconnect') if not reconnect: await self.close() @@ -475,74 +751,111 @@ async def connect(self, *, reconnect=True): if self.is_closed(): return + # If we get connection reset by peer then try to RESUME + if isinstance(exc, OSError) and exc.errno in (54, 10054): + ws_params.update( + sequence=self.ws.sequence, + gateway=self.ws.gateway, + initial=False, + resume=True, + session=self.ws.session_id, + ) + continue + # We should only get this when an unhandled close code happens, # such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc) # sometimes, discord sends us 1000 for unknown reasons so we should reconnect # regardless and rely on is_closed instead if isinstance(exc, ConnectionClosed): + if exc.code == 4014: + raise PrivilegedIntentsRequired(exc.shard_id) from None if exc.code != 1000: await self.close() raise retry = backoff.delay() - log.exception("Attempting a reconnect in %.2fs", retry) - await asyncio.sleep(retry, loop=self.loop) - - async def close(self): + _log.exception('Attempting a reconnect in %.2fs', retry) + await asyncio.sleep(retry) + # Always try to RESUME the connection + # If the connection is not RESUME-able then the gateway will invalidate the session. + # This is apparently what the official Discord client does. + ws_params.update( + sequence=self.ws.sequence, + gateway=self.ws.gateway, + resume=True, + session=self.ws.session_id, + ) + + async def close(self) -> None: """|coro| Closes the connection to Discord. """ - if self._closed: - return + if self._closing_task: + return await self._closing_task - await self.http.close() - self._closed = True + async def _close(): + await self._connection.close() - for voice in self.voice_clients: - try: - await voice.disconnect() - except Exception: - # if an error happens during disconnects, disregard it. - pass + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - if self.ws is not None and self.ws.open: - await self.ws.close() + await self.http.close() - self._ready.clear() + if self._ready is not MISSING: + self._ready.clear() + + self.loop = MISSING - def clear(self): + self._closing_task = asyncio.create_task(_close()) + await self._closing_task + + def clear(self) -> None: """Clears the internal state of the bot. After this, the bot can be considered "re-opened", i.e. :meth:`is_closed` and :meth:`is_ready` both return ``False`` along with the bot's internal cache cleared. """ - self._closed = False + self._closing_task = None self._ready.clear() self._connection.clear() - self.http.recreate() + self.http.clear() - async def start(self, *args, **kwargs): + async def start(self, token: str, *, reconnect: bool = True) -> None: """|coro| A shorthand coroutine for :meth:`login` + :meth:`connect`. + Parameters + ----------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + reconnect: :class:`bool` + If we should attempt reconnecting, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). + Raises ------- TypeError An unexpected keyword argument was received. """ - bot = kwargs.pop('bot', True) - reconnect = kwargs.pop('reconnect', True) - - if kwargs: - raise TypeError("unexpected keyword argument(s) %s" % list(kwargs.keys())) - - await self.login(*args, bot=bot) + await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args, **kwargs): + def run( + self, + token: str, + *, + reconnect: bool = True, + log_handler: Optional[logging.Handler] = MISSING, + log_formatter: logging.Formatter = MISSING, + log_level: int = MISSING, + root_logger: bool = False, + ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -550,112 +863,312 @@ def run(self, *args, **kwargs): function should not be used. Use :meth:`start` coroutine or :meth:`connect` + :meth:`login`. - Roughly Equivalent to: :: - - try: - loop.run_until_complete(start(*args, **kwargs)) - except KeyboardInterrupt: - loop.run_until_complete(logout()) - # cancel all tasks lingering - finally: - loop.close() + This function also sets up the logging library to make it easier + for beginners to know what is going on with the library. For more + advanced users, this can be disabled by passing ``None`` to + the ``log_handler`` parameter. .. warning:: This function must be the last function to call due to the fact that it is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. - """ - loop = self.loop - try: - loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) - loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) - except NotImplementedError: - pass + Parameters + ----------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + reconnect: :class:`bool` + If we should attempt reconnecting, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). + log_handler: Optional[:class:`logging.Handler`] + The log handler to use for the library's logger. If this is ``None`` + then the library will not set up anything logging related. Logging + will still work if ``None`` is passed, though it is your responsibility + to set it up. + + The default log handler if not provided is :class:`logging.StreamHandler`. + + .. versionadded:: 2.0 + log_formatter: :class:`logging.Formatter` + The formatter to use with the given log handler. If not provided then it + defaults to a colour based logging formatter (if available). + + .. versionadded:: 2.0 + log_level: :class:`int` + The default log level for the library's logger. This is only applied if the + ``log_handler`` parameter is not ``None``. Defaults to ``logging.INFO``. + + .. versionadded:: 2.0 + root_logger: :class:`bool` + Whether to set up the root logger rather than the library logger. + By default, only the library logger (``'discord'``) is set up. If this + is set to ``True`` then the root logger is set up as well. + + Defaults to ``False``. + + .. versionadded:: 2.0 + """ async def runner(): - try: - await self.start(*args, **kwargs) - finally: - await self.close() + async with self: + await self.start(token, reconnect=reconnect) - def stop_loop_on_completion(f): - loop.stop() + if log_handler is not None: + utils.setup_logging( + handler=log_handler, + formatter=log_formatter, + level=log_level, + root=root_logger, + ) - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() + asyncio.run(runner()) except KeyboardInterrupt: - log.info('Received signal to terminate bot and event loop.') - finally: - future.remove_done_callback(stop_loop_on_completion) - log.info('Cleaning up tasks.') - _cleanup_loop(loop) - - if not future.cancelled(): - return future.result() + # nothing to do here + # `asyncio.run` handles the loop cleanup + # and `self.start` closes all sockets and the HTTPClient instance. + return # properties - def is_closed(self): - """Indicates if the websocket connection is closed.""" - return self._closed + def is_closed(self) -> bool: + """:class:`bool`: Indicates if the websocket connection is closed.""" + return self._closing_task is not None @property - def activity(self): - """Optional[Union[:class:`.Activity`, :class:`.Game`, :class:`.Streaming`]]: The activity being used upon + def activity(self) -> Optional[ActivityTypes]: + """Optional[:class:`.BaseActivity`]: The activity being used upon logging in. """ - return create_activity(self._connection._activity) + return create_activity(self._connection._activity, self._connection) @activity.setter - def activity(self, value): + def activity(self, value: Optional[ActivityTypes]) -> None: if value is None: self._connection._activity = None - elif isinstance(value, _ActivityTag): - self._connection._activity = value.to_dict() + elif isinstance(value, BaseActivity): + # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] + self._connection._activity = value.to_dict() # type: ignore + else: + raise TypeError('activity must derive from BaseActivity.') + + @property + def status(self) -> Status: + """:class:`.Status`: + The status being used upon logging on to Discord. + + .. versionadded: 2.0 + """ + if self._connection._status in set(state.value for state in Status): + return Status(self._connection._status) + return Status.online + + @status.setter + def status(self, value: Status) -> None: + if value is Status.offline: + self._connection._status = 'invisible' + elif isinstance(value, Status): + self._connection._status = str(value) + else: + raise TypeError('status must derive from Status.') + + @property + def allowed_mentions(self) -> Optional[AllowedMentions]: + """Optional[:class:`~discord.AllowedMentions`]: The allowed mention configuration. + + .. versionadded:: 1.4 + """ + return self._connection.allowed_mentions + + @allowed_mentions.setter + def allowed_mentions(self, value: Optional[AllowedMentions]) -> None: + if value is None or isinstance(value, AllowedMentions): + self._connection.allowed_mentions = value else: - raise TypeError('activity must be one of Game, Streaming, or Activity.') + raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__.__name__}') + + @property + def intents(self) -> Intents: + """:class:`~discord.Intents`: The intents configured for this connection. + + .. versionadded:: 1.5 + """ + return self._connection.intents # helpers/getters @property - def users(self): + def users(self) -> List[User]: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id): - """Optional[Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]]: Returns a channel with the - given ID. + def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + """Returns a channel or thread with the given ID. + + .. versionchanged:: 2.0 + + ``id`` parameter is now positional-only. + + Parameters + ----------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] + The returned channel or ``None`` if not found. + """ + return self._connection.get_channel(id) # type: ignore # The cache contains all channel types + + def get_partial_messageable( + self, id: int, *, guild_id: Optional[int] = None, type: Optional[ChannelType] = None + ) -> PartialMessageable: + """Returns a partial messageable with the given channel ID. + + This is useful if you have a channel_id but don't want to do an API call + to send messages to it. + + .. versionadded:: 2.0 + + Parameters + ----------- + id: :class:`int` + The channel ID to create a partial messageable for. + guild_id: Optional[:class:`int`] + The optional guild ID to create a partial messageable for. + + This is not required to actually send messages, but it does allow the + :meth:`~discord.PartialMessageable.jump_url` and + :attr:`~discord.PartialMessageable.guild` properties to function properly. + type: Optional[:class:`.ChannelType`] + The underlying channel type for the partial messageable. + + Returns + -------- + :class:`.PartialMessageable` + The partial messageable + """ + return PartialMessageable(state=self._connection, id=id, guild_id=guild_id, type=type) + + def get_stage_instance(self, id: int, /) -> Optional[StageInstance]: + """Returns a stage instance with the given stage channel ID. + + .. versionadded:: 2.0 - If not found, returns ``None``. + Parameters + ----------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`.StageInstance`] + The stage instance or ``None`` if not found. """ - return self._connection.get_channel(id) + from .channel import StageChannel + + channel = self._connection.get_channel(id) + + if isinstance(channel, StageChannel): + return channel.instance + + def get_guild(self, id: int, /) -> Optional[Guild]: + """Returns a guild with the given ID. - def get_guild(self, id): - """Optional[:class:`.Guild`]: Returns a guild with the given ID. + .. versionchanged:: 2.0 + + ``id`` parameter is now positional-only. + + Parameters + ----------- + id: :class:`int` + The ID to search for. - If not found, returns ``None``. + Returns + -------- + Optional[:class:`.Guild`] + The guild or ``None`` if not found. """ return self._connection._get_guild(id) - def get_user(self, id): - """Optional[:class:`~discord.User`]: Returns a user with the given ID. + def get_user(self, id: int, /) -> Optional[User]: + """Returns a user with the given ID. + + .. versionchanged:: 2.0 + + ``id`` parameter is now positional-only. - If not found, returns ``None``. + Parameters + ----------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`~discord.User`] + The user or ``None`` if not found. """ return self._connection.get_user(id) - def get_emoji(self, id): - """Optional[:class:`.Emoji`]: Returns an emoji with the given ID. + def get_emoji(self, id: int, /) -> Optional[Emoji]: + """Returns an emoji with the given ID. + + .. versionchanged:: 2.0 + + ``id`` parameter is now positional-only. + + Parameters + ----------- + id: :class:`int` + The ID to search for. - If not found, returns ``None``. + Returns + -------- + Optional[:class:`.Emoji`] + The custom emoji or ``None`` if not found. """ return self._connection.get_emoji(id) - def get_all_channels(self): + def get_sticker(self, id: int, /) -> Optional[GuildSticker]: + """Returns a guild sticker with the given ID. + + .. versionadded:: 2.0 + + .. note:: + + To retrieve standard stickers, use :meth:`.fetch_sticker`. + or :meth:`.fetch_premium_sticker_packs`. + + Returns + -------- + Optional[:class:`.GuildSticker`] + The sticker or ``None`` if not found. + """ + return self._connection.get_sticker(id) + + def get_soundboard_sound(self, id: int, /) -> Optional[SoundboardSound]: + """Returns a soundboard sound with the given ID. + + .. versionadded:: 2.5 + + Parameters + ---------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`.SoundboardSound`] + The soundboard sound or ``None`` if not found. + """ + return self._connection.get_soundboard_sound(id) + + def get_all_channels(self) -> Generator[GuildChannel, None, None]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. This is equivalent to: :: @@ -669,13 +1182,17 @@ def get_all_channels(self): Just because you receive a :class:`.abc.GuildChannel` does not mean that you can communicate in said channel. :meth:`.abc.GuildChannel.permissions_for` should be used for that. + + Yields + ------ + :class:`.abc.GuildChannel` + A channel the client can 'access'. """ for guild in self.guilds: - for channel in guild.channels: - yield channel + yield from guild.channels - def get_all_members(self): + def get_all_members(self) -> Generator[Member, None, None]: """Returns a generator with every :class:`.Member` the client can see. This is equivalent to: :: @@ -683,21 +1200,727 @@ def get_all_members(self): for guild in client.guilds: for member in guild.members: yield member + + Yields + ------ + :class:`.Member` + A member the client can see. """ for guild in self.guilds: - for member in guild.members: - yield member + yield from guild.members # listeners/waiters - async def wait_until_ready(self): + async def wait_until_ready(self) -> None: """|coro| Waits until the client's internal cache is all ready. - """ - await self._ready.wait() - def wait_for(self, event, *, check=None, timeout=None): + .. warning:: + + Calling this inside :meth:`setup_hook` can lead to a deadlock. + """ + if self._ready is not MISSING: + await self._ready.wait() + else: + raise RuntimeError( + 'Client has not been properly initialised. ' + 'Please use the login method or asynchronous context manager before calling this method' + ) + + # App Commands + + @overload + async def wait_for( + self, + event: Literal['raw_app_command_permissions_update'], + /, + *, + check: Optional[Callable[[RawAppCommandPermissionsUpdateEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawAppCommandPermissionsUpdateEvent: ... + + @overload + async def wait_for( + self, + event: Literal['app_command_completion'], + /, + *, + check: Optional[Callable[[Interaction[Self], Union[Command[Any, ..., Any], ContextMenu]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Interaction[Self], Union[Command[Any, ..., Any], ContextMenu]]: ... + + # AutoMod + + @overload + async def wait_for( + self, + event: Literal['automod_rule_create', 'automod_rule_update', 'automod_rule_delete'], + /, + *, + check: Optional[Callable[[AutoModRule], bool]] = ..., + timeout: Optional[float] = ..., + ) -> AutoModRule: ... + + @overload + async def wait_for( + self, + event: Literal['automod_action'], + /, + *, + check: Optional[Callable[[AutoModAction], bool]] = ..., + timeout: Optional[float] = ..., + ) -> AutoModAction: ... + + # Channels + + @overload + async def wait_for( + self, + event: Literal['private_channel_update'], + /, + *, + check: Optional[Callable[[GroupChannel, GroupChannel], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[GroupChannel, GroupChannel]: ... + + @overload + async def wait_for( + self, + event: Literal['private_channel_pins_update'], + /, + *, + check: Optional[Callable[[PrivateChannel, datetime.datetime], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[PrivateChannel, datetime.datetime]: ... + + @overload + async def wait_for( + self, + event: Literal['guild_channel_delete', 'guild_channel_create'], + /, + *, + check: Optional[Callable[[GuildChannel], bool]] = ..., + timeout: Optional[float] = ..., + ) -> GuildChannel: ... + + @overload + async def wait_for( + self, + event: Literal['guild_channel_update'], + /, + *, + check: Optional[Callable[[GuildChannel, GuildChannel], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[GuildChannel, GuildChannel]: ... + + @overload + async def wait_for( + self, + event: Literal['guild_channel_pins_update'], + /, + *, + check: Optional[ + Callable[ + [Union[GuildChannel, Thread], Optional[datetime.datetime]], + bool, + ] + ], + timeout: Optional[float] = ..., + ) -> Tuple[Union[GuildChannel, Thread], Optional[datetime.datetime]]: ... + + @overload + async def wait_for( + self, + event: Literal['typing'], + /, + *, + check: Optional[Callable[[Messageable, Union[User, Member], datetime.datetime], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Messageable, Union[User, Member], datetime.datetime]: ... + + @overload + async def wait_for( + self, + event: Literal['raw_typing'], + /, + *, + check: Optional[Callable[[RawTypingEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawTypingEvent: ... + + # Debug & Gateway events + + @overload + async def wait_for( + self, + event: Literal['connect', 'disconnect', 'ready', 'resumed'], + /, + *, + check: Optional[Callable[[], bool]] = ..., + timeout: Optional[float] = ..., + ) -> None: ... + + @overload + async def wait_for( + self, + event: Literal['shard_connect', 'shard_disconnect', 'shard_ready', 'shard_resumed'], + /, + *, + check: Optional[Callable[[int], bool]] = ..., + timeout: Optional[float] = ..., + ) -> int: ... + + @overload + async def wait_for( + self, + event: Literal['socket_event_type', 'socket_raw_receive'], + /, + *, + check: Optional[Callable[[str], bool]] = ..., + timeout: Optional[float] = ..., + ) -> str: ... + + @overload + async def wait_for( + self, + event: Literal['socket_raw_send'], + /, + *, + check: Optional[Callable[[Union[str, bytes]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Union[str, bytes]: ... + + # Entitlements + @overload + async def wait_for( + self, + event: Literal['entitlement_create', 'entitlement_update', 'entitlement_delete'], + /, + *, + check: Optional[Callable[[Entitlement], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Entitlement: ... + + # Guilds + + @overload + async def wait_for( + self, + event: Literal[ + 'guild_available', + 'guild_unavailable', + 'guild_join', + 'guild_remove', + ], + /, + *, + check: Optional[Callable[[Guild], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Guild: ... + + @overload + async def wait_for( + self, + event: Literal['guild_update'], + /, + *, + check: Optional[Callable[[Guild, Guild], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Guild, Guild]: ... + + @overload + async def wait_for( + self, + event: Literal['guild_emojis_update'], + /, + *, + check: Optional[Callable[[Guild, Sequence[Emoji], Sequence[Emoji]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Guild, Sequence[Emoji], Sequence[Emoji]]: ... + + @overload + async def wait_for( + self, + event: Literal['guild_stickers_update'], + /, + *, + check: Optional[Callable[[Guild, Sequence[GuildSticker], Sequence[GuildSticker]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Guild, Sequence[GuildSticker], Sequence[GuildSticker]]: ... + + @overload + async def wait_for( + self, + event: Literal['invite_create', 'invite_delete'], + /, + *, + check: Optional[Callable[[Invite], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Invite: ... + + @overload + async def wait_for( + self, + event: Literal['audit_log_entry_create'], + /, + *, + check: Optional[Callable[[AuditLogEntry], bool]] = ..., + timeout: Optional[float] = ..., + ) -> AuditLogEntry: ... + + # Integrations + + @overload + async def wait_for( + self, + event: Literal['integration_create', 'integration_update'], + /, + *, + check: Optional[Callable[[Integration], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Integration: ... + + @overload + async def wait_for( + self, + event: Literal['guild_integrations_update'], + /, + *, + check: Optional[Callable[[Guild], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Guild: ... + + @overload + async def wait_for( + self, + event: Literal['webhooks_update'], + /, + *, + check: Optional[Callable[[GuildChannel], bool]] = ..., + timeout: Optional[float] = ..., + ) -> GuildChannel: ... + + @overload + async def wait_for( + self, + event: Literal['raw_integration_delete'], + /, + *, + check: Optional[Callable[[RawIntegrationDeleteEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawIntegrationDeleteEvent: ... + + # Interactions + + @overload + async def wait_for( + self, + event: Literal['interaction'], + /, + *, + check: Optional[Callable[[Interaction[Self]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Interaction[Self]: ... + + # Members + + @overload + async def wait_for( + self, + event: Literal['member_join', 'member_remove'], + /, + *, + check: Optional[Callable[[Member], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Member: ... + + @overload + async def wait_for( + self, + event: Literal['raw_member_remove'], + /, + *, + check: Optional[Callable[[RawMemberRemoveEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawMemberRemoveEvent: ... + + @overload + async def wait_for( + self, + event: Literal['member_update', 'presence_update'], + /, + *, + check: Optional[Callable[[Member, Member], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Member, Member]: ... + + @overload + async def wait_for( + self, + event: Literal['user_update'], + /, + *, + check: Optional[Callable[[User, User], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[User, User]: ... + + @overload + async def wait_for( + self, + event: Literal['member_ban'], + /, + *, + check: Optional[Callable[[Guild, Union[User, Member]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Guild, Union[User, Member]]: ... + + @overload + async def wait_for( + self, + event: Literal['member_unban'], + /, + *, + check: Optional[Callable[[Guild, User], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Guild, User]: ... + + # Messages + + @overload + async def wait_for( + self, + event: Literal['message', 'message_delete'], + /, + *, + check: Optional[Callable[[Message], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Message: ... + + @overload + async def wait_for( + self, + event: Literal['message_edit'], + /, + *, + check: Optional[Callable[[Message, Message], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Message, Message]: ... + + @overload + async def wait_for( + self, + event: Literal['bulk_message_delete'], + /, + *, + check: Optional[Callable[[List[Message]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> List[Message]: ... + + @overload + async def wait_for( + self, + event: Literal['raw_message_edit'], + /, + *, + check: Optional[Callable[[RawMessageUpdateEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawMessageUpdateEvent: ... + + @overload + async def wait_for( + self, + event: Literal['raw_message_delete'], + /, + *, + check: Optional[Callable[[RawMessageDeleteEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawMessageDeleteEvent: ... + + @overload + async def wait_for( + self, + event: Literal['raw_bulk_message_delete'], + /, + *, + check: Optional[Callable[[RawBulkMessageDeleteEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawBulkMessageDeleteEvent: ... + + # Reactions + + @overload + async def wait_for( + self, + event: Literal['reaction_add', 'reaction_remove'], + /, + *, + check: Optional[Callable[[Reaction, Union[Member, User]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Reaction, Union[Member, User]]: ... + + @overload + async def wait_for( + self, + event: Literal['reaction_clear'], + /, + *, + check: Optional[Callable[[Message, List[Reaction]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Message, List[Reaction]]: ... + + @overload + async def wait_for( + self, + event: Literal['reaction_clear_emoji'], + /, + *, + check: Optional[Callable[[Reaction], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Reaction: ... + + @overload + async def wait_for( + self, + event: Literal['raw_reaction_add', 'raw_reaction_remove'], + /, + *, + check: Optional[Callable[[RawReactionActionEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawReactionActionEvent: ... + + @overload + async def wait_for( + self, + event: Literal['raw_reaction_clear'], + /, + *, + check: Optional[Callable[[RawReactionClearEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawReactionClearEvent: ... + + @overload + async def wait_for( + self, + event: Literal['raw_reaction_clear_emoji'], + /, + *, + check: Optional[Callable[[RawReactionClearEmojiEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawReactionClearEmojiEvent: ... + + # Roles + + @overload + async def wait_for( + self, + event: Literal['guild_role_create', 'guild_role_delete'], + /, + *, + check: Optional[Callable[[Role], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Role: ... + + @overload + async def wait_for( + self, + event: Literal['guild_role_update'], + /, + *, + check: Optional[Callable[[Role, Role], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Role, Role]: ... + + # Scheduled Events + + @overload + async def wait_for( + self, + event: Literal['scheduled_event_create', 'scheduled_event_delete'], + /, + *, + check: Optional[Callable[[ScheduledEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> ScheduledEvent: ... + + @overload + async def wait_for( + self, + event: Literal['scheduled_event_user_add', 'scheduled_event_user_remove'], + /, + *, + check: Optional[Callable[[ScheduledEvent, User], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[ScheduledEvent, User]: ... + + # Stages + + @overload + async def wait_for( + self, + event: Literal['stage_instance_create', 'stage_instance_delete'], + /, + *, + check: Optional[Callable[[StageInstance], bool]] = ..., + timeout: Optional[float] = ..., + ) -> StageInstance: ... + + @overload + async def wait_for( + self, + event: Literal['stage_instance_update'], + /, + *, + check: Optional[Callable[[StageInstance, StageInstance], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Coroutine[Any, Any, Tuple[StageInstance, StageInstance]]: ... + + # Subscriptions + @overload + async def wait_for( + self, + event: Literal['subscription_create', 'subscription_update', 'subscription_delete'], + /, + *, + check: Optional[Callable[[Subscription], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Subscription: ... + + # Threads + @overload + async def wait_for( + self, + event: Literal['thread_create', 'thread_join', 'thread_remove', 'thread_delete'], + /, + *, + check: Optional[Callable[[Thread], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Thread: ... + + @overload + async def wait_for( + self, + event: Literal['thread_update'], + /, + *, + check: Optional[Callable[[Thread, Thread], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Thread, Thread]: ... + + @overload + async def wait_for( + self, + event: Literal['raw_thread_update'], + /, + *, + check: Optional[Callable[[RawThreadUpdateEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawThreadUpdateEvent: ... + + @overload + async def wait_for( + self, + event: Literal['raw_thread_delete'], + /, + *, + check: Optional[Callable[[RawThreadDeleteEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawThreadDeleteEvent: ... + + @overload + async def wait_for( + self, + event: Literal['thread_member_join', 'thread_member_remove'], + /, + *, + check: Optional[Callable[[ThreadMember], bool]] = ..., + timeout: Optional[float] = ..., + ) -> ThreadMember: ... + + @overload + async def wait_for( + self, + event: Literal['raw_thread_member_remove'], + /, + *, + check: Optional[Callable[[RawThreadMembersUpdate], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawThreadMembersUpdate: ... + + # Voice + + @overload + async def wait_for( + self, + event: Literal['voice_state_update'], + /, + *, + check: Optional[Callable[[Member, VoiceState, VoiceState], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Member, VoiceState, VoiceState]: ... + + # Polls + + @overload + async def wait_for( + self, + event: Literal['poll_vote_add', 'poll_vote_remove'], + /, + *, + check: Optional[Callable[[Union[User, Member], PollAnswer], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Union[User, Member], PollAnswer]: ... + + @overload + async def wait_for( + self, + event: Literal['raw_poll_vote_add', 'raw_poll_vote_remove'], + /, + *, + check: Optional[Callable[[RawPollVoteActionEvent], bool]] = ..., + timeout: Optional[float] = ..., + ) -> RawPollVoteActionEvent: ... + + # Commands + + @overload + async def wait_for( + self: Union[Bot, AutoShardedBot], + event: Literal['command', 'command_completion'], + /, + *, + check: Optional[Callable[[Context[Any]], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Context[Any]: ... + + @overload + async def wait_for( + self: Union[Bot, AutoShardedBot], + event: Literal['command_error'], + /, + *, + check: Optional[Callable[[Context[Any], CommandError], bool]] = ..., + timeout: Optional[float] = ..., + ) -> Tuple[Context[Any], CommandError]: ... + + @overload + async def wait_for( + self, + event: str, + /, + *, + check: Optional[Callable[..., bool]] = ..., + timeout: Optional[float] = ..., + ) -> Any: ... + + def wait_for( + self, + event: str, + /, + *, + check: Optional[Callable[..., bool]] = None, + timeout: Optional[float] = None, + ) -> Coro[Any]: """|coro| Waits for a WebSocket event to be dispatched. @@ -733,7 +1956,7 @@ def check(m): return m.content == 'hello' and m.channel == channel msg = await client.wait_for('message', check=check) - await channel.send('Hello {.author}!'.format(msg)) + await channel.send(f'Hello {msg.author}!') Waiting for a thumbs up reaction from the message author: :: @@ -753,6 +1976,10 @@ def check(reaction, user): else: await channel.send('\N{THUMBS UP SIGN}') + .. versionchanged:: 2.0 + + ``event`` parameter is now positional-only. + Parameters ------------ @@ -781,8 +2008,10 @@ def check(reaction, user): future = self.loop.create_future() if check is None: + def _check(*args): return True + check = _check ev = event.lower() @@ -793,11 +2022,11 @@ def _check(*args): self._listeners[ev] = listeners listeners.append((future, check)) - return asyncio.wait_for(future, timeout, loop=self.loop) + return asyncio.wait_for(future, timeout) # event registration - def event(self, coro): + def event(self, coro: CoroT, /) -> CoroT: """A decorator that registers an event to listen to. You can find more info about the events on the :ref:`documentation below `. @@ -813,6 +2042,10 @@ def event(self, coro): async def on_ready(): print('Ready!') + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Raises -------- TypeError @@ -823,18 +2056,19 @@ async def on_ready(): raise TypeError('event registered must be a coroutine function') setattr(self, coro.__name__, coro) - log.debug('%s has successfully been registered as an event', coro.__name__) + _log.debug('%s has successfully been registered as an event', coro.__name__) return coro - async def change_presence(self, *, activity=None, status=None, afk=False): + async def change_presence( + self, + *, + activity: Optional[BaseActivity] = None, + status: Optional[Status] = None, + ) -> None: """|coro| Changes the client's presence. - The activity parameter is a :class:`.Activity` object (not a string) that represents - the activity being done currently. This could also be the slimmed down versions, - :class:`.Game` and :class:`.Streaming`. - Example --------- @@ -843,55 +2077,67 @@ async def change_presence(self, *, activity=None, status=None, afk=False): game = discord.Game("with the API") await client.change_presence(status=discord.Status.idle, activity=game) + .. versionchanged:: 2.0 + Removed the ``afk`` keyword-only parameter. + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` instead of + ``InvalidArgument``. + Parameters ---------- - activity: Optional[Union[:class:`.Game`, :class:`.Streaming`, :class:`.Activity`]] + activity: Optional[:class:`.BaseActivity`] The activity being done. ``None`` if no currently active activity is done. status: Optional[:class:`.Status`] Indicates what status to change to. If ``None``, then :attr:`.Status.online` is used. - afk: Optional[:class:`bool`] - Indicates if you are going AFK. This allows the discord - client to know how to handle push notifications better - for you in case you are actually idle and not lying. Raises ------ - :exc:`.InvalidArgument` + TypeError If the ``activity`` parameter is not the proper type. """ if status is None: - status = 'online' - status_enum = Status.online + status_str = 'online' + status = Status.online elif status is Status.offline: - status = 'invisible' - status_enum = Status.offline + status_str = 'invisible' + status = Status.offline else: - status_enum = status - status = str(status) + status_str = str(status) - await self.ws.change_presence(activity=activity, status=status, afk=afk) + await self.ws.change_presence(activity=activity, status=status_str) for guild in self._connection.guilds: me = guild.me if me is None: continue - me.activities = (activity,) - me.status = status_enum + if activity is not None: + me.activities = (activity,) # type: ignore # Type checker does not understand the downcast here + else: + me.activities = () + + me.status = status # Guild stuff - def fetch_guilds(self, *, limit=100, before=None, after=None): - """|coro| - - Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. + async def fetch_guilds( + self, + *, + limit: Optional[int] = 200, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + with_counts: bool = True, + ) -> AsyncIterator[Guild]: + """Retrieves an :term:`asynchronous iterator` that enables receiving your guilds. .. note:: Using this, you will only receive :attr:`.Guild.owner`, :attr:`.Guild.icon`, - :attr:`.Guild.id`, and :attr:`.Guild.name` per :class:`.Guild`. + :attr:`.Guild.id`, :attr:`.Guild.name`, :attr:`.Guild.approximate_member_count`, + and :attr:`.Guild.approximate_presence_count` per :class:`.Guild`. .. note:: @@ -907,7 +2153,7 @@ def fetch_guilds(self, *, limit=100, before=None, after=None): Flattening into a list :: - guilds = await client.fetch_guilds(limit=150).flatten() + guilds = [guild async for guild in client.fetch_guilds(limit=150)] # guilds is now a list of Guild... All parameters are optional. @@ -918,17 +2164,30 @@ def fetch_guilds(self, *, limit=100, before=None, after=None): The number of guilds to retrieve. If ``None``, it retrieves every guild you have access to. Note, however, that this would make it a slow operation. - Defaults to 100. + Defaults to ``200``. + + .. versionchanged:: 2.0 + + The default has been changed to 200. + before: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`] Retrieves guilds before this date or object. - If a date is provided it must be a timezone-naive datetime representing UTC time. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. after: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`] Retrieve guilds after this date or object. - If a date is provided it must be a timezone-naive datetime representing UTC time. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + with_counts: :class:`bool` + Whether to include count information in the guilds. This fills the + :attr:`.Guild.approximate_member_count` and :attr:`.Guild.approximate_presence_count` + attributes without needing any privileged intents. Defaults to ``True``. + + .. versionadded:: 2.3 Raises ------ - :exc:`.HTTPException` + HTTPException Getting the guilds failed. Yields @@ -936,65 +2195,199 @@ def fetch_guilds(self, *, limit=100, before=None, after=None): :class:`.Guild` The guild with the guild data parsed. """ - return GuildIterator(self, limit=limit, before=before, after=after) - async def fetch_guild(self, guild_id): - """|coro| + async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]): + before_id = before.id if before else None + data = await self.http.get_guilds(retrieve, before=before_id, with_counts=with_counts) - Retrieves a :class:`.Guild` from an ID. + if data: + if limit is not None: + limit -= len(data) - .. note:: + before = Object(id=int(data[0]['id'])) - Using this, you will **not** receive :attr:`.Guild.channels`, :class:`.Guild.members`, - :attr:`.Member.activity` and :attr:`.Member.voice` per :class:`.Member`. + return data, before, limit - .. note:: + async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]): + after_id = after.id if after else None + data = await self.http.get_guilds(retrieve, after=after_id, with_counts=with_counts) - This method is an API call. For general usage, consider :meth:`get_guild` instead. + if data: + if limit is not None: + limit -= len(data) - Parameters - ----------- - guild_id: :class:`int` - The guild's ID to fetch from. + after = Object(id=int(data[-1]['id'])) - Raises - ------ - :exc:`.Forbidden` - You do not have access to the guild. - :exc:`.HTTPException` - Getting the guild failed. + return data, after, limit - Returns - -------- - :class:`.Guild` + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + + predicate: Optional[Callable[[GuildPayload], bool]] = None + strategy, state = _after_strategy, after + + if before: + strategy, state = _before_strategy, before + + if before and after: + predicate = lambda m: int(m['id']) > after.id + + while True: + retrieve = 200 if limit is None else min(limit, 200) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + if predicate: + data = filter(predicate, data) + + count = 0 + + for count, raw_guild in enumerate(data, 1): + yield Guild(state=self._connection, data=raw_guild) + + if count < 200: + # There's no data left after this + break + + async def fetch_template(self, code: Union[Template, str]) -> Template: + """|coro| + + Gets a :class:`.Template` from a discord.new URL or code. + + Parameters + ----------- + code: Union[:class:`.Template`, :class:`str`] + The Discord Template Code or URL (must be a discord.new URL). + + Raises + ------- + NotFound + The template is invalid. + HTTPException + Getting the template failed. + + Returns + -------- + :class:`.Template` + The template from the URL/code. + """ + code = utils.resolve_template(code) + data = await self.http.get_template(code) + return Template(data=data, state=self._connection) + + async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild: + """|coro| + + Retrieves a :class:`.Guild` from an ID. + + .. note:: + + Using this, you will **not** receive :attr:`.Guild.channels`, :attr:`.Guild.members`, + :attr:`.Member.activity` and :attr:`.Member.voice` per :class:`.Member`. + + .. note:: + + This method is an API call. For general usage, consider :meth:`get_guild` instead. + + .. versionchanged:: 2.0 + + ``guild_id`` parameter is now positional-only. + + + Parameters + ----------- + guild_id: :class:`int` + The guild's ID to fetch from. + with_counts: :class:`bool` + Whether to include count information in the guild. This fills the + :attr:`.Guild.approximate_member_count` and :attr:`.Guild.approximate_presence_count` + attributes without needing any privileged intents. Defaults to ``True``. + + .. versionadded:: 2.0 + + Raises + ------ + NotFound + The guild doesn't exist or you got no access to it. + HTTPException + Getting the guild failed. + + Returns + -------- + :class:`.Guild` The guild from the ID. """ - data = await self.http.get_guild(guild_id) + data = await self.http.get_guild(guild_id, with_counts=with_counts) return Guild(data=data, state=self._connection) - async def create_guild(self, name, region=None, icon=None): + async def fetch_guild_preview(self, guild_id: int) -> GuildPreview: + """|coro| + + Retrieves a preview of a :class:`.Guild` from an ID. If the guild is discoverable, + you don't have to be a member of it. + + .. versionadded:: 2.5 + + Raises + ------ + NotFound + The guild doesn't exist, or is not discoverable and you are not in it. + HTTPException + Getting the guild failed. + + Returns + -------- + :class:`.GuildPreview` + The guild preview from the ID. + """ + data = await self.http.get_guild_preview(guild_id) + return GuildPreview(data=data, state=self._connection) + + @deprecated() + async def create_guild( + self, + *, + name: str, + icon: bytes = MISSING, + code: str = MISSING, + ) -> Guild: """|coro| Creates a :class:`.Guild`. Bot accounts in more than 10 guilds are not allowed to create guilds. + .. versionchanged:: 2.0 + ``name`` and ``icon`` parameters are now keyword-only. The ``region`` parameter has been removed. + + .. versionchanged:: 2.0 + This function will now raise :exc:`ValueError` instead of + ``InvalidArgument``. + + .. deprecated:: 2.6 + This function is deprecated and will be removed in a future version. + Parameters ---------- name: :class:`str` The name of the guild. - region: :class:`.VoiceRegion` - The region for the voice communication server. - Defaults to :attr:`.VoiceRegion.us_west`. - icon: :class:`bytes` + icon: Optional[:class:`bytes`] The :term:`py:bytes-like object` representing the icon. See :meth:`.ClientUser.edit` for more details on what is expected. + code: :class:`str` + The code for a template to create the guild with. + + .. versionadded:: 1.4 Raises ------ - :exc:`.HTTPException` + HTTPException Guild creation failed. - :exc:`.InvalidArgument` + ValueError Invalid icon image format given. Must be PNG or JPG. Returns @@ -1003,20 +2396,56 @@ async def create_guild(self, name, region=None, icon=None): The guild created. This is not the same guild that is added to cache. """ - if icon is not None: - icon = utils._bytes_to_base64_data(icon) - - if region is None: - region = VoiceRegion.us_west.value + if icon is not MISSING: + icon_base64 = utils._bytes_to_base64_data(icon) else: - region = region.value + icon_base64 = None - data = await self.http.create_guild(name, region, icon) + if code: + data = await self.http.create_from_template(code, name, icon_base64) + else: + data = await self.http.create_guild(name, icon_base64) return Guild(data=data, state=self._connection) + async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance: + """|coro| + + Gets a :class:`.StageInstance` for a stage channel id. + + .. versionadded:: 2.0 + + Parameters + ----------- + channel_id: :class:`int` + The stage channel ID. + + Raises + ------- + NotFound + The stage instance or channel could not be found. + HTTPException + Getting the stage instance failed. + + Returns + -------- + :class:`.StageInstance` + The stage instance from the stage channel ID. + """ + data = await self.http.get_stage_instance(channel_id) + guild = self.get_guild(int(data['guild_id'])) + # Guild can technically be None here but this is being explicitly silenced right now. + return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore + # Invite management - async def fetch_invite(self, url, *, with_counts=True): + async def fetch_invite( + self, + url: Union[Invite, str], + *, + with_counts: bool = True, + with_expiration: bool = True, + scheduled_event_id: Optional[int] = None, + ) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1025,22 +2454,41 @@ async def fetch_invite(self, url, *, with_counts=True): If the invite is for a guild you have not joined, the guild and channel attributes of the returned :class:`.Invite` will be :class:`.PartialInviteGuild` and - :class:`PartialInviteChannel` respectively. + :class:`.PartialInviteChannel` respectively. Parameters ----------- - url: :class:`str` + url: Union[:class:`.Invite`, :class:`str`] The Discord invite ID or URL (must be a discord.gg URL). with_counts: :class:`bool` Whether to include count information in the invite. This fills the :attr:`.Invite.approximate_member_count` and :attr:`.Invite.approximate_presence_count` fields. + with_expiration: :class:`bool` + Whether to include the expiration date of the invite. This fills the + :attr:`.Invite.expires_at` field. + + .. versionadded:: 2.0 + .. deprecated:: 2.6 + This parameter is deprecated and will be removed in a future version as it is no + longer needed to fill the :attr:`.Invite.expires_at` field. + scheduled_event_id: Optional[:class:`int`] + The ID of the scheduled event this invite is for. + + .. note:: + + It is not possible to provide a url that contains an ``event_id`` parameter + when using this parameter. + + .. versionadded:: 2.0 Raises ------- - :exc:`.NotFound` + ValueError + The url contains an ``event_id``, but ``scheduled_event_id`` has also been provided. + NotFound The invite has expired or is invalid. - :exc:`.HTTPException` + HTTPException Getting the invite failed. Returns @@ -1049,39 +2497,56 @@ async def fetch_invite(self, url, *, with_counts=True): The invite from the URL/ID. """ - invite_id = utils.resolve_invite(url) - data = await self.http.get_invite(invite_id, with_counts=with_counts) + resolved = utils.resolve_invite(url) + + if scheduled_event_id and resolved.event: + raise ValueError('Cannot specify scheduled_event_id and contain an event_id in the url.') + + scheduled_event_id = scheduled_event_id or resolved.event + + data = await self.http.get_invite( + resolved.code, + with_counts=with_counts, + guild_scheduled_event_id=scheduled_event_id, + ) return Invite.from_incomplete(state=self._connection, data=data) - async def delete_invite(self, invite): + async def delete_invite(self, invite: Union[Invite, str], /, *, reason: Optional[str] = None) -> Invite: """|coro| Revokes an :class:`.Invite`, URL, or ID to an invite. - You must have the :attr:`~.Permissions.manage_channels` permission in + You must have :attr:`~.Permissions.manage_channels` in the associated guild to do this. + .. versionchanged:: 2.0 + + ``invite`` parameter is now positional-only. + Parameters ---------- invite: Union[:class:`.Invite`, :class:`str`] The invite to revoke. + reason: Optional[:class:`str`] + The reason for deleting the invite. Shows up on the audit log. Raises ------- - :exc:`.Forbidden` + Forbidden You do not have permissions to revoke invites. - :exc:`.NotFound` + NotFound The invite is invalid or expired. - :exc:`.HTTPException` + HTTPException Revoking the invite failed. """ - invite_id = utils.resolve_invite(invite) - await self.http.delete_invite(invite_id) + resolved = utils.resolve_invite(invite) + data = await self.http.delete_invite(resolved.code, reason=reason) + return Invite.from_incomplete(state=self._connection, data=data) # Miscellaneous stuff - async def fetch_widget(self, guild_id): + async def fetch_widget(self, guild_id: int, /) -> Widget: """|coro| Gets a :class:`.Widget` from a guild ID. @@ -1090,6 +2555,10 @@ async def fetch_widget(self, guild_id): The guild must have the widget enabled to get this information. + .. versionchanged:: 2.0 + + ``guild_id`` parameter is now positional-only. + Parameters ----------- guild_id: :class:`int` @@ -1097,9 +2566,9 @@ async def fetch_widget(self, guild_id): Raises ------- - :exc:`.Forbidden` + Forbidden The widget for this guild is disabled. - :exc:`.HTTPException` + HTTPException Retrieving the widget failed. Returns @@ -1111,14 +2580,14 @@ async def fetch_widget(self, guild_id): return Widget(state=self._connection, data=data) - async def application_info(self): + async def application_info(self) -> AppInfo: """|coro| Retrieves the bot's application information. Raises ------- - :exc:`.HTTPException` + HTTPException Retrieving the information failed somehow. Returns @@ -1127,21 +2596,22 @@ async def application_info(self): The bot's application information. """ data = await self.http.application_info() - if 'rpc_origins' not in data: - data['rpc_origins'] = None return AppInfo(self._connection, data) - async def fetch_user(self, user_id): + async def fetch_user(self, user_id: int, /) -> User: """|coro| - Retrieves a :class:`~discord.User` based on their ID. This can only - be used by bot accounts. You do not have to share any guilds - with the user to get this information, however many operations - do require that you do. + Retrieves a :class:`~discord.User` based on their ID. + You do not have to share any guilds with the user to get this information, + however many operations do require that you do. .. note:: - This method is an API call. For general usage, consider :meth:`get_user` instead. + This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead. + + .. versionchanged:: 2.0 + + ``user_id`` parameter is now positional-only. Parameters ----------- @@ -1150,9 +2620,9 @@ async def fetch_user(self, user_id): Raises ------- - :exc:`.NotFound` + NotFound A user with this ID does not exist. - :exc:`.HTTPException` + HTTPException Fetching the user failed. Returns @@ -1163,98 +2633,71 @@ async def fetch_user(self, user_id): data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_user_profile(self, user_id): + async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]: """|coro| - Gets an arbitrary user's profile. This can only be used by non-bot accounts. - - Parameters - ------------ - user_id: :class:`int` - The ID of the user to fetch their profile for. - - Raises - ------- - :exc:`.Forbidden` - Not allowed to fetch profiles. - :exc:`.HTTPException` - Fetching the profile failed. - - Returns - -------- - :class:`.Profile` - The profile of the user. - """ - - state = self._connection - data = await self.http.get_user_profile(user_id) - - def transform(d): - return state._get_guild(int(d['id'])) - - since = data.get('premium_since') - mutual_guilds = list(filter(None, map(transform, data.get('mutual_guilds', [])))) - user = data['user'] - return Profile(flags=user.get('flags', 0), - premium_since=utils.parse_time(since), - mutual_guilds=mutual_guilds, - user=User(data=user, state=state), - connected_accounts=data['connected_accounts']) - - async def fetch_channel(self, channel_id): - """|coro| - - Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID. + Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. .. note:: This method is an API call. For general usage, consider :meth:`get_channel` instead. - .. versionadded:: 1.2.0 + .. versionadded:: 1.2 + + .. versionchanged:: 2.0 + + ``channel_id`` parameter is now positional-only. Raises ------- - :exc:`.InvalidData` + InvalidData An unknown channel type was received from Discord. - :exc:`.HTTPException` + HTTPException Retrieving the channel failed. - :exc:`.NotFound` + NotFound Invalid Channel ID. - :exc:`.Forbidden` + Forbidden You do not have permission to fetch this channel. Returns -------- - Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`] + Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`] The channel from the ID. """ data = await self.http.get_channel(channel_id) - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _threaded_channel_factory(data['type']) if factory is None: raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): - channel = factory(me=self.user, data=data, state=self._connection) + # the factory will be a DMChannel or GroupChannel here + channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: - guild_id = int(data['guild_id']) - guild = self.get_guild(guild_id) or Object(id=guild_id) - channel = factory(guild=guild, state=self._connection, data=data) + # the factory can't be a DMChannel or GroupChannel here + guild_id = int(data['guild_id']) # type: ignore + guild = self._connection._get_or_create_unavailable_guild(guild_id) + # the factory should be a GuildChannel or Thread + channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel - async def fetch_webhook(self, webhook_id): + async def fetch_webhook(self, webhook_id: int, /) -> Webhook: """|coro| Retrieves a :class:`.Webhook` with the specified ID. + .. versionchanged:: 2.0 + + ``webhook_id`` parameter is now positional-only. + Raises -------- - :exc:`.HTTPException` + HTTPException Retrieving the webhook failed. - :exc:`.NotFound` + NotFound Invalid webhook ID. - :exc:`.Forbidden` + Forbidden You do not have permission to fetch this webhook. Returns @@ -1264,3 +2707,551 @@ async def fetch_webhook(self, webhook_id): """ data = await self.http.get_webhook(webhook_id) return Webhook.from_state(data, state=self._connection) + + async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]: + """|coro| + + Retrieves a :class:`.Sticker` with the specified ID. + + .. versionadded:: 2.0 + + Raises + -------- + HTTPException + Retrieving the sticker failed. + NotFound + Invalid sticker ID. + + Returns + -------- + Union[:class:`.StandardSticker`, :class:`.GuildSticker`] + The sticker you requested. + """ + data = await self.http.get_sticker(sticker_id) + cls, _ = _sticker_factory(data['type']) + # The type checker is not smart enough to figure out the constructor is correct + return cls(state=self._connection, data=data) # type: ignore + + async def fetch_skus(self) -> List[SKU]: + """|coro| + + Retrieves the bot's available SKUs. + + .. versionadded:: 2.4 + + Raises + ------- + MissingApplicationID + The application ID could not be found. + HTTPException + Retrieving the SKUs failed. + + Returns + -------- + List[:class:`.SKU`] + The bot's available SKUs. + """ + + if self.application_id is None: + raise MissingApplicationID + + data = await self.http.get_skus(self.application_id) + return [SKU(state=self._connection, data=sku) for sku in data] + + async def fetch_entitlement(self, entitlement_id: int, /) -> Entitlement: + """|coro| + + Retrieves a :class:`.Entitlement` with the specified ID. + + .. versionadded:: 2.4 + + Parameters + ----------- + entitlement_id: :class:`int` + The entitlement's ID to fetch from. + + Raises + ------- + NotFound + An entitlement with this ID does not exist. + MissingApplicationID + The application ID could not be found. + HTTPException + Fetching the entitlement failed. + + Returns + -------- + :class:`.Entitlement` + The entitlement you requested. + """ + + if self.application_id is None: + raise MissingApplicationID + + data = await self.http.get_entitlement(self.application_id, entitlement_id) + return Entitlement(state=self._connection, data=data) + + async def entitlements( + self, + *, + limit: Optional[int] = 100, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + skus: Optional[Sequence[Snowflake]] = None, + user: Optional[Snowflake] = None, + guild: Optional[Snowflake] = None, + exclude_ended: bool = False, + exclude_deleted: bool = True, + ) -> AsyncIterator[Entitlement]: + """Retrieves an :term:`asynchronous iterator` of the :class:`.Entitlement` that applications has. + + .. versionadded:: 2.4 + + Examples + --------- + + Usage :: + + async for entitlement in client.entitlements(limit=100): + print(entitlement.user_id, entitlement.ends_at) + + Flattening into a list :: + + entitlements = [entitlement async for entitlement in client.entitlements(limit=100)] + # entitlements is now a list of Entitlement... + + All parameters are optional. + + Parameters + ----------- + limit: Optional[:class:`int`] + The number of entitlements to retrieve. If ``None``, it retrieves every entitlement for this application. + Note, however, that this would make it a slow operation. Defaults to ``100``. + before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve entitlements before this date or entitlement. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]] + Retrieve entitlements after this date or entitlement. + If a datetime is provided, it is recommended to use a UTC aware datetime. + If the datetime is naive, it is assumed to be local time. + skus: Optional[Sequence[:class:`~discord.abc.Snowflake`]] + A list of SKUs to filter by. + user: Optional[:class:`~discord.abc.Snowflake`] + The user to filter by. + guild: Optional[:class:`~discord.abc.Snowflake`] + The guild to filter by. + exclude_ended: :class:`bool` + Whether to exclude ended entitlements. Defaults to ``False``. + exclude_deleted: :class:`bool` + Whether to exclude deleted entitlements. Defaults to ``True``. + + .. versionadded:: 2.5 + + Raises + ------- + MissingApplicationID + The application ID could not be found. + HTTPException + Fetching the entitlements failed. + TypeError + Both ``after`` and ``before`` were provided, as Discord does not + support this type of pagination. + + Yields + -------- + :class:`.Entitlement` + The entitlement with the application. + """ + + if self.application_id is None: + raise MissingApplicationID + + if before is not None and after is not None: + raise TypeError('entitlements pagination does not support both before and after') + + # This endpoint paginates in ascending order. + endpoint = self.http.get_entitlements + + async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]): + before_id = before.id if before else None + data = await endpoint( + self.application_id, # type: ignore # We already check for None above + limit=retrieve, + before=before_id, + sku_ids=[sku.id for sku in skus] if skus else None, + user_id=user.id if user else None, + guild_id=guild.id if guild else None, + exclude_ended=exclude_ended, + exclude_deleted=exclude_deleted, + ) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[0]['id'])) + + return data, before, limit + + async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]): + after_id = after.id if after else None + data = await endpoint( + self.application_id, # type: ignore # We already check for None above + limit=retrieve, + after=after_id, + sku_ids=[sku.id for sku in skus] if skus else None, + user_id=user.id if user else None, + guild_id=guild.id if guild else None, + exclude_ended=exclude_ended, + ) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[-1]['id'])) + + return data, after, limit + + if isinstance(before, datetime.datetime): + before = Object(id=utils.time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=utils.time_snowflake(after, high=True)) + + if before: + strategy, state = _before_strategy, before + else: + strategy, state = _after_strategy, after + + while True: + retrieve = 100 if limit is None else min(limit, 100) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 100: + limit = 0 + + for e in data: + yield Entitlement(self._connection, e) + + async def create_entitlement( + self, + sku: Snowflake, + owner: Snowflake, + owner_type: EntitlementOwnerType, + ) -> None: + """|coro| + + Creates a test :class:`.Entitlement` for the application. + + .. versionadded:: 2.4 + + Parameters + ----------- + sku: :class:`~discord.abc.Snowflake` + The SKU to create the entitlement for. + owner: :class:`~discord.abc.Snowflake` + The ID of the owner. + owner_type: :class:`.EntitlementOwnerType` + The type of the owner. + + Raises + ------- + MissingApplicationID + The application ID could not be found. + NotFound + The SKU or owner could not be found. + HTTPException + Creating the entitlement failed. + """ + + if self.application_id is None: + raise MissingApplicationID + + await self.http.create_entitlement(self.application_id, sku.id, owner.id, owner_type.value) + + async def fetch_premium_sticker_packs(self) -> List[StickerPack]: + """|coro| + + Retrieves all available premium sticker packs. + + .. versionadded:: 2.0 + + Raises + ------- + HTTPException + Retrieving the sticker packs failed. + + Returns + --------- + List[:class:`.StickerPack`] + All available premium sticker packs. + """ + data = await self.http.list_premium_sticker_packs() + return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']] + + async def fetch_premium_sticker_pack(self, sticker_pack_id: int, /) -> StickerPack: + """|coro| + + Retrieves a premium sticker pack with the specified ID. + + .. versionadded:: 2.5 + + Parameters + ---------- + sticker_pack_id: :class:`int` + The sticker pack's ID to fetch from. + + Raises + ------- + NotFound + A sticker pack with this ID does not exist. + HTTPException + Retrieving the sticker pack failed. + + Returns + ------- + :class:`.StickerPack` + The retrieved premium sticker pack. + """ + data = await self.http.get_sticker_pack(sticker_pack_id) + return StickerPack(state=self._connection, data=data) + + async def fetch_soundboard_default_sounds(self) -> List[SoundboardDefaultSound]: + """|coro| + + Retrieves all default soundboard sounds. + + .. versionadded:: 2.5 + + Raises + ------- + HTTPException + Retrieving the default soundboard sounds failed. + + Returns + --------- + List[:class:`.SoundboardDefaultSound`] + All default soundboard sounds. + """ + data = await self.http.get_soundboard_default_sounds() + return [SoundboardDefaultSound(state=self._connection, data=sound) for sound in data] + + async def create_dm(self, user: Snowflake) -> DMChannel: + """|coro| + + Creates a :class:`.DMChannel` with this user. + + This should be rarely called, as this is done transparently for most + people. + + .. versionadded:: 2.0 + + Parameters + ----------- + user: :class:`~discord.abc.Snowflake` + The user to create a DM with. + + Returns + ------- + :class:`.DMChannel` + The channel that was created. + """ + state = self._connection + found = state._get_private_channel_by_user(user.id) + if found: + return found + + data = await state.http.start_private_message(user.id) + return state.add_dm_channel(data) + + def add_dynamic_items(self, *items: Type[DynamicItem[Item[Any]]]) -> None: + r"""Registers :class:`~discord.ui.DynamicItem` classes for persistent listening. + + This method accepts *class types* rather than instances. + + .. versionadded:: 2.4 + + Parameters + ----------- + \*items: Type[:class:`~discord.ui.DynamicItem`] + The classes of dynamic items to add. + + Raises + ------- + TypeError + A class is not a subclass of :class:`~discord.ui.DynamicItem`. + """ + + for item in items: + if not issubclass(item, DynamicItem): + raise TypeError(f'expected subclass of DynamicItem not {item.__name__}') + + self._connection.store_dynamic_items(*items) + + def remove_dynamic_items(self, *items: Type[DynamicItem[Item[Any]]]) -> None: + r"""Removes :class:`~discord.ui.DynamicItem` classes from persistent listening. + + This method accepts *class types* rather than instances. + + .. versionadded:: 2.4 + + Parameters + ----------- + \*items: Type[:class:`~discord.ui.DynamicItem`] + The classes of dynamic items to remove. + + Raises + ------- + TypeError + A class is not a subclass of :class:`~discord.ui.DynamicItem`. + """ + + for item in items: + if not issubclass(item, DynamicItem): + raise TypeError(f'expected subclass of DynamicItem not {item.__name__}') + + self._connection.remove_dynamic_items(*items) + + def add_view(self, view: BaseView, *, message_id: Optional[int] = None) -> None: + """Registers a :class:`~discord.ui.View` for persistent listening. + + This method should be used for when a view is comprised of components + that last longer than the lifecycle of the program. + + .. versionadded:: 2.0 + + Parameters + ------------ + view: Union[:class:`discord.ui.View`, :class:`discord.ui.LayoutView`] + The view to register for dispatching. + message_id: Optional[:class:`int`] + The message ID that the view is attached to. This is currently used to + refresh the view's state during message update events. If not given + then message update events are not propagated for the view. + + Raises + ------- + TypeError + A view was not passed. + ValueError + The view is not persistent or is already finished. A persistent view has no timeout + and all their components have an explicitly provided custom_id. + """ + + if not isinstance(view, BaseView): + raise TypeError(f'expected an instance of View not {view.__class__.__name__}') + + if not view.is_persistent(): + raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout') + + if view.is_finished(): + raise ValueError('View is already finished.') + + self._connection.store_view(view, message_id) + + @property + def persistent_views(self) -> Sequence[BaseView]: + """Sequence[Union[:class:`.View`, :class:`.LayoutView`]]: A sequence of persistent views added to the client. + + .. versionadded:: 2.0 + """ + return self._connection.persistent_views + + async def create_application_emoji( + self, + *, + name: str, + image: bytes, + ) -> Emoji: + """|coro| + + Create an emoji for the current application. + + .. versionadded:: 2.5 + + Parameters + ---------- + name: :class:`str` + The emoji name. Must be between 2 and 32 characters long. + image: :class:`bytes` + The :term:`py:bytes-like object` representing the image data to use. + Only JPG, PNG and GIF images are supported. + + Raises + ------ + MissingApplicationID + The application ID could not be found. + HTTPException + Creating the emoji failed. + + Returns + ------- + :class:`.Emoji` + The emoji that was created. + """ + if self.application_id is None: + raise MissingApplicationID + + img = utils._bytes_to_base64_data(image) + data = await self.http.create_application_emoji(self.application_id, name, img) + return Emoji(guild=Object(0), state=self._connection, data=data) + + async def fetch_application_emoji(self, emoji_id: int, /) -> Emoji: + """|coro| + + Retrieves an emoji for the current application. + + .. versionadded:: 2.5 + + Parameters + ---------- + emoji_id: :class:`int` + The emoji ID to retrieve. + + Raises + ------ + MissingApplicationID + The application ID could not be found. + HTTPException + Retrieving the emoji failed. + + Returns + ------- + :class:`.Emoji` + The emoji requested. + """ + if self.application_id is None: + raise MissingApplicationID + + data = await self.http.get_application_emoji(self.application_id, emoji_id) + return Emoji(guild=Object(0), state=self._connection, data=data) + + async def fetch_application_emojis(self) -> List[Emoji]: + """|coro| + + Retrieves all emojis for the current application. + + .. versionadded:: 2.5 + + Raises + ------- + MissingApplicationID + The application ID could not be found. + HTTPException + Retrieving the emojis failed. + + Returns + ------- + List[:class:`.Emoji`] + The list of emojis for the current application. + """ + if self.application_id is None: + raise MissingApplicationID + + data = await self.http.get_application_emojis(self.application_id) + return [Emoji(guild=Object(0), state=self._connection, data=emoji) for emoji in data['items']] diff --git a/discord/collectible.py b/discord/collectible.py new file mode 100644 index 000000000000..b2ad7e4e01ba --- /dev/null +++ b/discord/collectible.py @@ -0,0 +1,109 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + + +from .asset import Asset +from .enums import NameplatePalette, CollectibleType, try_enum +from .utils import parse_time + + +if TYPE_CHECKING: + from datetime import datetime + + from .state import ConnectionState + from .types.user import ( + Collectible as CollectiblePayload, + ) + + +__all__ = ('Collectible',) + + +class Collectible: + """Represents a user's collectible. + + .. versionadded:: 2.7 + + Attributes + ---------- + label: :class:`str` + The label of the collectible. + palette: Optional[:class:`NameplatePalette`] + The palette of the collectible. + This is only available if ``type`` is + :class:`CollectibleType.nameplate`. + sku_id: :class:`int` + The SKU ID of the collectible. + type: :class:`CollectibleType` + The type of the collectible. + expires_at: Optional[:class:`datetime.datetime`] + The expiration date of the collectible. If applicable. + """ + + __slots__ = ( + 'type', + 'sku_id', + 'label', + 'expires_at', + 'palette', + '_state', + '_asset', + ) + + def __init__(self, *, state: ConnectionState, type: str, data: CollectiblePayload) -> None: + self._state: ConnectionState = state + self.type: CollectibleType = try_enum(CollectibleType, type) + self._asset: str = data['asset'] + self.sku_id: int = int(data['sku_id']) + self.label: str = data['label'] + self.expires_at: Optional[datetime] = parse_time(data.get('expires_at')) + + # nameplate + self.palette: Optional[NameplatePalette] + try: + self.palette = try_enum(NameplatePalette, data['palette']) # type: ignore + except KeyError: + self.palette = None + + @property + def static(self) -> Asset: + """:class:`Asset`: The static asset of the collectible.""" + return Asset._from_user_collectible(self._state, self._asset) + + @property + def animated(self) -> Asset: + """:class:`Asset`: The animated asset of the collectible.""" + return Asset._from_user_collectible(self._state, self._asset, animated=True) + + def __repr__(self) -> str: + attrs = ['sku_id'] + if self.palette: + attrs.append('palette') + + joined_attrs = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in attrs) + return f'<{self.type.name.title()} {joined_attrs}>' diff --git a/discord/colour.py b/discord/colour.py index 1dc9662ea77a..8c40dac35a1e 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,11 +22,64 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import colorsys +import random +import re + +from typing import TYPE_CHECKING, Optional, Tuple, Union + +if TYPE_CHECKING: + from typing_extensions import Self + +__all__ = ( + 'Colour', + 'Color', +) + +RGB_REGEX = re.compile(r'rgb\s*\((?P[0-9.]+%?)\s*,\s*(?P[0-9.]+%?)\s*,\s*(?P[0-9.]+%?)\s*\)') + + +def parse_hex_number(argument: str) -> Colour: + arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument + try: + value = int(arg, base=16) + if not (0 <= value <= 0xFFFFFF): + raise ValueError('hex number out of range for 24-bit colour') + except ValueError: + raise ValueError('invalid hex digit given') from None + else: + return Color(value=value) + + +def parse_rgb_number(number: str) -> int: + if number[-1] == '%': + value = float(number[:-1]) + if not (0 <= value <= 100): + raise ValueError('rgb percentage can only be between 0 to 100') + return round(255 * (value / 100)) + + value = int(number) + if not (0 <= value <= 255): + raise ValueError('rgb number can only be between 0 to 255') + return value + + +def parse_rgb(argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> Colour: + match = regex.match(argument) + if match is None: + raise ValueError('invalid rgb syntax found') + + red = parse_rgb_number(match.group('r')) + green = parse_rgb_number(match.group('g')) + blue = parse_rgb_number(match.group('b')) + return Color.from_rgb(red, green, blue) + class Colour: """Represents a Discord role colour. This class is similar - to an (red, green, blue) :class:`tuple`. + to a (red, green, blue) :class:`tuple`. There is an alias for this called Color. @@ -50,6 +101,15 @@ class Colour: Returns the hex format for the colour. + .. describe:: int(x) + + Returns the raw colour value. + + .. note:: + + The colour values in the classmethods are mostly provided as-is and can change between + versions should the Discord client's representation of that colour also change. + Attributes ------------ value: :class:`int` @@ -58,173 +118,476 @@ class Colour: __slots__ = ('value',) - def __init__(self, value): + def __init__(self, value: int): if not isinstance(value, int): - raise TypeError('Expected int parameter, received %s instead.' % value.__class__.__name__) + raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.') - self.value = value + self.value: int = value - def _get_byte(self, byte): - return (self.value >> (8 * byte)) & 0xff + def _get_byte(self, byte: int) -> int: + return (self.value >> (8 * byte)) & 0xFF - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, Colour) and self.value == other.value - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __str__(self): - return '#{:0>6x}'.format(self.value) + def __str__(self) -> str: + return f'#{self.value:0>6x}' - def __repr__(self): - return '' % self.value + def __int__(self) -> int: + return self.value - def __hash__(self): + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: return hash(self.value) @property - def r(self): + def r(self) -> int: """:class:`int`: Returns the red component of the colour.""" return self._get_byte(2) @property - def g(self): + def g(self) -> int: """:class:`int`: Returns the green component of the colour.""" return self._get_byte(1) @property - def b(self): + def b(self) -> int: """:class:`int`: Returns the blue component of the colour.""" return self._get_byte(0) - def to_rgb(self): + def to_rgb(self) -> Tuple[int, int, int]: """Tuple[:class:`int`, :class:`int`, :class:`int`]: Returns an (r, g, b) tuple representing the colour.""" return (self.r, self.g, self.b) @classmethod - def from_rgb(cls, r, g, b): + def from_rgb(cls, r: int, g: int, b: int) -> Self: """Constructs a :class:`Colour` from an RGB tuple.""" return cls((r << 16) + (g << 8) + b) @classmethod - def from_hsv(cls, h, s, v): + def from_hsv(cls, h: float, s: float, v: float) -> Self: """Constructs a :class:`Colour` from an HSV tuple.""" rgb = colorsys.hsv_to_rgb(h, s, v) return cls.from_rgb(*(int(x * 255) for x in rgb)) @classmethod - def default(cls): - """A factory method that returns a :class:`Colour` with a value of 0.""" + def from_str(cls, value: str) -> Colour: + """Constructs a :class:`Colour` from a string. + + The following formats are accepted: + + - ``0x`` + - ``#`` + - ``0x#`` + - ``rgb(, , )`` + + Like CSS, ```` can be either 0-255 or 0-100% and ```` can be + either a 6 digit hex number or a 3 digit hex shortcut (e.g. #FFF). + + .. versionadded:: 2.0 + + Raises + ------- + ValueError + The string could not be converted into a colour. + """ + + if not value: + raise ValueError('unknown colour format given') + + if value[0] == '#': + return parse_hex_number(value[1:]) + + if value[0:2] == '0x': + rest = value[2:] + # Legacy backwards compatible syntax + if rest.startswith('#'): + return parse_hex_number(rest[1:]) + return parse_hex_number(rest) + + arg = value.lower() + if arg[0:3] == 'rgb': + return parse_rgb(arg) + + raise ValueError('unknown colour format given') + + @classmethod + def default(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0``. + + .. colour:: #000000 + """ return cls(0) @classmethod - def teal(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" - return cls(0x1abc9c) + def random(cls, *, seed: Optional[Union[int, str, float, bytes, bytearray]] = None) -> Self: + """A factory method that returns a :class:`Colour` with a random hue. + + .. note:: + + The random algorithm works by choosing a colour with a random hue but + with maxed out saturation and value. + + .. versionadded:: 1.6 + + Parameters + ------------ + seed: Optional[Union[:class:`int`, :class:`str`, :class:`float`, :class:`bytes`, :class:`bytearray`]] + The seed to initialize the RNG with. If ``None`` is passed the default RNG is used. + + .. versionadded:: 1.7 + """ + rand = random if seed is None else random.Random(seed) + return cls.from_hsv(rand.random(), 1, 1) @classmethod - def dark_teal(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" - return cls(0x11806a) + def teal(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x1ABC9C``. + + .. colour:: #1ABC9C + """ + return cls(0x1ABC9C) @classmethod - def green(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" - return cls(0x2ecc71) + def dark_teal(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x11806A``. + + .. colour:: #11806A + """ + return cls(0x11806A) @classmethod - def dark_green(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" - return cls(0x1f8b4c) + def brand_green(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x57F287``. + + .. colour:: #57F287 + + + .. versionadded:: 2.0 + """ + return cls(0x57F287) + + @classmethod + def green(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x2ECC71``. + + .. colour:: #2ECC71 + """ + return cls(0x2ECC71) @classmethod - def blue(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" - return cls(0x3498db) + def dark_green(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x1F8B4C``. + + .. colour:: #1F8B4C + """ + return cls(0x1F8B4C) + + @classmethod + def blue(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x3498DB``. + + .. colour:: #3498DB + """ + return cls(0x3498DB) @classmethod - def dark_blue(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x206694``.""" + def dark_blue(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x206694``. + + .. colour:: #206694 + """ return cls(0x206694) @classmethod - def purple(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" - return cls(0x9b59b6) + def purple(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x9B59B6``. + + .. colour:: #9B59B6 + """ + return cls(0x9B59B6) + + @classmethod + def dark_purple(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x71368A``. + + .. colour:: #71368A + """ + return cls(0x71368A) + + @classmethod + def magenta(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xE91E63``. + + .. colour:: #E91E63 + """ + return cls(0xE91E63) + + @classmethod + def dark_magenta(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xAD1457``. + + .. colour:: #AD1457 + """ + return cls(0xAD1457) + + @classmethod + def gold(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xF1C40F``. + + .. colour:: #F1C40F + """ + return cls(0xF1C40F) + + @classmethod + def dark_gold(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xC27C0E``. + + .. colour:: #C27C0E + """ + return cls(0xC27C0E) @classmethod - def dark_purple(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" - return cls(0x71368a) + def orange(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xE67E22``. + + .. colour:: #E67E22 + """ + return cls(0xE67E22) + + @classmethod + def dark_orange(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xA84300``. + + .. colour:: #A84300 + """ + return cls(0xA84300) + + @classmethod + def brand_red(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xED4245``. + + .. colour:: #ED4245 + + .. versionadded:: 2.0 + """ + return cls(0xED4245) @classmethod - def magenta(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" - return cls(0xe91e63) + def red(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xE74C3C``. + + .. colour:: #E74C3C + """ + return cls(0xE74C3C) + + @classmethod + def dark_red(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x992D22``. + + .. colour:: #992D22 + """ + return cls(0x992D22) + + @classmethod + def lighter_grey(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x95A5A6``. + + .. colour:: #95A5A6 + """ + return cls(0x95A5A6) + + lighter_gray = lighter_grey @classmethod - def dark_magenta(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" - return cls(0xad1457) + def dark_grey(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``. + + .. colour:: #607d8b + """ + return cls(0x607D8B) + + dark_gray = dark_grey + + @classmethod + def light_grey(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x979C9F``. + + .. colour:: #979C9F + """ + return cls(0x979C9F) + + light_gray = light_grey + + @classmethod + def darker_grey(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x546E7A``. + + .. colour:: #546E7A + """ + return cls(0x546E7A) + + darker_gray = darker_grey + + @classmethod + def og_blurple(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x7289DA``. + + .. colour:: #7289DA + """ + return cls(0x7289DA) + + @classmethod + def blurple(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x5865F2``. + + .. colour:: #5865F2 + """ + return cls(0x5865F2) @classmethod - def gold(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" - return cls(0xf1c40f) + def greyple(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x99AAB5``. + + .. colour:: #99AAB5 + """ + return cls(0x99AAB5) @classmethod - def dark_gold(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" - return cls(0xc27c0e) + def ash_theme(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x2E2E34``. + + This will appear transparent on Discord's ash theme. + + .. colour:: #2E2E34 + + .. versionadded:: 2.6 + """ + return cls(0x2E2E34) @classmethod - def orange(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" - return cls(0xe67e22) + def dark_theme(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x1A1A1E``. + + This will appear transparent on Discord's dark theme. + + .. colour:: #1A1A1E + + .. versionadded:: 1.5 + + .. versionchanged:: 2.2 + Updated colour from previous ``0x36393F`` to reflect discord theme changes. + + .. versionchanged:: 2.6 + Updated colour from previous ``0x313338`` to reflect discord theme changes. + """ + return cls(0x1A1A1E) @classmethod - def dark_orange(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" - return cls(0xa84300) + def onyx_theme(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x070709``. + + This will appear transparent on Discord's onyx theme. + + .. colour:: #070709 + + .. versionadded:: 2.6 + """ + return cls(0x070709) @classmethod - def red(cls): - """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" - return cls(0xe74c3c) + def light_theme(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xFBFBFB``. + + This will appear transparent on Discord's light theme. + + .. colour:: #FBFBFB + + .. versionadded:: 2.6 + """ + return cls(0xFBFBFB) @classmethod - def dark_red(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" - return cls(0x992d22) + def fuchsia(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xEB459E``. + + .. colour:: #EB459E + + .. versionadded:: 2.0 + """ + return cls(0xEB459E) @classmethod - def lighter_grey(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" - return cls(0x95a5a6) + def yellow(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xFEE75C``. + + .. colour:: #FEE75C + + .. versionadded:: 2.0 + """ + return cls(0xFEE75C) @classmethod - def dark_grey(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" - return cls(0x607d8b) + def ash_embed(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x37373E``. + + .. colour:: #37373E + + .. versionadded:: 2.6 + + """ + return cls(0x37373E) @classmethod - def light_grey(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" - return cls(0x979c9f) + def dark_embed(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x242429``. + + .. colour:: #242429 + + .. versionadded:: 2.2 + + .. versionchanged:: 2.6 + Updated colour from previous ``0x2B2D31`` to reflect discord theme changes. + """ + return cls(0x242429) @classmethod - def darker_grey(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" - return cls(0x546e7a) + def onyx_embed(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0x131416``. + + .. colour:: #131416 + + .. versionadded:: 2.6 + """ + return cls(0x131416) @classmethod - def blurple(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" - return cls(0x7289da) + def light_embed(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xFFFFFF``. + + .. colour:: #EEEFF1 + + .. versionadded:: 2.2 + + .. versionchanged:: 2.6 + Updated colour from previous ``0xEEEFF1`` to reflect discord theme changes. + """ + return cls(0xFFFFFF) @classmethod - def greyple(cls): - """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" - return cls(0x99aab5) + def pink(cls) -> Self: + """A factory method that returns a :class:`Colour` with a value of ``0xEB459F``. + + .. colour:: #EB459F + + .. versionadded:: 2.3 + """ + return cls(0xEB459F) + Color = Colour diff --git a/discord/components.py b/discord/components.py new file mode 100644 index 000000000000..06caf24f2f4a --- /dev/null +++ b/discord/components.py @@ -0,0 +1,1482 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import ( + ClassVar, + List, + Literal, + Optional, + TYPE_CHECKING, + Tuple, + Union, +) + +from .asset import AssetMixin +from .enums import ( + try_enum, + ComponentType, + ButtonStyle, + TextStyle, + ChannelType, + SelectDefaultValueType, + SeparatorSpacing, + MediaItemLoadingState, +) +from .flags import AttachmentFlags +from .colour import Colour +from .file import File +from .utils import get_slots, MISSING, _get_as_snowflake +from .partial_emoji import PartialEmoji, _EmojiTag + +if TYPE_CHECKING: + from typing_extensions import Self + + from .types.components import ( + Component as ComponentPayload, + ButtonComponent as ButtonComponentPayload, + SelectMenu as SelectMenuPayload, + SelectOption as SelectOptionPayload, + ActionRow as ActionRowPayload, + TextInput as TextInputPayload, + SelectDefaultValues as SelectDefaultValuesPayload, + SectionComponent as SectionComponentPayload, + TextComponent as TextComponentPayload, + MediaGalleryComponent as MediaGalleryComponentPayload, + FileComponent as FileComponentPayload, + SeparatorComponent as SeparatorComponentPayload, + MediaGalleryItem as MediaGalleryItemPayload, + ThumbnailComponent as ThumbnailComponentPayload, + ContainerComponent as ContainerComponentPayload, + UnfurledMediaItem as UnfurledMediaItemPayload, + LabelComponent as LabelComponentPayload, + FileUploadComponent as FileUploadComponentPayload, + ) + + from .emoji import Emoji + from .abc import Snowflake + from .state import ConnectionState + + ActionRowChildComponentType = Union['Button', 'SelectMenu', 'TextInput'] + SectionComponentType = Union['TextDisplay'] + MessageComponentType = Union[ + ActionRowChildComponentType, + SectionComponentType, + 'ActionRow', + 'SectionComponent', + 'ThumbnailComponent', + 'MediaGalleryComponent', + 'FileComponent', + 'SectionComponent', + 'Component', + ] + + +__all__ = ( + 'Component', + 'ActionRow', + 'Button', + 'SelectMenu', + 'SelectOption', + 'TextInput', + 'SelectDefaultValue', + 'SectionComponent', + 'ThumbnailComponent', + 'UnfurledMediaItem', + 'MediaGalleryItem', + 'MediaGalleryComponent', + 'FileComponent', + 'SectionComponent', + 'Container', + 'TextDisplay', + 'SeparatorComponent', + 'LabelComponent', + 'FileUploadComponent', +) + + +class Component: + """Represents a Discord Bot UI Kit Component. + + The components supported by Discord are: + + - :class:`ActionRow` + - :class:`Button` + - :class:`SelectMenu` + - :class:`TextInput` + - :class:`SectionComponent` + - :class:`TextDisplay` + - :class:`ThumbnailComponent` + - :class:`MediaGalleryComponent` + - :class:`FileComponent` + - :class:`SeparatorComponent` + - :class:`Container` + - :class:`LabelComponent` + - :class:`FileUploadComponent` + + This class is abstract and cannot be instantiated. + + .. versionadded:: 2.0 + """ + + __slots__: Tuple[str, ...] = () + + __repr_info__: ClassVar[Tuple[str, ...]] + + def __repr__(self) -> str: + attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) + return f'<{self.__class__.__name__} {attrs}>' + + @property + def type(self) -> ComponentType: + """:class:`ComponentType`: The type of component.""" + raise NotImplementedError + + @classmethod + def _raw_construct(cls, **kwargs) -> Self: + self = cls.__new__(cls) + for slot in get_slots(cls): + try: + value = kwargs[slot] + except KeyError: + pass + else: + setattr(self, slot, value) + return self + + def to_dict(self) -> ComponentPayload: + raise NotImplementedError + + +class ActionRow(Component): + """Represents a Discord Bot UI Kit Action Row. + + This is a component that holds up to 5 children components in a row. + + This inherits from :class:`Component`. + + .. versionadded:: 2.0 + + Attributes + ------------ + children: List[Union[:class:`Button`, :class:`SelectMenu`, :class:`TextInput`]] + The children components that this holds, if any. + id: Optional[:class:`int`] + The ID of this component. + + .. versionadded:: 2.6 + """ + + __slots__: Tuple[str, ...] = ('children', 'id') + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: ActionRowPayload, /) -> None: + self.id: Optional[int] = data.get('id') + self.children: List[ActionRowChildComponentType] = [] + + for component_data in data.get('components', []): + component = _component_factory(component_data) + + if component is not None: + self.children.append(component) # type: ignore # should be the correct type here + + @property + def type(self) -> Literal[ComponentType.action_row]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.action_row + + def to_dict(self) -> ActionRowPayload: + payload: ActionRowPayload = { + 'type': self.type.value, + 'components': [child.to_dict() for child in self.children], + } + if self.id is not None: + payload['id'] = self.id + return payload + + +class Button(Component): + """Represents a button from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type to create a button is :class:`discord.ui.Button` + not this one. + + .. versionadded:: 2.0 + + Attributes + ----------- + style: :class:`.ButtonStyle` + The style of the button. + custom_id: Optional[:class:`str`] + The ID of the button that gets received during an interaction. + If this button is for a URL, it does not have a custom ID. + url: Optional[:class:`str`] + The URL this button sends you to. + disabled: :class:`bool` + Whether the button is disabled or not. + label: Optional[:class:`str`] + The label of the button, if any. + emoji: Optional[:class:`PartialEmoji`] + The emoji of the button, if available. + sku_id: Optional[:class:`int`] + The SKU ID this button sends you to, if available. + + .. versionadded:: 2.4 + id: Optional[:class:`int`] + The ID of this component. + + .. versionadded:: 2.6 + """ + + __slots__: Tuple[str, ...] = ( + 'style', + 'custom_id', + 'url', + 'disabled', + 'label', + 'emoji', + 'sku_id', + 'id', + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: ButtonComponentPayload, /) -> None: + self.id: Optional[int] = data.get('id') + self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) + self.custom_id: Optional[str] = data.get('custom_id') + self.url: Optional[str] = data.get('url') + self.disabled: bool = data.get('disabled', False) + self.label: Optional[str] = data.get('label') + self.emoji: Optional[PartialEmoji] + try: + self.emoji = PartialEmoji.from_dict(data['emoji']) # pyright: ignore[reportTypedDictNotRequiredAccess] + except KeyError: + self.emoji = None + + try: + self.sku_id: Optional[int] = int(data['sku_id']) # pyright: ignore[reportTypedDictNotRequiredAccess] + except KeyError: + self.sku_id = None + + @property + def type(self) -> Literal[ComponentType.button]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.button + + def to_dict(self) -> ButtonComponentPayload: + payload: ButtonComponentPayload = { + 'type': 2, + 'style': self.style.value, + 'disabled': self.disabled, + } + + if self.id is not None: + payload['id'] = self.id + + if self.sku_id: + payload['sku_id'] = str(self.sku_id) + + if self.label: + payload['label'] = self.label + + if self.custom_id: + payload['custom_id'] = self.custom_id + + if self.url: + payload['url'] = self.url + + if self.emoji: + payload['emoji'] = self.emoji.to_dict() + + return payload + + +class SelectMenu(Component): + """Represents a select menu from the Discord Bot UI Kit. + + A select menu is functionally the same as a dropdown, however + on mobile it renders a bit differently. + + .. note:: + + The user constructible and usable type to create a select menu is + :class:`discord.ui.Select` not this one. + + .. versionadded:: 2.0 + + Attributes + ------------ + type: :class:`ComponentType` + The type of component. + custom_id: Optional[:class:`str`] + The ID of the select menu that gets received during an interaction. + placeholder: Optional[:class:`str`] + The placeholder text that is shown if nothing is selected, if any. + min_values: :class:`int` + The minimum number of items that must be chosen for this select menu. + Defaults to 1 and must be between 0 and 25. + max_values: :class:`int` + The maximum number of items that must be chosen for this select menu. + Defaults to 1 and must be between 1 and 25. + options: List[:class:`SelectOption`] + A list of options that can be selected in this menu. + disabled: :class:`bool` + Whether the select is disabled or not. + channel_types: List[:class:`.ChannelType`] + A list of channel types that are allowed to be chosen in this select menu. + id: Optional[:class:`int`] + The ID of this component. + + .. versionadded:: 2.6 + required: :class:`bool` + Whether the select is required. Only applicable within modals. + + .. versionadded:: 2.6 + """ + + __slots__: Tuple[str, ...] = ( + 'type', + 'custom_id', + 'placeholder', + 'min_values', + 'max_values', + 'options', + 'disabled', + 'channel_types', + 'default_values', + 'required', + 'id', + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: SelectMenuPayload, /) -> None: + self.type: ComponentType = try_enum(ComponentType, data['type']) + self.custom_id: str = data['custom_id'] + self.placeholder: Optional[str] = data.get('placeholder') + self.min_values: int = data.get('min_values', 1) + self.max_values: int = data.get('max_values', 1) + self.required: bool = data.get('required', False) + self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] + self.disabled: bool = data.get('disabled', False) + self.channel_types: List[ChannelType] = [try_enum(ChannelType, t) for t in data.get('channel_types', [])] + self.default_values: List[SelectDefaultValue] = [ + SelectDefaultValue.from_dict(d) for d in data.get('default_values', []) + ] + self.id: Optional[int] = data.get('id') + + def to_dict(self) -> SelectMenuPayload: + payload: SelectMenuPayload = { + 'type': self.type.value, # type: ignore # we know this is a select menu. + 'custom_id': self.custom_id, + 'min_values': self.min_values, + 'max_values': self.max_values, + 'disabled': self.disabled, + 'required': self.required, + } + if self.id is not None: + payload['id'] = self.id + if self.placeholder: + payload['placeholder'] = self.placeholder + if self.options: + payload['options'] = [op.to_dict() for op in self.options] + if self.channel_types: + payload['channel_types'] = [t.value for t in self.channel_types] + if self.default_values: + payload['default_values'] = [v.to_dict() for v in self.default_values] + + return payload + + +class SelectOption: + """Represents a select menu's option. + + These can be created by users. + + .. versionadded:: 2.0 + + Parameters + ----------- + label: :class:`str` + The label of the option. This is displayed to users. + Can only be up to 100 characters. + value: :class:`str` + The value of the option. This is not displayed to users. + If not provided when constructed then it defaults to the label. + Can only be up to 100 characters. + description: Optional[:class:`str`] + An additional description of the option, if any. + Can only be up to 100 characters. + emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]] + The emoji of the option, if available. + default: :class:`bool` + Whether this option is selected by default. + + Attributes + ----------- + label: :class:`str` + The label of the option. This is displayed to users. + value: :class:`str` + The value of the option. This is not displayed to users. + If not provided when constructed then it defaults to the + label. + description: Optional[:class:`str`] + An additional description of the option, if any. + default: :class:`bool` + Whether this option is selected by default. + """ + + __slots__: Tuple[str, ...] = ( + 'label', + 'value', + 'description', + '_emoji', + 'default', + ) + + def __init__( + self, + *, + label: str, + value: str = MISSING, + description: Optional[str] = None, + emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, + default: bool = False, + ) -> None: + self.label: str = label + self.value: str = label if value is MISSING else value + self.description: Optional[str] = description + + self.emoji = emoji + self.default: bool = default + + def __repr__(self) -> str: + return ( + f'' + ) + + def __str__(self) -> str: + if self.emoji: + base = f'{self.emoji} {self.label}' + else: + base = self.label + + if self.description: + return f'{base}\n{self.description}' + return base + + @property + def emoji(self) -> Optional[PartialEmoji]: + """Optional[:class:`.PartialEmoji`]: The emoji of the option, if available.""" + return self._emoji + + @emoji.setter + def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]) -> None: + if value is not None: + if isinstance(value, str): + self._emoji = PartialEmoji.from_str(value) + elif isinstance(value, _EmojiTag): + self._emoji = value._to_partial() + else: + raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__.__name__} instead') + else: + self._emoji = None + + @classmethod + def from_dict(cls, data: SelectOptionPayload) -> SelectOption: + try: + emoji = PartialEmoji.from_dict(data['emoji']) # pyright: ignore[reportTypedDictNotRequiredAccess] + except KeyError: + emoji = None + + return cls( + label=data['label'], + value=data['value'], + description=data.get('description'), + emoji=emoji, + default=data.get('default', False), + ) + + def to_dict(self) -> SelectOptionPayload: + payload: SelectOptionPayload = { + 'label': self.label, + 'value': self.value, + 'default': self.default, + } + + if self.emoji: + payload['emoji'] = self.emoji.to_dict() + + if self.description: + payload['description'] = self.description + + return payload + + def copy(self) -> SelectOption: + return self.__class__.from_dict(self.to_dict()) + + +class TextInput(Component): + """Represents a text input from the Discord Bot UI Kit. + + .. note:: + The user constructible and usable type to create a text input is + :class:`discord.ui.TextInput` not this one. + + .. versionadded:: 2.0 + + Attributes + ------------ + custom_id: Optional[:class:`str`] + The ID of the text input that gets received during an interaction. + label: Optional[:class:`str`] + The label to display above the text input. + style: :class:`TextStyle` + The style of the text input. + placeholder: Optional[:class:`str`] + The placeholder text to display when the text input is empty. + value: Optional[:class:`str`] + The default value of the text input. + required: :class:`bool` + Whether the text input is required. + min_length: Optional[:class:`int`] + The minimum length of the text input. + max_length: Optional[:class:`int`] + The maximum length of the text input. + id: Optional[:class:`int`] + The ID of this component. + + .. versionadded:: 2.6 + """ + + __slots__: Tuple[str, ...] = ( + 'style', + 'label', + 'custom_id', + 'placeholder', + 'value', + 'required', + 'min_length', + 'max_length', + 'id', + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: TextInputPayload, /) -> None: + self.style: TextStyle = try_enum(TextStyle, data['style']) + self.label: Optional[str] = data.get('label') + self.custom_id: str = data['custom_id'] + self.placeholder: Optional[str] = data.get('placeholder') + self.value: Optional[str] = data.get('value') + self.required: bool = data.get('required', True) + self.min_length: Optional[int] = data.get('min_length') + self.max_length: Optional[int] = data.get('max_length') + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.text_input]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.text_input + + def to_dict(self) -> TextInputPayload: + payload: TextInputPayload = { + 'type': self.type.value, + 'style': self.style.value, + 'label': self.label, + 'custom_id': self.custom_id, + 'required': self.required, + } + + if self.id is not None: + payload['id'] = self.id + + if self.placeholder: + payload['placeholder'] = self.placeholder + + if self.value: + payload['value'] = self.value + + if self.min_length: + payload['min_length'] = self.min_length + + if self.max_length: + payload['max_length'] = self.max_length + + return payload + + @property + def default(self) -> Optional[str]: + """Optional[:class:`str`]: The default value of the text input. + + This is an alias to :attr:`value`. + """ + return self.value + + +class SelectDefaultValue: + """Represents a select menu's default value. + + These can be created by users. + + .. versionadded:: 2.4 + + Parameters + ----------- + id: :class:`int` + The id of a role, user, or channel. + type: :class:`SelectDefaultValueType` + The type of value that ``id`` represents. + """ + + def __init__( + self, + *, + id: int, + type: SelectDefaultValueType, + ) -> None: + self.id: int = id + self._type: SelectDefaultValueType = type + + @property + def type(self) -> SelectDefaultValueType: + """:class:`SelectDefaultValueType`: The type of value that ``id`` represents.""" + return self._type + + @type.setter + def type(self, value: SelectDefaultValueType) -> None: + if not isinstance(value, SelectDefaultValueType): + raise TypeError(f'expected SelectDefaultValueType, received {value.__class__.__name__} instead') + + self._type = value + + def __repr__(self) -> str: + return f'' + + @classmethod + def from_dict(cls, data: SelectDefaultValuesPayload) -> SelectDefaultValue: + return cls( + id=data['id'], + type=try_enum(SelectDefaultValueType, data['type']), + ) + + def to_dict(self) -> SelectDefaultValuesPayload: + return { + 'id': self.id, + 'type': self._type.value, + } + + @classmethod + def from_channel(cls, channel: Snowflake, /) -> Self: + """Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.channel`. + + Parameters + ----------- + channel: :class:`~discord.abc.Snowflake` + The channel to create the default value for. + + Returns + -------- + :class:`SelectDefaultValue` + The default value created with the channel. + """ + return cls( + id=channel.id, + type=SelectDefaultValueType.channel, + ) + + @classmethod + def from_role(cls, role: Snowflake, /) -> Self: + """Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.role`. + + Parameters + ----------- + role: :class:`~discord.abc.Snowflake` + The role to create the default value for. + + Returns + -------- + :class:`SelectDefaultValue` + The default value created with the role. + """ + return cls( + id=role.id, + type=SelectDefaultValueType.role, + ) + + @classmethod + def from_user(cls, user: Snowflake, /) -> Self: + """Creates a :class:`SelectDefaultValue` with the type set to :attr:`~SelectDefaultValueType.user`. + + Parameters + ----------- + user: :class:`~discord.abc.Snowflake` + The user to create the default value for. + + Returns + -------- + :class:`SelectDefaultValue` + The default value created with the user. + """ + return cls( + id=user.id, + type=SelectDefaultValueType.user, + ) + + +class SectionComponent(Component): + """Represents a section from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type to create a section is :class:`discord.ui.Section` + not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + children: List[:class:`TextDisplay`] + The components on this section. + accessory: :class:`Component` + The section accessory. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ( + 'children', + 'accessory', + 'id', + ) + + __repr_info__ = __slots__ + + def __init__(self, data: SectionComponentPayload, state: Optional[ConnectionState]) -> None: + self.children: List[SectionComponentType] = [] + self.accessory: Component = _component_factory(data['accessory'], state) # type: ignore + self.id: Optional[int] = data.get('id') + + for component_data in data['components']: + component = _component_factory(component_data, state) + if component is not None: + self.children.append(component) # type: ignore # should be the correct type here + + @property + def type(self) -> Literal[ComponentType.section]: + return ComponentType.section + + def to_dict(self) -> SectionComponentPayload: + payload: SectionComponentPayload = { + 'type': self.type.value, + 'components': [c.to_dict() for c in self.children], + 'accessory': self.accessory.to_dict(), + } + + if self.id is not None: + payload['id'] = self.id + + return payload + + +class ThumbnailComponent(Component): + """Represents a Thumbnail from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type to create a thumbnail is :class:`discord.ui.Thumbnail` + not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + media: :class:`UnfurledMediaItem` + The media for this thumbnail. + description: Optional[:class:`str`] + The description shown within this thumbnail. + spoiler: :class:`bool` + Whether this thumbnail is flagged as a spoiler. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ( + 'media', + 'spoiler', + 'description', + 'id', + ) + + __repr_info__ = __slots__ + + def __init__( + self, + data: ThumbnailComponentPayload, + state: Optional[ConnectionState], + ) -> None: + self.media: UnfurledMediaItem = UnfurledMediaItem._from_data(data['media'], state) + self.description: Optional[str] = data.get('description') + self.spoiler: bool = data.get('spoiler', False) + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.thumbnail]: + return ComponentType.thumbnail + + def to_dict(self) -> ThumbnailComponentPayload: + payload = { + 'media': self.media.to_dict(), + 'description': self.description, + 'spoiler': self.spoiler, + 'type': self.type.value, + } + + if self.id is not None: + payload['id'] = self.id + + return payload # type: ignore + + +class TextDisplay(Component): + """Represents a text display from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type to create a text display is + :class:`discord.ui.TextDisplay` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + content: :class:`str` + The content that this display shows. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ('content', 'id') + + __repr_info__ = __slots__ + + def __init__(self, data: TextComponentPayload) -> None: + self.content: str = data['content'] + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.text_display]: + return ComponentType.text_display + + def to_dict(self) -> TextComponentPayload: + payload: TextComponentPayload = { + 'type': self.type.value, + 'content': self.content, + } + if self.id is not None: + payload['id'] = self.id + return payload + + +class UnfurledMediaItem(AssetMixin): + """Represents an unfurled media item. + + .. versionadded:: 2.6 + + Parameters + ---------- + url: :class:`str` + The URL of this media item. This can be an arbitrary url or a reference to a local + file uploaded as an attachment within the message, which can be accessed with the + ``attachment://`` format. + + Attributes + ---------- + url: :class:`str` + The URL of this media item. + proxy_url: Optional[:class:`str`] + The proxy URL. This is a cached version of the :attr:`.url` in the + case of images. When the message is deleted, this URL might be valid for a few minutes + or not valid at all. + height: Optional[:class:`int`] + The media item's height, in pixels. Only applicable to images and videos. + width: Optional[:class:`int`] + The media item's width, in pixels. Only applicable to images and videos. + content_type: Optional[:class:`str`] + The media item's `media type `_ + placeholder: Optional[:class:`str`] + The media item's placeholder. + loading_state: Optional[:class:`MediaItemLoadingState`] + The loading state of this media item. + attachment_id: Optional[:class:`int`] + The attachment id this media item points to, only available if the url points to a local file + uploaded within the component message. + """ + + __slots__ = ( + 'url', + 'proxy_url', + 'height', + 'width', + 'content_type', + '_flags', + 'placeholder', + 'loading_state', + 'attachment_id', + '_state', + ) + + def __init__(self, url: str) -> None: + self.url: str = url + + self.proxy_url: Optional[str] = None + self.height: Optional[int] = None + self.width: Optional[int] = None + self.content_type: Optional[str] = None + self._flags: int = 0 + self.placeholder: Optional[str] = None + self.loading_state: Optional[MediaItemLoadingState] = None + self.attachment_id: Optional[int] = None + self._state: Optional[ConnectionState] = None + + @property + def flags(self) -> AttachmentFlags: + """:class:`AttachmentFlags`: This media item's flags.""" + return AttachmentFlags._from_value(self._flags) + + @classmethod + def _from_data(cls, data: UnfurledMediaItemPayload, state: Optional[ConnectionState]): + self = cls(data['url']) + self._update(data, state) + return self + + def _update(self, data: UnfurledMediaItemPayload, state: Optional[ConnectionState]) -> None: + self.proxy_url = data.get('proxy_url') + self.height = data.get('height') + self.width = data.get('width') + self.content_type = data.get('content_type') + self._flags = data.get('flags', 0) + self.placeholder = data.get('placeholder') + + loading_state = data.get('loading_state') + if loading_state is not None: + self.loading_state = try_enum(MediaItemLoadingState, loading_state) + self.attachment_id = _get_as_snowflake(data, 'attachment_id') + self._state = state + + def __repr__(self) -> str: + return f'' + + def to_dict(self): + return { + 'url': self.url, + } + + +class MediaGalleryItem: + """Represents a :class:`MediaGalleryComponent` media item. + + .. versionadded:: 2.6 + + Parameters + ---------- + media: Union[:class:`str`, :class:`discord.File`, :class:`UnfurledMediaItem`] + The media item data. This can be a string representing a local + file uploaded as an attachment in the message, which can be accessed + using the ``attachment://`` format, or an arbitrary url. + description: Optional[:class:`str`] + The description to show within this item. Up to 256 characters. Defaults + to ``None``. + spoiler: :class:`bool` + Whether this item should be flagged as a spoiler. + """ + + __slots__ = ( + '_media', + 'description', + 'spoiler', + '_state', + ) + + def __init__( + self, + media: Union[str, File, UnfurledMediaItem], + *, + description: Optional[str] = MISSING, + spoiler: bool = MISSING, + ) -> None: + self.media = media + + if isinstance(media, File): + if description is MISSING: + description = media.description + if spoiler is MISSING: + spoiler = media.spoiler + + self.description: Optional[str] = None if description is MISSING else description + self.spoiler: bool = bool(spoiler) + self._state: Optional[ConnectionState] = None + + def __repr__(self) -> str: + return f'' + + @property + def media(self) -> UnfurledMediaItem: + """:class:`UnfurledMediaItem`: This item's media data.""" + return self._media + + @media.setter + def media(self, value: Union[str, File, UnfurledMediaItem]) -> None: + if isinstance(value, str): + self._media = UnfurledMediaItem(value) + elif isinstance(value, UnfurledMediaItem): + self._media = value + elif isinstance(value, File): + self._media = UnfurledMediaItem(value.uri) + else: + raise TypeError(f'Expected a str or UnfurledMediaItem, not {value.__class__.__name__}') + + @classmethod + def _from_data(cls, data: MediaGalleryItemPayload, state: Optional[ConnectionState]) -> MediaGalleryItem: + media = data['media'] + self = cls( + media=UnfurledMediaItem._from_data(media, state), + description=data.get('description'), + spoiler=data.get('spoiler', False), + ) + self._state = state + return self + + @classmethod + def _from_gallery( + cls, + items: List[MediaGalleryItemPayload], + state: Optional[ConnectionState], + ) -> List[MediaGalleryItem]: + return [cls._from_data(item, state) for item in items] + + def to_dict(self) -> MediaGalleryItemPayload: + payload: MediaGalleryItemPayload = { + 'media': self.media.to_dict(), # type: ignore + 'spoiler': self.spoiler, + } + + if self.description: + payload['description'] = self.description + + return payload + + +class MediaGalleryComponent(Component): + """Represents a Media Gallery component from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for creating a media gallery is + :class:`discord.ui.MediaGallery` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + items: List[:class:`MediaGalleryItem`] + The items this gallery has. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ('items', 'id') + + __repr_info__ = __slots__ + + def __init__(self, data: MediaGalleryComponentPayload, state: Optional[ConnectionState]) -> None: + self.items: List[MediaGalleryItem] = MediaGalleryItem._from_gallery(data['items'], state) + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.media_gallery]: + return ComponentType.media_gallery + + def to_dict(self) -> MediaGalleryComponentPayload: + payload: MediaGalleryComponentPayload = { + 'type': self.type.value, + 'items': [item.to_dict() for item in self.items], + } + if self.id is not None: + payload['id'] = self.id + return payload + + +class FileComponent(Component): + """Represents a File component from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for create a file component is + :class:`discord.ui.File` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + media: :class:`UnfurledMediaItem` + The unfurled attachment contents of the file. + spoiler: :class:`bool` + Whether this file is flagged as a spoiler. + id: Optional[:class:`int`] + The ID of this component. + name: Optional[:class:`str`] + The displayed file name, only available when received from the API. + size: Optional[:class:`int`] + The file size in MiB, only available when received from the API. + """ + + __slots__ = ( + 'media', + 'spoiler', + 'id', + 'name', + 'size', + ) + + __repr_info__ = __slots__ + + def __init__(self, data: FileComponentPayload, state: Optional[ConnectionState]) -> None: + self.media: UnfurledMediaItem = UnfurledMediaItem._from_data(data['file'], state) + self.spoiler: bool = data.get('spoiler', False) + self.id: Optional[int] = data.get('id') + self.name: Optional[str] = data.get('name') + self.size: Optional[int] = data.get('size') + + @property + def type(self) -> Literal[ComponentType.file]: + return ComponentType.file + + def to_dict(self) -> FileComponentPayload: + payload: FileComponentPayload = { + 'type': self.type.value, + 'file': self.media.to_dict(), # type: ignore + 'spoiler': self.spoiler, + } + if self.id is not None: + payload['id'] = self.id + return payload + + +class SeparatorComponent(Component): + """Represents a Separator from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for creating a separator is + :class:`discord.ui.Separator` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + spacing: :class:`SeparatorSpacing` + The spacing size of the separator. + visible: :class:`bool` + Whether this separator is visible and shows a divider. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ( + 'spacing', + 'visible', + 'id', + ) + + __repr_info__ = __slots__ + + def __init__( + self, + data: SeparatorComponentPayload, + ) -> None: + self.spacing: SeparatorSpacing = try_enum(SeparatorSpacing, data.get('spacing', 1)) + self.visible: bool = data.get('divider', True) + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.separator]: + return ComponentType.separator + + def to_dict(self) -> SeparatorComponentPayload: + payload: SeparatorComponentPayload = { + 'type': self.type.value, + 'divider': self.visible, + 'spacing': self.spacing.value, + } + if self.id is not None: + payload['id'] = self.id + return payload + + +class Container(Component): + """Represents a Container from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for creating a container is + :class:`discord.ui.Container` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + children: :class:`Component` + This container's children. + spoiler: :class:`bool` + Whether this container is flagged as a spoiler. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ( + 'children', + 'id', + 'spoiler', + '_colour', + ) + + __repr_info__ = ( + 'children', + 'id', + 'spoiler', + 'accent_colour', + ) + + def __init__(self, data: ContainerComponentPayload, state: Optional[ConnectionState]) -> None: + self.children: List[Component] = [] + self.id: Optional[int] = data.get('id') + + for child in data['components']: + comp = _component_factory(child, state) + + if comp: + self.children.append(comp) + + self.spoiler: bool = data.get('spoiler', False) + + colour = data.get('accent_color') + self._colour: Optional[Colour] = None + if colour is not None: + self._colour = Colour(colour) + + @property + def accent_colour(self) -> Optional[Colour]: + """Optional[:class:`Colour`]: The container's accent colour.""" + return self._colour + + accent_color = accent_colour + + @property + def type(self) -> Literal[ComponentType.container]: + return ComponentType.container + + def to_dict(self) -> ContainerComponentPayload: + payload: ContainerComponentPayload = { + 'type': self.type.value, + 'spoiler': self.spoiler, + 'components': [c.to_dict() for c in self.children], # pyright: ignore[reportAssignmentType] + } + if self.id is not None: + payload['id'] = self.id + if self._colour: + payload['accent_color'] = self._colour.value + return payload + + +class LabelComponent(Component): + """Represents a label component from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for creating a label is + :class:`discord.ui.Label` not this one. + + .. versionadded:: 2.6 + + Attributes + ---------- + label: :class:`str` + The label text to display. + description: Optional[:class:`str`] + The description text to display below the label, if any. + component: :class:`Component` + The component that this label is associated with. + id: Optional[:class:`int`] + The ID of this component. + """ + + __slots__ = ( + 'label', + 'description', + 'component', + 'id', + ) + + __repr_info__ = ('label', 'description', 'component') + + def __init__(self, data: LabelComponentPayload, state: Optional[ConnectionState]) -> None: + self.component: Component = _component_factory(data['component'], state) # type: ignore + self.label: str = data['label'] + self.id: Optional[int] = data.get('id') + self.description: Optional[str] = data.get('description') + + @property + def type(self) -> Literal[ComponentType.label]: + return ComponentType.label + + def to_dict(self) -> LabelComponentPayload: + payload: LabelComponentPayload = { + 'type': self.type.value, + 'label': self.label, + 'component': self.component.to_dict(), # type: ignore + } + if self.description: + payload['description'] = self.description + if self.id is not None: + payload['id'] = self.id + return payload + + +class FileUploadComponent(Component): + """Represents a file upload component from the Discord Bot UI Kit. + + This inherits from :class:`Component`. + + .. note:: + + The user constructible and usable type for creating a file upload is + :class:`discord.ui.FileUpload` not this one. + + .. versionadded:: 2.7 + + Attributes + ------------ + custom_id: Optional[:class:`str`] + The ID of the component that gets received during an interaction. + min_values: :class:`int` + The minimum number of files that must be uploaded for this component. + Defaults to 1 and must be between 0 and 10. + max_values: :class:`int` + The maximum number of files that must be uploaded for this component. + Defaults to 1 and must be between 1 and 10. + id: Optional[:class:`int`] + The ID of this component. + required: :class:`bool` + Whether the component is required. + Defaults to ``True``. + """ + + __slots__: Tuple[str, ...] = ( + 'custom_id', + 'min_values', + 'max_values', + 'required', + 'id', + ) + + __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ + + def __init__(self, data: FileUploadComponentPayload, /) -> None: + self.custom_id: str = data['custom_id'] + self.min_values: int = data.get('min_values', 1) + self.max_values: int = data.get('max_values', 1) + self.required: bool = data.get('required', True) + self.id: Optional[int] = data.get('id') + + @property + def type(self) -> Literal[ComponentType.file_upload]: + """:class:`ComponentType`: The type of component.""" + return ComponentType.file_upload + + def to_dict(self) -> FileUploadComponentPayload: + payload: FileUploadComponentPayload = { + 'type': self.type.value, + 'custom_id': self.custom_id, + 'min_values': self.min_values, + 'max_values': self.max_values, + 'required': self.required, + } + if self.id is not None: + payload['id'] = self.id + + return payload + + +def _component_factory(data: ComponentPayload, state: Optional[ConnectionState] = None) -> Optional[Component]: + if data['type'] == 1: + return ActionRow(data) + elif data['type'] == 2: + return Button(data) + elif data['type'] == 4: + return TextInput(data) + elif data['type'] in (3, 5, 6, 7, 8): + return SelectMenu(data) # type: ignore + elif data['type'] == 9: + return SectionComponent(data, state) + elif data['type'] == 10: + return TextDisplay(data) + elif data['type'] == 11: + return ThumbnailComponent(data, state) + elif data['type'] == 12: + return MediaGalleryComponent(data, state) + elif data['type'] == 13: + return FileComponent(data, state) + elif data['type'] == 14: + return SeparatorComponent(data) + elif data['type'] == 17: + return Container(data, state) + elif data['type'] == 18: + return LabelComponent(data, state) + elif data['type'] == 19: + return FileUploadComponent(data) diff --git a/discord/context_managers.py b/discord/context_managers.py index e7ac501c61ab..09803c95320e 100644 --- a/discord/context_managers.py +++ b/discord/context_managers.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,44 +22,71 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio +from typing import TYPE_CHECKING, Generator, Optional, Type, TypeVar + +if TYPE_CHECKING: + from .abc import Messageable, MessageableChannel + + from types import TracebackType + + BE = TypeVar('BE', bound=BaseException) -def _typing_done_callback(fut): +# fmt: off +__all__ = ( + 'Typing', +) +# fmt: on + + +def _typing_done_callback(fut: asyncio.Future) -> None: # just retrieve any exception and call it a day try: fut.exception() except (asyncio.CancelledError, Exception): pass + class Typing: - def __init__(self, messageable): - self.loop = messageable._state.loop - self.messageable = messageable + def __init__(self, messageable: Messageable) -> None: + self.loop: asyncio.AbstractEventLoop = messageable._state.loop + self.messageable: Messageable = messageable + self.channel: Optional[MessageableChannel] = None - async def do_typing(self): - try: - channel = self._channel - except AttributeError: - channel = await self.messageable._get_channel() + async def _get_channel(self) -> MessageableChannel: + if self.channel: + return self.channel + self.channel = channel = await self.messageable._get_channel() + return channel + + async def wrapped_typer(self) -> None: + channel = await self._get_channel() + await channel._state.http.send_typing(channel.id) + + def __await__(self) -> Generator[None, None, None]: + return self.wrapped_typer().__await__() + + async def do_typing(self) -> None: + channel = await self._get_channel() typing = channel._state.http.send_typing while True: - await typing(channel.id) await asyncio.sleep(5) + await typing(channel.id) - def __enter__(self): - self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop) - self.task.add_done_callback(_typing_done_callback) - return self - - def __exit__(self, exc_type, exc, tb): - self.task.cancel() - - async def __aenter__(self): - self._channel = channel = await self.messageable._get_channel() + async def __aenter__(self) -> None: + channel = await self._get_channel() await channel._state.http.send_typing(channel.id) - return self.__enter__() + self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing()) + self.task.add_done_callback(_typing_done_callback) - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BE]], + exc: Optional[BE], + traceback: Optional[TracebackType], + ) -> None: self.task.cancel() diff --git a/discord/embeds.py b/discord/embeds.py index e8b83985d816..b1c98e66b330 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,35 +22,89 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import datetime +from typing import Any, Dict, List, Mapping, Optional, Protocol, TYPE_CHECKING, TypeVar, Union from . import utils from .colour import Colour +from .flags import AttachmentFlags, EmbedFlags -class _EmptyEmbed: - def __bool__(self): - return False - - def __repr__(self): - return 'Embed.Empty' - - def __len__(self): - return 0 +# fmt: off +__all__ = ( + 'Embed', +) +# fmt: on -EmptyEmbed = _EmptyEmbed() class EmbedProxy: - def __init__(self, layer): + def __init__(self, layer: Dict[str, Any]): self.__dict__.update(layer) - def __len__(self): + def __len__(self) -> int: return len(self.__dict__) - def __repr__(self): - return 'EmbedProxy(%s)' % ', '.join(('%s=%r' % (k, v) for k, v in self.__dict__.items() if not k.startswith('_'))) + def __repr__(self) -> str: + inner = ', '.join((f'{k}={getattr(self, k)!r}' for k in dir(self) if not k.startswith('_'))) + return f'EmbedProxy({inner})' + + def __getattr__(self, attr: str) -> None: + return None + + def __eq__(self, other: object) -> bool: + return isinstance(other, EmbedProxy) and self.__dict__ == other.__dict__ + + +class EmbedMediaProxy(EmbedProxy): + def __init__(self, layer: Dict[str, Any]): + super().__init__(layer) + self._flags = self.__dict__.pop('flags', 0) + + def __bool__(self) -> bool: + # This is a nasty check to see if we only have the `_flags` attribute which is created regardless in init. + # Had we had any of the other items, like image/video data this would be >1 and therefor + # would not be "empty". + return len(self.__dict__) > 1 + + @property + def flags(self) -> AttachmentFlags: + return AttachmentFlags._from_value(self._flags or 0) + + +if TYPE_CHECKING: + from typing_extensions import Self + + from .types.embed import Embed as EmbedData, EmbedType + + T = TypeVar('T') + + class _EmbedFooterProxy(Protocol): + text: Optional[str] + icon_url: Optional[str] + + class _EmbedFieldProxy(Protocol): + name: Optional[str] + value: Optional[str] + inline: bool + + class _EmbedMediaProxy(Protocol): + url: Optional[str] + proxy_url: Optional[str] + height: Optional[int] + width: Optional[int] + flags: AttachmentFlags + + class _EmbedProviderProxy(Protocol): + name: Optional[str] + url: Optional[str] + + class _EmbedAuthorProxy(Protocol): + name: Optional[str] + url: Optional[str] + icon_url: Optional[str] + proxy_icon_url: Optional[str] - def __getattr__(self, attr): - return EmptyEmbed class Embed: """Represents a Discord embed. @@ -64,75 +116,104 @@ class Embed: Returns the total size of the embed. Useful for checking if it's within the 6000 character limit. - Certain properties return an ``EmbedProxy``, a type - that acts similar to a regular :class:`dict` except using dotted access, - e.g. ``embed.author.icon_url``. If the attribute - is invalid or empty, then a special sentinel value is returned, - :attr:`Embed.Empty`. + .. describe:: bool(b) + + Returns whether the embed has any data set. + + .. versionadded:: 2.0 + + .. describe:: x == y + + Checks if two embeds are equal. + + .. versionadded:: 2.0 For ease of use, all parameters that expect a :class:`str` are implicitly casted to :class:`str` for you. + .. versionchanged:: 2.0 + ``Embed.Empty`` has been removed in favour of ``None``. + Attributes ----------- - title: :class:`str` + title: Optional[:class:`str`] The title of the embed. This can be set during initialisation. + Can only be up to 256 characters. type: :class:`str` The type of embed. Usually "rich". This can be set during initialisation. - description: :class:`str` + Possible strings for embed types can be found on discord's + :ddocs:`api docs ` + description: Optional[:class:`str`] The description of the embed. This can be set during initialisation. - url: :class:`str` + Can only be up to 4096 characters. + url: Optional[:class:`str`] The URL of the embed. This can be set during initialisation. - timestamp: :class:`datetime.datetime` - The timestamp of the embed content. This could be a naive or aware datetime. - colour: Union[:class:`Colour`, :class:`int`] + timestamp: Optional[:class:`datetime.datetime`] + The timestamp of the embed content. This is an aware datetime. + If a naive datetime is passed, it is converted to an aware + datetime with the local timezone. + colour: Optional[Union[:class:`Colour`, :class:`int`]] The colour code of the embed. Aliased to ``color`` as well. This can be set during initialisation. - Empty - A special sentinel value used by ``EmbedProxy`` and this class - to denote that the value or attribute is empty. """ - __slots__ = ('title', 'url', 'type', '_timestamp', '_colour', '_footer', - '_image', '_thumbnail', '_video', '_provider', '_author', - '_fields', 'description') - - Empty = EmptyEmbed - - def __init__(self, **kwargs): - # swap the colour/color aliases - try: - colour = kwargs['colour'] - except KeyError: - colour = kwargs.get('color', EmptyEmbed) - - self.colour = colour - self.title = kwargs.get('title', EmptyEmbed) - self.type = kwargs.get('type', 'rich') - self.url = kwargs.get('url', EmptyEmbed) - self.description = kwargs.get('description', EmptyEmbed) - - try: - timestamp = kwargs['timestamp'] - except KeyError: - pass - else: + __slots__ = ( + 'title', + 'url', + 'type', + '_timestamp', + '_colour', + '_footer', + '_image', + '_thumbnail', + '_video', + '_provider', + '_author', + '_fields', + 'description', + '_flags', + ) + + def __init__( + self, + *, + colour: Optional[Union[int, Colour]] = None, + color: Optional[Union[int, Colour]] = None, + title: Optional[Any] = None, + type: EmbedType = 'rich', + url: Optional[Any] = None, + description: Optional[Any] = None, + timestamp: Optional[datetime.datetime] = None, + ): + self.colour = colour if colour is not None else color + self.title: Optional[str] = title + self.type: EmbedType = type + self.url: Optional[str] = url + self.description: Optional[str] = description + self._flags: int = 0 + + if self.title is not None: + self.title = str(self.title) + + if self.description is not None: + self.description = str(self.description) + + if self.url is not None: + self.url = str(self.url) + + if timestamp is not None: self.timestamp = timestamp @classmethod - def from_dict(cls, data): + def from_dict(cls, data: Mapping[str, Any]) -> Self: """Converts a :class:`dict` to a :class:`Embed` provided it is in the format that Discord expects it to be in. - You can find out about this format in the `official Discord documentation`__. - - .. _DiscordDocs: https://discordapp.com/developers/docs/resources/channel#embed-object - - __ DiscordDocs_ + You can find out about this format in the :ddocs:`official Discord documentation `. Parameters ----------- @@ -144,10 +225,20 @@ def from_dict(cls, data): # fill in the basic fields - self.title = data.get('title', EmptyEmbed) - self.type = data.get('type', EmptyEmbed) - self.description = data.get('description', EmptyEmbed) - self.url = data.get('url', EmptyEmbed) + self.title = data.get('title', None) + self.type = data.get('type', None) + self.description = data.get('description', None) + self.url = data.get('url', None) + self._flags = data.get('flags', 0) + + if self.title is not None: + self.title = str(self.title) + + if self.description is not None: + self.description = str(self.description) + + if self.url is not None: + self.url = str(self.url) # try to fill in the more rich fields @@ -171,21 +262,21 @@ def from_dict(cls, data): return self - def copy(self): + def copy(self) -> Self: """Returns a shallow copy of the embed.""" - return Embed.from_dict(self.to_dict()) + return self.__class__.from_dict(self.to_dict()) - def __len__(self): - total = len(self.title) + len(self.description) + def __len__(self) -> int: + total = len(self.title or '') + len(self.description or '') for field in getattr(self, '_fields', []): total += len(field['name']) + len(field['value']) try: - footer = self._footer - except AttributeError: + footer_text = self._footer['text'] + except (AttributeError, KeyError): pass else: - total += len(footer['text']) + total += len(footer_text) try: author = self._author @@ -196,43 +287,94 @@ def __len__(self): return total + def __bool__(self) -> bool: + return any( + ( + self.title, + self.url, + self.description, + self.colour, + self.fields, + self.timestamp, + self.author, + self.thumbnail, + self.footer, + self.image, + self.provider, + self.video, + ) + ) + + def __eq__(self, other: Embed) -> bool: + return isinstance(other, Embed) and ( + self.type == other.type + and self.title == other.title + and self.url == other.url + and self.description == other.description + and self.colour == other.colour + and self.fields == other.fields + and self.timestamp == other.timestamp + and self.author == other.author + and self.thumbnail == other.thumbnail + and self.footer == other.footer + and self.image == other.image + and self.provider == other.provider + and self.video == other.video + and self._flags == other._flags + ) + @property - def colour(self): - return getattr(self, '_colour', EmptyEmbed) + def flags(self) -> EmbedFlags: + """:class:`EmbedFlags`: The flags of this embed. + + .. versionadded:: 2.5 + """ + return EmbedFlags._from_value(self._flags or 0) + + @property + def colour(self) -> Optional[Colour]: + return getattr(self, '_colour', None) @colour.setter - def colour(self, value): - if isinstance(value, (Colour, _EmptyEmbed)): + def colour(self, value: Optional[Union[int, Colour]]) -> None: + if value is None: + self._colour = None + elif isinstance(value, Colour): self._colour = value elif isinstance(value, int): self._colour = Colour(value=value) else: - raise TypeError('Expected discord.Colour, int, or Embed.Empty but received %s instead.' % value.__class__.__name__) + raise TypeError(f'Expected discord.Colour, int, or None but received {value.__class__.__name__} instead.') color = colour @property - def timestamp(self): - return getattr(self, '_timestamp', EmptyEmbed) + def timestamp(self) -> Optional[datetime.datetime]: + return getattr(self, '_timestamp', None) @timestamp.setter - def timestamp(self, value): - if isinstance(value, (datetime.datetime, _EmptyEmbed)): + def timestamp(self, value: Optional[datetime.datetime]) -> None: + if isinstance(value, datetime.datetime): + if value.tzinfo is None: + value = value.astimezone() self._timestamp = value + elif value is None: + self._timestamp = None else: - raise TypeError("Expected datetime.datetime or Embed.Empty received %s instead" % value.__class__.__name__) + raise TypeError(f'Expected datetime.datetime or None received {value.__class__.__name__} instead') @property - def footer(self): + def footer(self) -> _EmbedFooterProxy: """Returns an ``EmbedProxy`` denoting the footer contents. See :meth:`set_footer` for possible values you can access. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_footer', {})) + # Lying to the type checker for better developer UX. + return EmbedProxy(getattr(self, '_footer', {})) # type: ignore - def set_footer(self, *, text=EmptyEmbed, icon_url=EmptyEmbed): + def set_footer(self, *, text: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self: """Sets the footer for the embed content. This function returns the class instance to allow for fluent-style @@ -241,36 +383,54 @@ def set_footer(self, *, text=EmptyEmbed, icon_url=EmptyEmbed): Parameters ----------- text: :class:`str` - The footer text. + The footer text. Can only be up to 2048 characters. icon_url: :class:`str` The URL of the footer icon. Only HTTP(S) is supported. + Inline attachment URLs are also supported, see :ref:`local_image`. """ self._footer = {} - if text is not EmptyEmbed: + if text is not None: self._footer['text'] = str(text) - if icon_url is not EmptyEmbed: + if icon_url is not None: self._footer['icon_url'] = str(icon_url) return self + def remove_footer(self) -> Self: + """Clears embed's footer information. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 2.0 + """ + try: + del self._footer + except AttributeError: + pass + + return self + @property - def image(self): + def image(self) -> _EmbedMediaProxy: """Returns an ``EmbedProxy`` denoting the image contents. Possible attributes you can access are: - - ``url`` - - ``proxy_url`` - - ``width`` - - ``height`` + - ``url`` for the image URL. + - ``proxy_url`` for the proxied image URL. + - ``width`` for the image width. + - ``height`` for the image height. + - ``flags`` for the image's attachment flags. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_image', {})) + # Lying to the type checker for better developer UX. + return EmbedMediaProxy(getattr(self, '_image', {})) # type: ignore - def set_image(self, *, url): + def set_image(self, *, url: Optional[Any]) -> Self: """Sets the image for the embed content. This function returns the class instance to allow for fluent-style @@ -278,32 +438,42 @@ def set_image(self, *, url): Parameters ----------- - url: :class:`str` + url: Optional[:class:`str`] The source URL for the image. Only HTTP(S) is supported. + If ``None`` is passed, any existing image is removed. + Inline attachment URLs are also supported, see :ref:`local_image`. """ - self._image = { - 'url': str(url) - } + if url is None: + try: + del self._image + except AttributeError: + pass + else: + self._image = { + 'url': str(url), + } return self @property - def thumbnail(self): + def thumbnail(self) -> _EmbedMediaProxy: """Returns an ``EmbedProxy`` denoting the thumbnail contents. Possible attributes you can access are: - - ``url`` - - ``proxy_url`` - - ``width`` - - ``height`` + - ``url`` for the thumbnail URL. + - ``proxy_url`` for the proxied thumbnail URL. + - ``width`` for the thumbnail width. + - ``height`` for the thumbnail height. + - ``flags`` for the thumbnail's attachment flags. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_thumbnail', {})) + # Lying to the type checker for better developer UX. + return EmbedMediaProxy(getattr(self, '_thumbnail', {})) # type: ignore - def set_thumbnail(self, *, url): + def set_thumbnail(self, *, url: Optional[Any]) -> Self: """Sets the thumbnail for the embed content. This function returns the class instance to allow for fluent-style @@ -311,51 +481,64 @@ def set_thumbnail(self, *, url): Parameters ----------- - url: :class:`str` + url: Optional[:class:`str`] The source URL for the thumbnail. Only HTTP(S) is supported. + If ``None`` is passed, any existing thumbnail is removed. + Inline attachment URLs are also supported, see :ref:`local_image`. """ - self._thumbnail = { - 'url': str(url) - } + if url is None: + try: + del self._thumbnail + except AttributeError: + pass + else: + self._thumbnail = { + 'url': str(url), + } return self @property - def video(self): + def video(self) -> _EmbedMediaProxy: """Returns an ``EmbedProxy`` denoting the video contents. Possible attributes include: - ``url`` for the video URL. + - ``proxy_url`` for the proxied video URL. - ``height`` for the video height. - ``width`` for the video width. + - ``flags`` for the video's attachment flags. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_video', {})) + # Lying to the type checker for better developer UX. + return EmbedMediaProxy(getattr(self, '_video', {})) # type: ignore @property - def provider(self): + def provider(self) -> _EmbedProviderProxy: """Returns an ``EmbedProxy`` denoting the provider contents. The only attributes that might be accessed are ``name`` and ``url``. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_provider', {})) + # Lying to the type checker for better developer UX. + return EmbedProxy(getattr(self, '_provider', {})) # type: ignore @property - def author(self): + def author(self) -> _EmbedAuthorProxy: """Returns an ``EmbedProxy`` denoting the author contents. See :meth:`set_author` for possible values you can access. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return EmbedProxy(getattr(self, '_author', {})) + # Lying to the type checker for better developer UX. + return EmbedProxy(getattr(self, '_author', {})) # type: ignore - def set_author(self, *, name, url=EmptyEmbed, icon_url=EmptyEmbed): + def set_author(self, *, name: Any, url: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self: """Sets the author for the embed content. This function returns the class instance to allow for fluent-style @@ -364,47 +547,64 @@ def set_author(self, *, name, url=EmptyEmbed, icon_url=EmptyEmbed): Parameters ----------- name: :class:`str` - The name of the author. + The name of the author. Can only be up to 256 characters. url: :class:`str` The URL for the author. icon_url: :class:`str` The URL of the author icon. Only HTTP(S) is supported. + Inline attachment URLs are also supported, see :ref:`local_image`. """ self._author = { - 'name': str(name) + 'name': str(name), } - if url is not EmptyEmbed: + if url is not None: self._author['url'] = str(url) - if icon_url is not EmptyEmbed: + if icon_url is not None: self._author['icon_url'] = str(icon_url) return self + def remove_author(self) -> Self: + """Clears embed's author information. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionadded:: 1.4 + """ + try: + del self._author + except AttributeError: + pass + + return self + @property - def fields(self): - """Returns a :class:`list` of ``EmbedProxy`` denoting the field contents. + def fields(self) -> List[_EmbedFieldProxy]: + """List[``EmbedProxy``]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents. See :meth:`add_field` for possible values you can access. - If the attribute has no value then :attr:`Empty` is returned. + If the attribute has no value then ``None`` is returned. """ - return [EmbedProxy(d) for d in getattr(self, '_fields', [])] + # Lying to the type checker for better developer UX. + return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore - def add_field(self, *, name, value, inline=True): + def add_field(self, *, name: Any, value: Any, inline: bool = True) -> Self: """Adds a field to the embed object. This function returns the class instance to allow for fluent-style - chaining. + chaining. Can only be up to 25 fields. Parameters ----------- name: :class:`str` - The name of the field. + The name of the field. Can only be up to 256 characters. value: :class:`str` - The value of the field. + The value of the field. Can only be up to 1024 characters. inline: :class:`bool` Whether the field should be displayed inline. """ @@ -412,7 +612,7 @@ def add_field(self, *, name, value, inline=True): field = { 'inline': inline, 'name': str(name), - 'value': str(value) + 'value': str(value), } try: @@ -421,23 +621,23 @@ def add_field(self, *, name, value, inline=True): self._fields = [field] return self - - def insert_field_at(self, index, *, name, value, inline=True): + + def insert_field_at(self, index: int, *, name: Any, value: Any, inline: bool = True) -> Self: """Inserts a field before a specified index to the embed. - + This function returns the class instance to allow for fluent-style - chaining. - - .. versionadded:: 1.2.0 - + chaining. Can only be up to 25 fields. + + .. versionadded:: 1.2 + Parameters ----------- index: :class:`int` The index of where to insert the field. name: :class:`str` - The name of the field. + The name of the field. Can only be up to 256 characters. value: :class:`str` - The value of the field. + The value of the field. Can only be up to 1024 characters. inline: :class:`bool` Whether the field should be displayed inline. """ @@ -445,7 +645,7 @@ def insert_field_at(self, index, *, name, value, inline=True): field = { 'inline': inline, 'name': str(name), - 'value': str(value) + 'value': str(value), } try: @@ -455,24 +655,39 @@ def insert_field_at(self, index, *, name, value, inline=True): return self - def clear_fields(self): - """Removes all fields from this embed.""" + def clear_fields(self) -> Self: + """Removes all fields from this embed. + + This function returns the class instance to allow for fluent-style + chaining. + + .. versionchanged:: 2.0 + This function now returns the class instance. + """ try: self._fields.clear() except AttributeError: self._fields = [] - def remove_field(self, index): + return self + + def remove_field(self, index: int) -> Self: """Removes a field at a specified index. If the index is invalid or out of bounds then the error is silently swallowed. + This function returns the class instance to allow for fluent-style + chaining. + .. note:: When deleting a field by index, the index of the other fields shift to fill the gap just like a regular list. + .. versionchanged:: 2.0 + This function now returns the class instance. + Parameters ----------- index: :class:`int` @@ -483,10 +698,12 @@ def remove_field(self, index): except (AttributeError, IndexError): pass - def set_field_at(self, index, *, name, value, inline=True): + return self + + def set_field_at(self, index: int, *, name: Any, value: Any, inline: bool = True) -> Self: """Modifies a field to the embed object. - The index must point to a valid pre-existing field. + The index must point to a valid pre-existing field. Can only be up to 25 fields. This function returns the class instance to allow for fluent-style chaining. @@ -496,9 +713,9 @@ def set_field_at(self, index, *, name, value, inline=True): index: :class:`int` The index of the field to modify. name: :class:`str` - The name of the field. + The name of the field. Can only be up to 256 characters. value: :class:`str` - The value of the field. + The value of the field. Can only be up to 1024 characters. inline: :class:`bool` Whether the field should be displayed inline. @@ -518,15 +735,17 @@ def set_field_at(self, index, *, name, value, inline=True): field['inline'] = inline return self - def to_dict(self): + def to_dict(self) -> EmbedData: """Converts this embed object into a dict.""" # add in the raw data into the dict + # fmt: off result = { key[1:]: getattr(self, key) - for key in self.__slots__ + for key in Embed.__slots__ if key[0] == '_' and hasattr(self, key) } + # fmt: on # deal with basic convenience wrappers @@ -562,4 +781,4 @@ def to_dict(self): if self.title: result['title'] = self.title - return result + return result # type: ignore # This payload is equivalent to the EmbedData type diff --git a/discord/emoji.py b/discord/emoji.py index 5e864b2816d0..efea38c756b7 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,108 +22,32 @@ DEALINGS IN THE SOFTWARE. """ -from .asset import Asset -from . import utils -from .user import User - -class PartialEmoji: - """Represents a "partial" emoji. - - This model will be given in two scenarios: - - - "Raw" data events such as :func:`on_raw_reaction_add` - - Custom emoji that the bot cannot see from e.g. :attr:`Message.reactions` - - .. container:: operations - - .. describe:: x == y - - Checks if two emoji are the same. - - .. describe:: x != y - - Checks if two emoji are not the same. - - .. describe:: hash(x) - - Return the emoji's hash. - - .. describe:: str(x) - - Returns the emoji rendered for discord. - - Attributes - ----------- - name: :class:`str` - The custom emoji name, if applicable, or the unicode codepoint - of the non-custom emoji. - animated: :class:`bool` - Whether the emoji is animated or not. - id: Optional[:class:`int`] - The ID of the custom emoji, if applicable. - """ - - __slots__ = ('animated', 'name', 'id', '_state') - - def __init__(self, *, animated, name, id=None): - self.animated = animated - self.name = name - self.id = id - self._state = None +from __future__ import annotations +from typing import Any, Collection, Iterator, List, Optional, TYPE_CHECKING, Tuple - @classmethod - def with_state(cls, state, *, animated, name, id=None): - self = cls(animated=animated, name=name, id=id) - self._state = state - return self - - def __str__(self): - if self.id is None: - return self.name - if self.animated: - return '' % (self.name, self.id) - return '<:%s:%s>' % (self.name, self.id) - - def __repr__(self): - return '<{0.__class__.__name__} animated={0.animated} name={0.name!r} id={0.id}>'.format(self) - - def __eq__(self, other): - if self.is_unicode_emoji(): - return isinstance(other, PartialEmoji) and self.name == other.name - - if isinstance(other, (PartialEmoji, Emoji)): - return self.id == other.id - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.id, self.name)) - - def is_custom_emoji(self): - """Checks if this is a custom non-Unicode emoji.""" - return self.id is not None - - def is_unicode_emoji(self): - """Checks if this is a Unicode emoji.""" - return self.id is None +from .asset import Asset, AssetMixin +from .utils import SnowflakeList, snowflake_time, MISSING +from .partial_emoji import _EmojiTag, PartialEmoji +from .user import User +from .errors import MissingApplicationID +from .object import Object - def _as_reaction(self): - if self.id is None: - return self.name - return '%s:%s' % (self.name, self.id) +# fmt: off +__all__ = ( + 'Emoji', +) +# fmt: on - @property - def url(self): - """:class:`Asset`:Returns an asset of the emoji, if it is custom.""" - if self.is_unicode_emoji(): - return Asset(self._state) +if TYPE_CHECKING: + from .types.emoji import Emoji as EmojiPayload + from .guild import Guild + from .state import ConnectionState + from .abc import Snowflake + from .role import Role + from datetime import datetime - _format = 'gif' if self.animated else 'png' - url = "https://cdn.discordapp.com/emojis/{0.id}.{1}".format(self, _format) - return Asset(self._state, url) -class Emoji: +class Emoji(_EmojiTag, AssetMixin): """Represents a custom emoji. Depending on the way this object was created, some of the attributes can @@ -172,68 +94,82 @@ class Emoji: Whether the emoji is available for use. user: Optional[:class:`User`] The user that created the emoji. This can only be retrieved using :meth:`Guild.fetch_emoji` and - having the :attr:`~Permissions.manage_emojis` permission. + having :attr:`~Permissions.manage_emojis`. + + Or if :meth:`.is_application_owned` is ``True``, this is the team member that uploaded + the emoji, or the bot user if it was uploaded using the API and this can + only be retrieved using :meth:`~discord.Client.fetch_application_emoji` or :meth:`~discord.Client.fetch_application_emojis`. """ - __slots__ = ('require_colons', 'animated', 'managed', 'id', 'name', '_roles', 'guild_id', - '_state', 'user', 'available') - def __init__(self, *, guild, state, data): - self.guild_id = guild.id - self._state = state + __slots__: Tuple[str, ...] = ( + 'require_colons', + 'animated', + 'managed', + 'id', + 'name', + '_roles', + 'guild_id', + '_state', + 'user', + 'available', + ) + + def __init__(self, *, guild: Snowflake, state: ConnectionState, data: EmojiPayload) -> None: + self.guild_id: int = guild.id + self._state: ConnectionState = state self._from_data(data) - def _from_data(self, emoji): - self.require_colons = emoji['require_colons'] - self.managed = emoji['managed'] - self.id = int(emoji['id']) - self.name = emoji['name'] - self.animated = emoji.get('animated', False) - self.available = emoji.get('available', True) - self._roles = utils.SnowflakeList(map(int, emoji.get('roles', []))) + def _from_data(self, emoji: EmojiPayload) -> None: + self.require_colons: bool = emoji.get('require_colons', False) + self.managed: bool = emoji.get('managed', False) + self.id: int = int(emoji['id']) # type: ignore # This won't be None for full emoji objects. + self.name: str = emoji['name'] # type: ignore # This won't be None for full emoji objects. + self.animated: bool = emoji.get('animated', False) + self.available: bool = emoji.get('available', True) + self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) user = emoji.get('user') - self.user = User(state=self._state, data=user) if user else None + self.user: Optional[User] = User(state=self._state, data=user) if user else None + + def _to_partial(self) -> PartialEmoji: + return PartialEmoji(name=self.name, animated=self.animated, id=self.id) - def _iterator(self): + def __iter__(self) -> Iterator[Tuple[str, Any]]: for attr in self.__slots__: if attr[0] != '_': value = getattr(self, attr, None) if value is not None: yield (attr, value) - def __iter__(self): - return self._iterator() - - def __str__(self): + def __str__(self) -> str: if self.animated: - return ''.format(self) - return "<:{0.name}:{0.id}>".format(self) + return f'' + return f'<:{self.name}:{self.id}>' - def __repr__(self): - return ''.format(self) + def __repr__(self) -> str: + return f'' - def __eq__(self, other): - return isinstance(other, (PartialEmoji, Emoji)) and self.id == other.id + def __eq__(self, other: object) -> bool: + return isinstance(other, _EmojiTag) and self.id == other.id - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return self.id >> 22 @property - def created_at(self): + def created_at(self) -> datetime: """:class:`datetime.datetime`: Returns the emoji's creation time in UTC.""" - return utils.snowflake_time(self.id) + return snowflake_time(self.id) @property - def url(self): - """:class:`Asset`: Returns the asset of the emoji.""" - _format = 'gif' if self.animated else 'png' - url = "https://cdn.discordapp.com/emojis/{0.id}.{1}".format(self, _format) - return Asset(self._state, url) + def url(self) -> str: + """:class:`str`: Returns the URL of the emoji.""" + end = 'webp?animated=true' if self.animated else 'png' + return f'{Asset.BASE}/emojis/{self.id}.{end}' @property - def roles(self): + def roles(self) -> List[Role]: """List[:class:`Role`]: A :class:`list` of roles that is allowed to use this emoji. If roles is empty, the emoji is unrestricted. @@ -245,58 +181,122 @@ def roles(self): return [role for role in guild.roles if self._roles.has(role.id)] @property - def guild(self): + def guild(self) -> Optional[Guild]: """:class:`Guild`: The guild this emoji belongs to.""" return self._state._get_guild(self.guild_id) - async def delete(self, *, reason=None): + def is_usable(self) -> bool: + """:class:`bool`: Whether the bot can use this emoji. + + .. versionadded:: 1.3 + """ + if not self.available or not self.guild or self.guild.unavailable: + return False + if not self._roles: + return True + emoji_roles, my_roles = self._roles, self.guild.me._roles + return any(my_roles.has(role_id) for role_id in emoji_roles) + + async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| Deletes the custom emoji. - You must have :attr:`~Permissions.manage_emojis` permission to - do this. + You must have :attr:`~Permissions.manage_emojis` to do this if + :meth:`.is_application_owned` is ``False``. Parameters ----------- reason: Optional[:class:`str`] The reason for deleting this emoji. Shows up on the audit log. + This does not apply if :meth:`.is_application_owned` is ``True``. + Raises ------- Forbidden You are not allowed to delete emojis. HTTPException An error occurred deleting the emoji. + MissingApplicationID + The emoji is owned by an application but the application ID is missing. """ + if self.is_application_owned(): + application_id = self._state.application_id + if application_id is None: + raise MissingApplicationID + + await self._state.http.delete_application_emoji(application_id, self.id) + return - await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) + await self._state.http.delete_custom_emoji(self.guild_id, self.id, reason=reason) - async def edit(self, *, name, roles=None, reason=None): + async def edit( + self, *, name: str = MISSING, roles: Collection[Snowflake] = MISSING, reason: Optional[str] = None + ) -> Emoji: r"""|coro| Edits the custom emoji. - You must have :attr:`~Permissions.manage_emojis` permission to - do this. + You must have :attr:`~Permissions.manage_emojis` to do this. + + .. versionchanged:: 2.0 + The newly updated emoji is returned. Parameters ----------- name: :class:`str` The new emoji name. - roles: Optional[list[:class:`Role`]] - A :class:`list` of :class:`Role`\s that can use this emoji. Leave empty to make it available to everyone. + roles: List[:class:`~discord.abc.Snowflake`] + A list of roles that can use this emoji. An empty list can be passed to make it available to everyone. + + This does not apply if :meth:`.is_application_owned` is ``True``. + reason: Optional[:class:`str`] The reason for editing this emoji. Shows up on the audit log. + This does not apply if :meth:`.is_application_owned` is ``True``. + Raises ------- Forbidden You are not allowed to edit emojis. HTTPException An error occurred editing the emoji. + MissingApplicationID + The emoji is owned by an application but the application ID is missing + + Returns + -------- + :class:`Emoji` + The newly updated emoji. """ - if roles: - roles = [role.id for role in roles] - await self._state.http.edit_custom_emoji(self.guild.id, self.id, name=name, roles=roles, reason=reason) + payload = {} + if name is not MISSING: + payload['name'] = name + if roles is not MISSING: + payload['roles'] = [role.id for role in roles] + + if self.is_application_owned(): + application_id = self._state.application_id + if application_id is None: + raise MissingApplicationID + + payload.pop('roles', None) + data = await self._state.http.edit_application_emoji( + application_id, + self.id, + payload=payload, + ) + return Emoji(guild=Object(0), data=data, state=self._state) + + data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason) + return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # if guild is None, the http request would have failed + + def is_application_owned(self) -> bool: + """:class:`bool`: Whether the emoji is owned by an application. + + .. versionadded:: 2.5 + """ + return self.guild_id == 0 diff --git a/discord/enums.py b/discord/enums.py index 5fec3a02a651..260222894f5d 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,48 +22,112 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import types from collections import namedtuple +from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping __all__ = ( 'Enum', 'ChannelType', 'MessageType', - 'VoiceRegion', 'SpeakingState', 'VerificationLevel', 'ContentFilter', 'Status', 'DefaultAvatar', - 'RelationshipType', 'AuditLogAction', 'AuditLogActionCategory', 'UserFlags', 'ActivityType', - 'HypeSquadHouse', 'NotificationLevel', - 'PremiumType', - 'UserContentFilter', - 'FriendFlags', - 'Theme', + 'TeamMembershipState', + 'TeamMemberRole', + 'WebhookType', + 'ExpireBehaviour', + 'ExpireBehavior', + 'StickerType', + 'StickerFormatType', + 'InviteTarget', + 'VideoQualityMode', + 'ComponentType', + 'ButtonStyle', + 'TextStyle', + 'PrivacyLevel', + 'InteractionType', + 'InteractionResponseType', + 'NSFWLevel', + 'MFALevel', + 'Locale', + 'EntityType', + 'EventStatus', + 'AppCommandType', + 'AppCommandOptionType', + 'AppCommandPermissionType', + 'AutoModRuleTriggerType', + 'AutoModRuleEventType', + 'AutoModRuleActionType', + 'ForumLayoutType', + 'ForumOrderType', + 'SelectDefaultValueType', + 'SKUType', + 'EntitlementType', + 'EntitlementOwnerType', + 'PollLayoutType', + 'InviteType', + 'ReactionType', + 'VoiceChannelEffectAnimationType', + 'SubscriptionStatus', + 'MessageReferenceType', + 'StatusDisplayType', + 'OnboardingPromptType', + 'OnboardingMode', + 'SeparatorSpacing', + 'MediaItemLoadingState', + 'CollectibleType', + 'NameplatePalette', ) -def _create_value_cls(name): + +def _create_value_cls(name: str, comparable: bool): + # All the type ignores here are due to the type checker being unable to recognise + # Runtime type creation without exploding. cls = namedtuple('_EnumValue_' + name, 'name value') - cls.__repr__ = lambda self: '<%s.%s: %r>' % (name, self.name, self.value) - cls.__str__ = lambda self: '%s.%s' % (name, self.name) + cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' + cls.__str__ = lambda self: f'{name}.{self.name}' + if comparable: + cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value + cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value + cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value + cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls + def _is_descriptor(obj): return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') + class EnumMeta(type): - def __new__(cls, name, bases, attrs): + if TYPE_CHECKING: + __name__: ClassVar[str] + _enum_member_names_: ClassVar[List[str]] + _enum_member_map_: ClassVar[Dict[str, Any]] + _enum_value_map_: ClassVar[Dict[Any, Any]] + + def __new__( + cls, + name: str, + bases: Tuple[type, ...], + attrs: Dict[str, Any], + *, + comparable: bool = False, + ) -> EnumMeta: value_mapping = {} member_mapping = {} member_names = [] - value_cls = _create_value_cls(name) + value_cls = _create_value_cls(name, comparable) for key, value in list(attrs.items()): is_descriptor = _is_descriptor(value) if key[0] == '_' and not is_descriptor: @@ -93,42 +155,43 @@ def __new__(cls, name, bases, attrs): attrs['_enum_value_map_'] = value_mapping attrs['_enum_member_map_'] = member_mapping attrs['_enum_member_names_'] = member_names + attrs['_enum_value_cls_'] = value_cls actual_cls = super().__new__(cls, name, bases, attrs) - value_cls._actual_enum_cls_ = actual_cls + value_cls._actual_enum_cls_ = actual_cls # type: ignore # Runtime attribute isn't understood return actual_cls - def __iter__(cls): + def __iter__(cls) -> Iterator[Any]: return (cls._enum_member_map_[name] for name in cls._enum_member_names_) - def __reversed__(cls): + def __reversed__(cls) -> Iterator[Any]: return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) - def __len__(cls): + def __len__(cls) -> int: return len(cls._enum_member_names_) - def __repr__(cls): - return '' % cls.__name__ + def __repr__(cls) -> str: + return f'' @property - def __members__(cls): + def __members__(cls) -> Mapping[str, Any]: return types.MappingProxyType(cls._enum_member_map_) - def __call__(cls, value): + def __call__(cls, value: str) -> Any: try: return cls._enum_value_map_[value] except (KeyError, TypeError): - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + raise ValueError(f'{value!r} is not a valid {cls.__name__}') - def __getitem__(cls, key): + def __getitem__(cls, key: str) -> Any: return cls._enum_member_map_[key] - def __setattr__(cls, name, value): + def __setattr__(cls, name: str, value: Any) -> None: raise TypeError('Enums are immutable.') - def __delattr__(cls, attr): + def __delattr__(cls, attr: str) -> None: raise TypeError('Enums are immutable') - def __instancecheck__(self, instance): + def __instancecheck__(self, instance: Any) -> bool: # isinstance(x, Y) # -> __instancecheck__(Y, x) try: @@ -136,113 +199,129 @@ def __instancecheck__(self, instance): except AttributeError: return False -class Enum(metaclass=EnumMeta): - @classmethod - def try_value(cls, value): - try: - return cls._enum_value_map_[value] - except (KeyError, TypeError): - return value + +if TYPE_CHECKING: + from enum import Enum +else: + + class Enum(metaclass=EnumMeta): + @classmethod + def try_value(cls, value): + try: + return cls._enum_value_map_[value] + except (KeyError, TypeError): + return value class ChannelType(Enum): - text = 0 - private = 1 - voice = 2 - group = 3 + text = 0 + private = 1 + voice = 2 + group = 3 category = 4 - news = 5 - store = 6 - - def __str__(self): + news = 5 + news_thread = 10 + public_thread = 11 + private_thread = 12 + stage_voice = 13 + forum = 15 + media = 16 + + def __str__(self) -> str: return self.name + +class MessageReferenceType(Enum): + default = 0 + reply = 0 + forward = 1 + + class MessageType(Enum): - default = 0 - recipient_add = 1 - recipient_remove = 2 - call = 3 - channel_name_change = 4 - channel_icon_change = 5 - pins_add = 6 - new_member = 7 + default = 0 + recipient_add = 1 + recipient_remove = 2 + call = 3 + channel_name_change = 4 + channel_icon_change = 5 + pins_add = 6 + new_member = 7 premium_guild_subscription = 8 - premium_guild_tier_1 = 9 - premium_guild_tier_2 = 10 - premium_guild_tier_3 = 11 - -class VoiceRegion(Enum): - us_west = 'us-west' - us_east = 'us-east' - us_south = 'us-south' - us_central = 'us-central' - eu_west = 'eu-west' - eu_central = 'eu-central' - singapore = 'singapore' - london = 'london' - sydney = 'sydney' - amsterdam = 'amsterdam' - frankfurt = 'frankfurt' - brazil = 'brazil' - hongkong = 'hongkong' - russia = 'russia' - japan = 'japan' - southafrica = 'southafrica' - india = 'india' - vip_us_east = 'vip-us-east' - vip_us_west = 'vip-us-west' - vip_amsterdam = 'vip-amsterdam' - - def __str__(self): - return self.value + premium_guild_tier_1 = 9 + premium_guild_tier_2 = 10 + premium_guild_tier_3 = 11 + channel_follow_add = 12 + guild_stream = 13 + guild_discovery_disqualified = 14 + guild_discovery_requalified = 15 + guild_discovery_grace_period_initial_warning = 16 + guild_discovery_grace_period_final_warning = 17 + thread_created = 18 + reply = 19 + chat_input_command = 20 + thread_starter_message = 21 + guild_invite_reminder = 22 + context_menu_command = 23 + auto_moderation_action = 24 + role_subscription_purchase = 25 + interaction_premium_upsell = 26 + stage_start = 27 + stage_end = 28 + stage_speaker = 29 + stage_raise_hand = 30 + stage_topic = 31 + guild_application_premium_subscription = 32 + guild_incident_alert_mode_enabled = 36 + guild_incident_alert_mode_disabled = 37 + guild_incident_report_raid = 38 + guild_incident_report_false_alarm = 39 + purchase_notification = 44 + poll_result = 46 + emoji_added = 63 + + def is_deletable(self) -> bool: + return self not in { + MessageType.recipient_add, + MessageType.recipient_remove, + MessageType.call, + MessageType.channel_name_change, + MessageType.channel_icon_change, + MessageType.thread_starter_message, + } + class SpeakingState(Enum): - none = 0 - voice = 1 + none = 0 + voice = 1 soundshare = 2 - priority = 4 + priority = 4 - def __str__(self): + def __str__(self) -> str: return self.name - def __int__(self): + def __int__(self) -> int: return self.value -class VerificationLevel(Enum): - none = 0 - low = 1 - medium = 2 - high = 3 - table_flip = 3 - extreme = 4 - double_table_flip = 4 - def __str__(self): +class VerificationLevel(Enum, comparable=True): + none = 0 + low = 1 + medium = 2 + high = 3 + highest = 4 + + def __str__(self) -> str: return self.name -class ContentFilter(Enum): - disabled = 0 - no_role = 1 + +class ContentFilter(Enum, comparable=True): + disabled = 0 + no_role = 1 all_members = 2 - def __str__(self): + def __str__(self) -> str: return self.name -class UserContentFilter(Enum): - disabled = 0 - friends = 1 - all_messages = 2 - -class FriendFlags(Enum): - noone = 0 - mutual_guilds = 1 - mutual_friends = 2 - guild_and_friends = 3 - everyone = 4 - -class Theme(Enum): - light = 'light' - dark = 'dark' class Status(Enum): online = 'online' @@ -252,97 +331,182 @@ class Status(Enum): do_not_disturb = 'dnd' invisible = 'invisible' - def __str__(self): + def __str__(self) -> str: return self.value + class DefaultAvatar(Enum): blurple = 0 - grey = 1 - gray = 1 - green = 2 - orange = 3 - red = 4 - - def __str__(self): + grey = 1 + gray = 1 + green = 2 + orange = 3 + red = 4 + pink = 5 + + def __str__(self) -> str: return self.name -class RelationshipType(Enum): - friend = 1 - blocked = 2 - incoming_request = 3 - outgoing_request = 4 -class NotificationLevel(Enum): - all_messages = 0 +class NotificationLevel(Enum, comparable=True): + all_messages = 0 only_mentions = 1 + class AuditLogActionCategory(Enum): create = 1 delete = 2 update = 3 + class AuditLogAction(Enum): - guild_update = 1 - channel_create = 10 - channel_update = 11 - channel_delete = 12 - overwrite_create = 13 - overwrite_update = 14 - overwrite_delete = 15 - kick = 20 - member_prune = 21 - ban = 22 - unban = 23 - member_update = 24 - member_role_update = 25 - role_create = 30 - role_update = 31 - role_delete = 32 - invite_create = 40 - invite_update = 41 - invite_delete = 42 - webhook_create = 50 - webhook_update = 51 - webhook_delete = 52 - emoji_create = 60 - emoji_update = 61 - emoji_delete = 62 - message_delete = 72 + # fmt: off + guild_update = 1 + channel_create = 10 + channel_update = 11 + channel_delete = 12 + overwrite_create = 13 + overwrite_update = 14 + overwrite_delete = 15 + kick = 20 + member_prune = 21 + ban = 22 + unban = 23 + member_update = 24 + member_role_update = 25 + member_move = 26 + member_disconnect = 27 + bot_add = 28 + role_create = 30 + role_update = 31 + role_delete = 32 + invite_create = 40 + invite_update = 41 + invite_delete = 42 + webhook_create = 50 + webhook_update = 51 + webhook_delete = 52 + emoji_create = 60 + emoji_update = 61 + emoji_delete = 62 + message_delete = 72 + message_bulk_delete = 73 + message_pin = 74 + message_unpin = 75 + integration_create = 80 + integration_update = 81 + integration_delete = 82 + stage_instance_create = 83 + stage_instance_update = 84 + stage_instance_delete = 85 + sticker_create = 90 + sticker_update = 91 + sticker_delete = 92 + scheduled_event_create = 100 + scheduled_event_update = 101 + scheduled_event_delete = 102 + thread_create = 110 + thread_update = 111 + thread_delete = 112 + app_command_permission_update = 121 + soundboard_sound_create = 130 + soundboard_sound_update = 131 + soundboard_sound_delete = 132 + automod_rule_create = 140 + automod_rule_update = 141 + automod_rule_delete = 142 + automod_block_message = 143 + automod_flag_message = 144 + automod_timeout_member = 145 + automod_quarantine_user = 146 + creator_monetization_request_created = 150 + creator_monetization_terms_accepted = 151 + onboarding_prompt_create = 163 + onboarding_prompt_update = 164 + onboarding_prompt_delete = 165 + onboarding_create = 166 + onboarding_update = 167 + home_settings_create = 190 + home_settings_update = 191 + # fmt: on @property - def category(self): - lookup = { - AuditLogAction.guild_update: AuditLogActionCategory.update, - AuditLogAction.channel_create: AuditLogActionCategory.create, - AuditLogAction.channel_update: AuditLogActionCategory.update, - AuditLogAction.channel_delete: AuditLogActionCategory.delete, - AuditLogAction.overwrite_create: AuditLogActionCategory.create, - AuditLogAction.overwrite_update: AuditLogActionCategory.update, - AuditLogAction.overwrite_delete: AuditLogActionCategory.delete, - AuditLogAction.kick: None, - AuditLogAction.member_prune: None, - AuditLogAction.ban: None, - AuditLogAction.unban: None, - AuditLogAction.member_update: AuditLogActionCategory.update, - AuditLogAction.member_role_update: AuditLogActionCategory.update, - AuditLogAction.role_create: AuditLogActionCategory.create, - AuditLogAction.role_update: AuditLogActionCategory.update, - AuditLogAction.role_delete: AuditLogActionCategory.delete, - AuditLogAction.invite_create: AuditLogActionCategory.create, - AuditLogAction.invite_update: AuditLogActionCategory.update, - AuditLogAction.invite_delete: AuditLogActionCategory.delete, - AuditLogAction.webhook_create: AuditLogActionCategory.create, - AuditLogAction.webhook_update: AuditLogActionCategory.update, - AuditLogAction.webhook_delete: AuditLogActionCategory.delete, - AuditLogAction.emoji_create: AuditLogActionCategory.create, - AuditLogAction.emoji_update: AuditLogActionCategory.update, - AuditLogAction.emoji_delete: AuditLogActionCategory.delete, - AuditLogAction.message_delete: AuditLogActionCategory.delete, + def category(self) -> Optional[AuditLogActionCategory]: + # fmt: off + lookup: Dict[AuditLogAction, Optional[AuditLogActionCategory]] = { + AuditLogAction.guild_update: AuditLogActionCategory.update, + AuditLogAction.channel_create: AuditLogActionCategory.create, + AuditLogAction.channel_update: AuditLogActionCategory.update, + AuditLogAction.channel_delete: AuditLogActionCategory.delete, + AuditLogAction.overwrite_create: AuditLogActionCategory.create, + AuditLogAction.overwrite_update: AuditLogActionCategory.update, + AuditLogAction.overwrite_delete: AuditLogActionCategory.delete, + AuditLogAction.kick: None, + AuditLogAction.member_prune: None, + AuditLogAction.ban: None, + AuditLogAction.unban: None, + AuditLogAction.member_update: AuditLogActionCategory.update, + AuditLogAction.member_role_update: AuditLogActionCategory.update, + AuditLogAction.member_move: None, + AuditLogAction.member_disconnect: None, + AuditLogAction.bot_add: None, + AuditLogAction.role_create: AuditLogActionCategory.create, + AuditLogAction.role_update: AuditLogActionCategory.update, + AuditLogAction.role_delete: AuditLogActionCategory.delete, + AuditLogAction.invite_create: AuditLogActionCategory.create, + AuditLogAction.invite_update: AuditLogActionCategory.update, + AuditLogAction.invite_delete: AuditLogActionCategory.delete, + AuditLogAction.webhook_create: AuditLogActionCategory.create, + AuditLogAction.webhook_update: AuditLogActionCategory.update, + AuditLogAction.webhook_delete: AuditLogActionCategory.delete, + AuditLogAction.emoji_create: AuditLogActionCategory.create, + AuditLogAction.emoji_update: AuditLogActionCategory.update, + AuditLogAction.emoji_delete: AuditLogActionCategory.delete, + AuditLogAction.message_delete: AuditLogActionCategory.delete, + AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete, + AuditLogAction.message_pin: None, + AuditLogAction.message_unpin: None, + AuditLogAction.integration_create: AuditLogActionCategory.create, + AuditLogAction.integration_update: AuditLogActionCategory.update, + AuditLogAction.integration_delete: AuditLogActionCategory.delete, + AuditLogAction.stage_instance_create: AuditLogActionCategory.create, + AuditLogAction.stage_instance_update: AuditLogActionCategory.update, + AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete, + AuditLogAction.sticker_create: AuditLogActionCategory.create, + AuditLogAction.sticker_update: AuditLogActionCategory.update, + AuditLogAction.sticker_delete: AuditLogActionCategory.delete, + AuditLogAction.scheduled_event_create: AuditLogActionCategory.create, + AuditLogAction.scheduled_event_update: AuditLogActionCategory.update, + AuditLogAction.scheduled_event_delete: AuditLogActionCategory.delete, + AuditLogAction.thread_create: AuditLogActionCategory.create, + AuditLogAction.thread_delete: AuditLogActionCategory.delete, + AuditLogAction.thread_update: AuditLogActionCategory.update, + AuditLogAction.app_command_permission_update: AuditLogActionCategory.update, + AuditLogAction.automod_rule_create: AuditLogActionCategory.create, + AuditLogAction.automod_rule_update: AuditLogActionCategory.update, + AuditLogAction.automod_rule_delete: AuditLogActionCategory.delete, + AuditLogAction.automod_block_message: None, + AuditLogAction.automod_flag_message: None, + AuditLogAction.automod_timeout_member: None, + AuditLogAction.automod_quarantine_user: None, + AuditLogAction.creator_monetization_request_created: None, + AuditLogAction.creator_monetization_terms_accepted: None, + AuditLogAction.soundboard_sound_create: AuditLogActionCategory.create, + AuditLogAction.soundboard_sound_update: AuditLogActionCategory.update, + AuditLogAction.soundboard_sound_delete: AuditLogActionCategory.delete, + AuditLogAction.onboarding_prompt_create: AuditLogActionCategory.create, + AuditLogAction.onboarding_prompt_update: AuditLogActionCategory.update, + AuditLogAction.onboarding_prompt_delete: AuditLogActionCategory.delete, + AuditLogAction.onboarding_create: AuditLogActionCategory.create, + AuditLogAction.onboarding_update: AuditLogActionCategory.update, + AuditLogAction.home_settings_create: AuditLogActionCategory.create, + AuditLogAction.home_settings_update: AuditLogActionCategory.update, } - return lookup[self] + # fmt: on + return lookup.get(self, None) @property - def target_type(self): + def target_type(self) -> Optional[str]: v = self.value if v == -1: return 'all' @@ -360,18 +524,58 @@ def target_type(self): return 'webhook' elif v < 70: return 'emoji' + elif v == 73: + return 'channel' elif v < 80: return 'message' + elif v < 83: + return 'integration' + elif v < 90: + return 'stage_instance' + elif v < 93: + return 'sticker' + elif v < 103: + return 'guild_scheduled_event' + elif v < 113: + return 'thread' + elif v < 122: + return 'integration_or_app_command' + elif 139 < v < 143: + return 'auto_moderation' + elif v < 147: + return 'user' + elif v < 152: + return 'creator_monetization' + elif v < 166: + return 'onboarding_prompt' + elif v < 168: + return 'onboarding' + elif v < 192: + return 'home_settings' + class UserFlags(Enum): staff = 1 partner = 2 hypesquad = 4 bug_hunter = 8 + mfa_sms = 16 + premium_promo_dismissed = 32 hypesquad_bravery = 64 hypesquad_brilliance = 128 hypesquad_balance = 256 early_supporter = 512 + team_user = 1024 + system = 4096 + has_unread_urgent_messages = 8192 + bug_hunter_level_2 = 16384 + verified_bot = 65536 + verified_bot_developer = 131072 + discord_certified_moderator = 262144 + bot_http_interactions = 524288 + spammer = 1048576 + active_developer = 4194304 + class ActivityType(Enum): unknown = -1 @@ -379,26 +583,438 @@ class ActivityType(Enum): streaming = 1 listening = 2 watching = 3 + custom = 4 + competing = 5 + + def __int__(self) -> int: + return self.value + - def __int__(self): +class TeamMembershipState(Enum): + invited = 1 + accepted = 2 + + +class TeamMemberRole(Enum): + admin = 'admin' + developer = 'developer' + read_only = 'read_only' + + +class WebhookType(Enum): + incoming = 1 + channel_follower = 2 + application = 3 + + +class ExpireBehaviour(Enum): + remove_role = 0 + kick = 1 + + +ExpireBehavior = ExpireBehaviour + + +class StickerType(Enum): + standard = 1 + guild = 2 + + +class StickerFormatType(Enum): + png = 1 + apng = 2 + lottie = 3 + gif = 4 + + @property + def file_extension(self) -> str: + # fmt: off + lookup: Dict[StickerFormatType, str] = { + StickerFormatType.png: 'png', + StickerFormatType.apng: 'png', + StickerFormatType.lottie: 'json', + StickerFormatType.gif: 'gif', + } + # fmt: on + return lookup.get(self, 'png') + + +class InviteTarget(Enum): + unknown = 0 + stream = 1 + embedded_application = 2 + + +class InteractionType(Enum): + ping = 1 + application_command = 2 + component = 3 + autocomplete = 4 + modal_submit = 5 + + +class InteractionResponseType(Enum): + pong = 1 + # ack = 2 (deprecated) + # channel_message = 3 (deprecated) + channel_message = 4 # (with source) + deferred_channel_message = 5 # (with source) + deferred_message_update = 6 # for components + message_update = 7 # for components + autocomplete_result = 8 + modal = 9 # for modals + # premium_required = 10 (deprecated) + launch_activity = 12 + + +class VideoQualityMode(Enum): + auto = 1 + full = 2 + + def __int__(self) -> int: + return self.value + + +class ComponentType(Enum): + action_row = 1 + button = 2 + select = 3 + string_select = 3 + text_input = 4 + user_select = 5 + role_select = 6 + mentionable_select = 7 + channel_select = 8 + section = 9 + text_display = 10 + thumbnail = 11 + media_gallery = 12 + file = 13 + separator = 14 + container = 17 + label = 18 + file_upload = 19 + + def __int__(self) -> int: + return self.value + + +class ButtonStyle(Enum): + primary = 1 + secondary = 2 + success = 3 + danger = 4 + link = 5 + premium = 6 + + # Aliases + blurple = 1 + grey = 2 + gray = 2 + green = 3 + red = 4 + url = 5 + + def __int__(self) -> int: return self.value -class HypeSquadHouse(Enum): - bravery = 1 - brilliance = 2 - balance = 3 -class PremiumType(Enum): - nitro_classic = 1 - nitro = 2 +class TextStyle(Enum): + short = 1 + paragraph = 2 + + # Aliases + long = 2 + + def __int__(self) -> int: + return self.value + + +class PrivacyLevel(Enum): + guild_only = 2 + + +class NSFWLevel(Enum, comparable=True): + default = 0 + explicit = 1 + safe = 2 + age_restricted = 3 + + +class MFALevel(Enum, comparable=True): + disabled = 0 + require_2fa = 1 + + +_UNICODE_LANG_MAP: Dict[str, str] = { + 'bg': 'bg-BG', + 'zh-CN': 'zh-CN', + 'zh-TW': 'zh-TW', + 'hr': 'hr-HR', + 'cs': 'cs-CZ', + 'da': 'da-DK', + 'nl': 'nl-NL', + 'en-US': 'en-US', + 'en-GB': 'en-GB', + 'fi': 'fi-FI', + 'fr': 'fr-FR', + 'de': 'de-DE', + 'el': 'el-GR', + 'hi': 'hi-IN', + 'hu': 'hu-HU', + 'id': 'id-ID', + 'it': 'it-IT', + 'ja': 'ja-JP', + 'ko': 'ko-KR', + 'lt': 'lt-LT', + 'no': 'no-NO', + 'pl': 'pl-PL', + 'pt-BR': 'pt-BR', + 'ro': 'ro-RO', + 'ru': 'ru-RU', + 'es-ES': 'es-ES', + 'es-419': 'es-419', + 'sv-SE': 'sv-SE', + 'th': 'th-TH', + 'tr': 'tr-TR', + 'uk': 'uk-UA', + 'vi': 'vi-VN', +} + + +class Locale(Enum): + american_english = 'en-US' + british_english = 'en-GB' + bulgarian = 'bg' + chinese = 'zh-CN' + taiwan_chinese = 'zh-TW' + croatian = 'hr' + czech = 'cs' + indonesian = 'id' + danish = 'da' + dutch = 'nl' + finnish = 'fi' + french = 'fr' + german = 'de' + greek = 'el' + hindi = 'hi' + hungarian = 'hu' + italian = 'it' + japanese = 'ja' + korean = 'ko' + latin_american_spanish = 'es-419' + lithuanian = 'lt' + norwegian = 'no' + polish = 'pl' + brazil_portuguese = 'pt-BR' + romanian = 'ro' + russian = 'ru' + spain_spanish = 'es-ES' + swedish = 'sv-SE' + thai = 'th' + turkish = 'tr' + ukrainian = 'uk' + vietnamese = 'vi' + + def __str__(self) -> str: + return self.value + + @property + def language_code(self) -> str: + return _UNICODE_LANG_MAP.get(self.value, self.value) + + +E = TypeVar('E', bound='Enum') + + +class EntityType(Enum): + stage_instance = 1 + voice = 2 + external = 3 + + +class EventStatus(Enum): + scheduled = 1 + active = 2 + completed = 3 + canceled = 4 + + ended = 3 + cancelled = 4 + + +class AppCommandOptionType(Enum): + subcommand = 1 + subcommand_group = 2 + string = 3 + integer = 4 + boolean = 5 + user = 6 + channel = 7 + role = 8 + mentionable = 9 + number = 10 + attachment = 11 + + +class AppCommandType(Enum): + chat_input = 1 + user = 2 + message = 3 + + +class AppCommandPermissionType(Enum): + role = 1 + user = 2 + channel = 3 + + +class AutoModRuleTriggerType(Enum): + keyword = 1 + harmful_link = 2 + spam = 3 + keyword_preset = 4 + mention_spam = 5 + member_profile = 6 + + +class AutoModRuleEventType(Enum): + message_send = 1 + member_update = 2 + + +class AutoModRuleActionType(Enum): + block_message = 1 + send_alert_message = 2 + timeout = 3 + block_member_interactions = 4 + + +class ForumLayoutType(Enum): + not_set = 0 + list_view = 1 + gallery_view = 2 + + +class ForumOrderType(Enum): + latest_activity = 0 + creation_date = 1 + + +class SelectDefaultValueType(Enum): + user = 'user' + role = 'role' + channel = 'channel' + + +class SKUType(Enum): + durable = 2 + consumable = 3 + subscription = 5 + subscription_group = 6 + + +class EntitlementType(Enum): + purchase = 1 + premium_subscription = 2 + developer_gift = 3 + test_mode_purchase = 4 + free_purchase = 5 + user_gift = 6 + premium_purchase = 7 + application_subscription = 8 + + +class EntitlementOwnerType(Enum): + guild = 1 + user = 2 + + +class PollLayoutType(Enum): + default = 1 + + +class InviteType(Enum): + guild = 0 + group_dm = 1 + friend = 2 + + +class ReactionType(Enum): + normal = 0 + burst = 1 + + +class VoiceChannelEffectAnimationType(Enum): + premium = 0 + basic = 1 + + +class SubscriptionStatus(Enum): + active = 0 + ending = 1 + inactive = 2 + + +class StatusDisplayType(Enum): + name = 0 # pyright: ignore[reportAssignmentType] + state = 1 + details = 2 + + +class OnboardingPromptType(Enum): + multiple_choice = 0 + dropdown = 1 + + +class OnboardingMode(Enum): + default = 0 + advanced = 1 + + +class SeparatorSpacing(Enum): + small = 1 + large = 2 + + +class MediaItemLoadingState(Enum): + unknown = 0 + loading = 1 + loaded = 2 + not_found = 3 + + +class CollectibleType(Enum): + nameplate = 'nameplate' + + +class NameplatePalette(Enum): + crimson = 'crimson' + berry = 'berry' + sky = 'sky' + teal = 'teal' + forest = 'forest' + bubble_gum = 'bubble_gum' + violet = 'violet' + cobalt = 'cobalt' + clover = 'clover' + lemon = 'lemon' + white = 'white' + + +def create_unknown_value(cls: Type[E], val: Any) -> E: + value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below + name = f'unknown_{val}' + return value_cls(name=name, value=val) + -def try_enum(cls, val): +def try_enum(cls: Type[E], val: Any) -> E: """A function that tries to turn the value into enum ``cls``. - If it fails it returns the value instead. + If it fails it returns a proxy invalid value instead. """ try: - return cls._enum_value_map_[val] + return cls._enum_value_map_[val] # type: ignore # All errors are caught below except (KeyError, TypeError, AttributeError): - return val + return create_unknown_value(cls, val) diff --git a/discord/errors.py b/discord/errors.py index fd967a9ce0ad..c07a7ed152ff 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,42 +22,76 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations +from typing import Dict, List, Optional, TYPE_CHECKING, Any, Tuple, Union + +if TYPE_CHECKING: + from aiohttp import ClientResponse, ClientWebSocketResponse + from requests import Response + + _ResponseType = Union[ClientResponse, Response] + + from .interactions import Interaction + +__all__ = ( + 'DiscordException', + 'ClientException', + 'GatewayNotFound', + 'HTTPException', + 'RateLimited', + 'Forbidden', + 'NotFound', + 'DiscordServerError', + 'InvalidData', + 'LoginFailure', + 'ConnectionClosed', + 'PrivilegedIntentsRequired', + 'InteractionResponded', + 'MissingApplicationID', +) + +APP_ID_NOT_FOUND = ( + 'Client does not have an application_id set. Either the function was called before on_ready ' + 'was called or application_id was not passed to the Client constructor.' +) + + class DiscordException(Exception): """Base exception class for discord.py - Ideally speaking, this could be caught to handle any exceptions thrown from this library. + Ideally speaking, this could be caught to handle any exceptions raised from this library. """ + pass + class ClientException(DiscordException): - """Exception that's thrown when an operation in the :class:`Client` fails. + """Exception that's raised when an operation in the :class:`Client` fails. These are usually for exceptions that happened due to user input. """ - pass -class NoMoreItems(DiscordException): - """Exception that is thrown when an async iteration operation has no more - items.""" pass + class GatewayNotFound(DiscordException): - """An exception that is usually thrown when the gateway hub - for the :class:`Client` websocket is not found.""" + """An exception that is raised when the gateway for Discord could not be found""" + def __init__(self): message = 'The gateway to connect to discord was not found.' - super(GatewayNotFound, self).__init__(message) + super().__init__(message) + -def flatten_error_dict(d, key=''): - items = [] +def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]: + items: List[Tuple[str, str]] = [] for k, v in d.items(): new_key = key + '.' + k if key else k if isinstance(v, dict): try: - _errors = v['_errors'] + _errors: List[Dict[str, Any]] = v['_errors'] except KeyError: - items.extend(flatten_error_dict(v, new_key).items()) + items.extend(_flatten_error_dict(v, new_key).items()) else: items.append((new_key, ' '.join(x.get('message', '') for x in _errors))) else: @@ -67,8 +99,9 @@ def flatten_error_dict(d, key=''): return dict(items) + class HTTPException(DiscordException): - """Exception that's thrown when an HTTP request operation fails. + """Exception that's raised when an HTTP request operation fails. Attributes ------------ @@ -85,69 +118,105 @@ class HTTPException(DiscordException): The Discord specific error code for the failure. """ - def __init__(self, response, message): - self.response = response - self.status = response.status + def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]): + self.response: _ResponseType = response + self.status: int = response.status # type: ignore # This attribute is filled by the library even if using requests + self.code: int + self.text: str if isinstance(message, dict): self.code = message.get('code', 0) base = message.get('message', '') errors = message.get('errors') + self._errors: Optional[Dict[str, Any]] = errors if errors: - errors = flatten_error_dict(errors) + errors = _flatten_error_dict(errors) helpful = '\n'.join('In %s: %s' % t for t in errors.items()) self.text = base + '\n' + helpful else: self.text = base else: - self.text = message + self.text = message or '' self.code = 0 fmt = '{0.status} {0.reason} (error code: {1})' if len(self.text): - fmt = fmt + ': {2}' + fmt += ': {2}' super().__init__(fmt.format(self.response, self.code, self.text)) + +class RateLimited(DiscordException): + """Exception that's raised for when status code 429 occurs + and the timeout is greater than the configured maximum using + the ``max_ratelimit_timeout`` parameter in :class:`Client`. + + This is not raised during global ratelimits. + + Since sometimes requests are halted pre-emptively before they're + even made, this **does not** subclass :exc:`HTTPException`. + + .. versionadded:: 2.0 + + Attributes + ------------ + retry_after: :class:`float` + The amount of seconds that the client should wait before retrying + the request. + """ + + def __init__(self, retry_after: float): + self.retry_after = retry_after + super().__init__(f'Too many requests. Retry in {retry_after:.2f} seconds.') + + class Forbidden(HTTPException): - """Exception that's thrown for when status code 403 occurs. + """Exception that's raised for when status code 403 occurs. Subclass of :exc:`HTTPException` """ + pass + class NotFound(HTTPException): - """Exception that's thrown for when status code 404 occurs. + """Exception that's raised for when status code 404 occurs. Subclass of :exc:`HTTPException` """ + pass -class InvalidData(ClientException): - """Exception that's raised when the library encounters unknown - or invalid data from Discord. +class DiscordServerError(HTTPException): + """Exception that's raised for when a 500 range status code occurs. + + Subclass of :exc:`HTTPException`. + + .. versionadded:: 1.5 """ + pass -class InvalidArgument(ClientException): - """Exception that's thrown when an argument to a function - is invalid some way (e.g. wrong value or wrong type). - This could be considered the analogous of ``ValueError`` and - ``TypeError`` except inherited from :exc:`ClientException` and thus - :exc:`DiscordException`. +class InvalidData(ClientException): + """Exception that's raised when the library encounters unknown + or invalid data from Discord. """ + pass + class LoginFailure(ClientException): - """Exception that's thrown when the :meth:`Client.login` function + """Exception that's raised when the :meth:`Client.login` function fails to log you in from improper credentials or some other misc. failure. """ + pass + class ConnectionClosed(ClientException): - """Exception that's thrown when the gateway connection is + """Exception that's raised when the gateway connection is closed for reasons that could not be handled internally. Attributes @@ -159,10 +228,78 @@ class ConnectionClosed(ClientException): shard_id: Optional[:class:`int`] The shard ID that got closed if applicable. """ - def __init__(self, original, *, shard_id): + + def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None): # This exception is just the same exception except # reconfigured to subclass ClientException for users - self.code = original.code - self.reason = original.reason - self.shard_id = shard_id - super().__init__(str(original)) + self.code: int = code or socket.close_code or -1 + # aiohttp doesn't seem to consistently provide close reason + self.reason: str = '' + self.shard_id: Optional[int] = shard_id + super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}') + + +class PrivilegedIntentsRequired(ClientException): + """Exception that's raised when the gateway is requesting privileged intents + but they're not ticked in the developer page yet. + + Go to https://discord.com/developers/applications/ and enable the intents + that are required. Currently these are as follows: + + - :attr:`Intents.members` + - :attr:`Intents.presences` + - :attr:`Intents.message_content` + + Attributes + ----------- + shard_id: Optional[:class:`int`] + The shard ID that got closed if applicable. + """ + + def __init__(self, shard_id: Optional[int]): + self.shard_id: Optional[int] = shard_id + msg = ( + 'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' + 'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' + "and explicitly enable the privileged intents within your application's page. If this is not " + 'possible, then consider disabling the privileged intents instead.' + ) + super().__init__(msg % shard_id) + + +class InteractionResponded(ClientException): + """Exception that's raised when sending another interaction response using + :class:`InteractionResponse` when one has already been done before. + + An interaction can only respond once. + + .. versionadded:: 2.0 + + Attributes + ----------- + interaction: :class:`Interaction` + The interaction that's already been responded to. + """ + + def __init__(self, interaction: Interaction): + self.interaction: Interaction = interaction + super().__init__('This interaction has already been responded to before') + + +class MissingApplicationID(ClientException): + """An exception raised when the client does not have an application ID set. + + An application ID is required for syncing application commands and various + other application tasks such as SKUs or application emojis. + + This inherits from :exc:`~discord.app_commands.AppCommandError` + and :class:`~discord.ClientException`. + + .. versionadded:: 2.0 + + .. versionchanged:: 2.5 + This is now exported to the ``discord`` namespace and now inherits from :class:`~discord.ClientException`. + """ + + def __init__(self, message: Optional[str] = None): + super().__init__(message or APP_ID_NOT_FOUND) diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py index b14fd655953d..08dab54d3438 100644 --- a/discord/ext/commands/__init__.py +++ b/discord/ext/commands/__init__.py @@ -1,20 +1,21 @@ -# -*- coding: utf-8 -*- - """ discord.ext.commands ~~~~~~~~~~~~~~~~~~~~~ An extension module to facilitate creation of bot commands. -:copyright: (c) 2019 Rapptz +:copyright: (c) 2015-present Rapptz :license: MIT, see LICENSE for more details. """ -from .bot import Bot, AutoShardedBot, when_mentioned, when_mentioned_or -from .context import Context +from .bot import * +from .cog import * +from .context import * +from .converter import * +from .cooldowns import * from .core import * from .errors import * +from .flags import * from .help import * -from .converter import * -from .cooldowns import * -from .cog import * +from .parameters import * +from .hybrid import * diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index bd2544783891..d7801939c5d2 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,6 +22,48 @@ DEALINGS IN THE SOFTWARE. """ +from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, Protocol, TypeVar, Union, Tuple, Optional + + +T = TypeVar('T') + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from .bot import Bot, AutoShardedBot + from .context import Context + from .cog import Cog + from .errors import CommandError + + P = ParamSpec('P') + MaybeAwaitableFunc = Callable[P, 'MaybeAwaitable[T]'] +else: + P = TypeVar('P') + MaybeAwaitableFunc = Tuple[P, T] + +_Bot = Union['Bot', 'AutoShardedBot'] +Coro = Coroutine[Any, Any, T] +CoroFunc = Callable[..., Coro[Any]] +MaybeCoro = Union[T, Coro[T]] +MaybeAwaitable = Union[T, Awaitable[T]] + +CogT = TypeVar('CogT', bound='Optional[Cog]') +UserCheck = Callable[['ContextT'], MaybeCoro[bool]] +Hook = Union[Callable[['CogT', 'ContextT'], Coro[Any]], Callable[['ContextT'], Coro[Any]]] +Error = Union[Callable[['CogT', 'ContextT', 'CommandError'], Coro[Any]], Callable[['ContextT', 'CommandError'], Coro[Any]]] + +ContextT = TypeVar('ContextT', bound='Context[Any]') +BotT = TypeVar('BotT', bound=_Bot, covariant=True) + +ContextT_co = TypeVar('ContextT_co', bound='Context[Any]', covariant=True) + + +class Check(Protocol[ContextT_co]): # type: ignore # TypeVar is expected to be invariant + predicate: Callable[[ContextT_co], Coroutine[Any, Any, bool]] + + def __call__(self, coro_or_commands: T) -> T: ... + + # This is merely a tag type to avoid circular import issues. # Yes, this is a terrible solution but ultimately it is the only solution. class _BaseCommand: diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 66e245dc3f73..0bb4cf95f5ed 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,32 +22,108 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + + import asyncio import collections +import collections.abc import inspect -import importlib +import importlib.util import sys -import traceback -import re +import logging import types +from typing import ( + Any, + Callable, + Mapping, + List, + Dict, + TYPE_CHECKING, + Optional, + Sequence, + TypeVar, + Type, + Union, + Iterable, + Collection, + overload, +) import discord +from discord import app_commands +from discord.app_commands.tree import _retrieve_guild_ids +from discord.utils import MISSING, _is_submodule -from .core import GroupMixin, Command +from .core import GroupMixin from .view import StringView from .context import Context from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog +from .hybrid import hybrid_command, hybrid_group, HybridCommand, HybridGroup + +if TYPE_CHECKING: + from typing_extensions import Self, Unpack + + import importlib.machinery + + from discord.message import Message + from discord.interactions import Interaction + from discord.abc import User, Snowflake + from ._types import ( + _Bot, + BotT, + UserCheck, + CoroFunc, + ContextT, + MaybeAwaitableFunc, + ) + from .core import Command + from .hybrid import CommandCallback, ContextT, P, _HybridCommandDecoratorKwargs, _HybridGroupDecoratorKwargs + from discord.client import _ClientOptions + from discord.shard import _AutoShardedClientOptions + + _Prefix = Union[Iterable[str], str] + _PrefixCallable = MaybeAwaitableFunc[[BotT, Message], _Prefix] + PrefixType = Union[_Prefix, _PrefixCallable[BotT]] + + class _BotOptions(_ClientOptions, total=False): + owner_id: Optional[int] + owner_ids: Optional[Collection[int]] + strip_after_prefix: bool + case_insensitive: bool + + class _AutoShardedBotOptions(_AutoShardedClientOptions, _BotOptions): ... + + +__all__ = ( + 'when_mentioned', + 'when_mentioned_or', + 'Bot', + 'AutoShardedBot', +) -def when_mentioned(bot, msg): +T = TypeVar('T') +CFT = TypeVar('CFT', bound='CoroFunc') + +_log = logging.getLogger(__name__) + + +def when_mentioned(bot: _Bot, msg: Message, /) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. + + .. versionchanged:: 2.0 + + ``bot`` and ``msg`` parameters are now positional-only. """ - return [bot.user.mention + ' ', '<@!%s> ' % bot.user.id] + # bot.user will never be None when this is called + return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore + -def when_mentioned_or(*prefixes): +def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. @@ -78,6 +152,7 @@ async def get_prefix(bot, message): ---------- :func:`.when_mentioned` """ + def inner(bot, msg): r = list(prefixes) r = when_mentioned(bot, msg) + r @@ -85,34 +160,55 @@ def inner(bot, msg): return inner -def _is_submodule(parent, child): - return parent == child or child.startswith(parent + ".") class _DefaultRepr: def __repr__(self): return '' -_default = _DefaultRepr() - -class BotBase(GroupMixin): - def __init__(self, command_prefix, help_command=_default, description=None, **options): - super().__init__(**options) - self.command_prefix = command_prefix - self.extra_events = {} - self.__cogs = {} - self.__extensions = {} - self._checks = [] - self._check_once = [] - self._before_invoke = None - self._after_invoke = None - self._help_command = None - self.description = inspect.cleandoc(description) if description else '' - self.owner_id = options.get('owner_id') - - if options.pop('self_bot', False): - self._skip_check = lambda x, y: x != y - else: - self._skip_check = lambda x, y: x == y + +_default: Any = _DefaultRepr() + + +class BotBase(GroupMixin[None]): + def __init__( + self, + command_prefix: PrefixType[BotT], + *, + help_command: Optional[HelpCommand] = _default, + tree_cls: Type[app_commands.CommandTree[Any]] = app_commands.CommandTree, + description: Optional[str] = None, + allowed_contexts: app_commands.AppCommandContext = MISSING, + allowed_installs: app_commands.AppInstallationType = MISSING, + intents: discord.Intents, + **options: Unpack[_BotOptions], + ) -> None: + super().__init__(intents=intents, **options) + self.command_prefix: PrefixType[BotT] = command_prefix # type: ignore + self.extra_events: Dict[str, List[CoroFunc]] = {} + # Self doesn't have the ClientT bound, but since this is a mixin it technically does + self.__tree: app_commands.CommandTree[Self] = tree_cls(self) # type: ignore + if allowed_contexts is not MISSING: + self.__tree.allowed_contexts = allowed_contexts + if allowed_installs is not MISSING: + self.__tree.allowed_installs = allowed_installs + + self.__cogs: Dict[str, Cog] = {} + self.__extensions: Dict[str, types.ModuleType] = {} + self._checks: List[UserCheck] = [] + self._check_once: List[UserCheck] = [] + self._before_invoke: Optional[CoroFunc] = None + self._after_invoke: Optional[CoroFunc] = None + self._help_command: Optional[HelpCommand] = None + self.description: str = inspect.cleandoc(description) if description else '' + self.owner_id: Optional[int] = options.get('owner_id') + self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set()) + self.strip_after_prefix: bool = options.get('strip_after_prefix', False) + + if self.owner_id and self.owner_ids: + raise TypeError('Both owner_id and owner_ids are set.') + + if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): + raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__.__name__}') if help_command is _default: self.help_command = DefaultHelpCommand() @@ -121,54 +217,157 @@ def __init__(self, command_prefix, help_command=_default, description=None, **op # internal helpers - def dispatch(self, event_name, *args, **kwargs): - super().dispatch(event_name, *args, **kwargs) + async def _async_setup_hook(self) -> None: + # self/super() resolves to Client/AutoShardedClient + await super()._async_setup_hook() # type: ignore + prefix = self.command_prefix + + # This has to be here because for the default logging set up to capture + # the logging calls, they have to come after the `Client.run` call. + # The best place to do this is in an async init scenario + if not self.intents.message_content: # type: ignore + trigger_warning = ( + (callable(prefix) and prefix is not when_mentioned) + or isinstance(prefix, str) + or (isinstance(prefix, collections.abc.Iterable) and len(list(prefix)) >= 1) + ) + if trigger_warning: + _log.warning('Privileged message content intent is missing, commands may not work as expected.') + + def dispatch(self, event_name: str, /, *args: Any, **kwargs: Any) -> None: + # super() will resolve to Client + super().dispatch(event_name, *args, **kwargs) # type: ignore ev = 'on_' + event_name for event in self.extra_events.get(ev, []): - self._schedule_event(event, ev, *args, **kwargs) + self._schedule_event(event, ev, *args, **kwargs) # type: ignore - async def close(self): + @discord.utils.copy_doc(discord.Client.close) + async def close(self) -> None: for extension in tuple(self.__extensions): try: - self.unload_extension(extension) + await self.unload_extension(extension) except Exception: pass for cog in tuple(self.__cogs): try: - self.remove_cog(cog) + await self.remove_cog(cog) except Exception: pass - await super().close() + await super().close() # type: ignore + + # GroupMixin overrides + + @discord.utils.copy_doc(GroupMixin.add_command) + def add_command(self, command: Command[Any, ..., Any], /) -> None: + super().add_command(command) + if isinstance(command, (HybridCommand, HybridGroup)) and command.app_command: + # If a cog is also inheriting from app_commands.Group then it'll also + # add the hybrid commands as text commands, which would recursively add the + # hybrid commands as slash commands. This check just terminates that recursion + # from happening + if command.cog is None or not command.cog.__cog_is_app_commands_group__: + self.tree.add_command(command.app_command) + + @discord.utils.copy_doc(GroupMixin.remove_command) + def remove_command(self, name: str, /) -> Optional[Command[Any, ..., Any]]: + cmd: Optional[Command[Any, ..., Any]] = super().remove_command(name) + if isinstance(cmd, (HybridCommand, HybridGroup)) and cmd.app_command: + # See above + if cmd.cog is not None and cmd.cog.__cog_is_app_commands_group__: + return cmd + + guild_ids: Optional[List[int]] = cmd.app_command._guild_ids + if guild_ids is None: + self.__tree.remove_command(name) + else: + for guild_id in guild_ids: + self.__tree.remove_command(name, guild=discord.Object(id=guild_id)) + + return cmd - async def on_command_error(self, context, exception): + def hybrid_command( + self, + name: Union[str, app_commands.locale_str] = MISSING, + with_app_command: bool = True, + *args: Any, + **kwargs: Unpack[_HybridCommandDecoratorKwargs], # type: ignore # name, with_app_command + ) -> Callable[[CommandCallback[Any, ContextT, P, T]], HybridCommand[Any, P, T]]: + """A shortcut decorator that invokes :func:`~discord.ext.commands.hybrid_command` and adds it to + the internal command list via :meth:`add_command`. + + Returns + -------- + Callable[..., :class:`HybridCommand`] + A decorator that converts the provided method into a Command, adds it to the bot, then returns it. + """ + + def decorator(func: CommandCallback[Any, ContextT, P, T]): + kwargs.setdefault('parent', self) # type: ignore # parent is not for the user to set + result = hybrid_command(name=name, *args, with_app_command=with_app_command, **kwargs)(func) # type: ignore # name, with_app_command + self.add_command(result) + return result + + return decorator + + def hybrid_group( + self, + name: Union[str, app_commands.locale_str] = MISSING, + with_app_command: bool = True, + *args: Any, + **kwargs: Unpack[_HybridGroupDecoratorKwargs], # type: ignore # name, with_app_command + ) -> Callable[[CommandCallback[Any, ContextT, P, T]], HybridGroup[Any, P, T]]: + """A shortcut decorator that invokes :func:`~discord.ext.commands.hybrid_group` and adds it to + the internal command list via :meth:`add_command`. + + Returns + -------- + Callable[..., :class:`HybridGroup`] + A decorator that converts the provided method into a Group, adds it to the bot, then returns it. + """ + + def decorator(func: CommandCallback[Any, ContextT, P, T]): + kwargs.setdefault('parent', self) # type: ignore # parent is not for the user to set + result = hybrid_group(name=name, *args, with_app_command=with_app_command, **kwargs)(func) # type: ignore # name, with_app_command + self.add_command(result) + return result + + return decorator + + # Error handler + + async def on_command_error(self, context: Context[BotT], exception: errors.CommandError, /) -> None: """|coro| The default command error handler provided by the bot. - By default this prints to :data:`sys.stderr` however it could be + By default this logs to the library logger, however it could be overridden to have a different implementation. This only fires if you do not specify any listeners for command error. + + .. versionchanged:: 2.0 + + ``context`` and ``exception`` parameters are now positional-only. + Instead of writing to ``sys.stderr`` this now uses the library logger. """ if self.extra_events.get('on_command_error', None): return - if hasattr(context.command, 'on_error'): + command = context.command + if command and command.has_error_handler(): return cog = context.cog - if cog: - if Cog._get_overridden_method(cog.cog_command_error) is not None: - return + if cog and cog.has_error_handler(): + return - print('Ignoring exception in command {}:'.format(context.command), file=sys.stderr) - traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) + _log.error('Ignoring exception in command %s', command, exc_info=exception) # global check registration - def check(self, func): + def check(self, func: T, /) -> T: r"""A decorator that adds a global check to the bot. A global check is similar to a :func:`.check` that is applied @@ -192,23 +391,33 @@ def check(self, func): def check_commands(ctx): return ctx.command.qualified_name in allowed_commands + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. """ - self.add_check(func) + # T was used instead of Check to ensure the type matches on return + self.add_check(func) # type: ignore return func - def add_check(self, func, *, call_once=False): + def add_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Adds a global check to the bot. This is the non-decorator interface to :meth:`.check` and :meth:`.check_once`. + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + + .. seealso:: The :func:`~discord.ext.commands.check` decorator + Parameters ----------- func The function that was used as a global check. call_once: :class:`bool` If the function should only be called once per - :meth:`Command.invoke` call. + :meth:`.invoke` call. """ if call_once: @@ -216,12 +425,16 @@ def add_check(self, func, *, call_once=False): else: self._checks.append(func) - def remove_check(self, func, *, call_once=False): + def remove_check(self, func: UserCheck[ContextT], /, *, call_once: bool = False) -> None: """Removes a global check from the bot. This function is idempotent and will not raise an exception if the function is not in the global checks. + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + Parameters ----------- func @@ -237,17 +450,23 @@ def remove_check(self, func, *, call_once=False): except ValueError: pass - def check_once(self, func): + def check_once(self, func: CFT, /) -> CFT: r"""A decorator that adds a "call once" global check to the bot. Unlike regular global checks, this one is called only once - per :meth:`Command.invoke` call. + per :meth:`.invoke` call. Regular global checks are called whenever a command is called or :meth:`.Command.can_run` is called. This type of check bypasses that and ensures that it's called only once, even inside the default help command. + .. note:: + + When using this function the :class:`.Context` sent to a group subcommand + may only parse the parent command and not the subcommands due to it + being invoked once per :meth:`.Bot.invoke` call. + .. note:: This function can either be a regular function or a coroutine. @@ -265,25 +484,45 @@ def check_once(self, func): def whitelist(ctx): return ctx.message.author.id in my_whitelist + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + """ self.add_check(func, call_once=True) return func - async def can_run(self, ctx, *, call_once=False): + async def can_run(self, ctx: Context[BotT], /, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks if len(data) == 0: return True - return await discord.utils.async_all(f(ctx) for f in data) + return await discord.utils.async_all(f(ctx) for f in data) # type: ignore + + async def is_owner(self, user: User, /) -> bool: + """|coro| - async def is_owner(self, user): - """Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of + Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of this bot. If an :attr:`owner_id` is not set, it is fetched automatically through the use of :meth:`~.Bot.application_info`. + .. versionchanged:: 1.3 + The function also checks if the application is team-owned if + :attr:`owner_ids` is not set. + + .. versionchanged:: 2.0 + + ``user`` parameter is now positional-only. + + .. versionchanged:: 2.4 + + This function now respects the team member roles if the bot is team-owned. + In order to be considered an owner, they must be either an admin or + a developer. + Parameters ----------- user: :class:`.abc.User` @@ -295,13 +534,24 @@ async def is_owner(self, user): Whether the user is the owner. """ - if self.owner_id is None: - app = await self.application_info() - self.owner_id = owner_id = app.owner.id - return user.id == owner_id - return user.id == self.owner_id + if self.owner_id: + return user.id == self.owner_id + elif self.owner_ids: + return user.id in self.owner_ids + else: + app: discord.AppInfo = await self.application_info() # type: ignore + if app.team: + self.owner_ids = ids = { + m.id + for m in app.team.members + if m.role in (discord.TeamMemberRole.admin, discord.TeamMemberRole.developer) + } + return user.id in ids + else: + self.owner_id = owner_id = app.owner.id + return user.id == owner_id - def before_invoke(self, coro): + def before_invoke(self, coro: CFT, /) -> CFT: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -317,6 +567,10 @@ def before_invoke(self, coro): without error. If any check or argument parsing procedures fail then the hooks are not called. + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Parameters ----------- coro: :ref:`coroutine ` @@ -333,7 +587,7 @@ def before_invoke(self, coro): self._before_invoke = coro return coro - def after_invoke(self, coro): + def after_invoke(self, coro: CFT, /) -> CFT: r"""A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is @@ -350,6 +604,10 @@ def after_invoke(self, coro): callback raising an error (i.e. :exc:`.CommandInvokeError`\). This makes it ideal for clean-up scenarios. + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Parameters ----------- coro: :ref:`coroutine ` @@ -368,14 +626,18 @@ def after_invoke(self, coro): # listener registration - def add_listener(self, func, name=None): + def add_listener(self, func: CoroFunc, /, name: str = MISSING) -> None: """The non decorator alternative to :meth:`.listen`. + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + Parameters ----------- func: :ref:`coroutine ` The function to call. - name: Optional[:class:`str`] + name: :class:`str` The name of the event to listen for. Defaults to ``func.__name__``. Example @@ -390,7 +652,7 @@ async def my_message(message): pass bot.add_listener(my_message, 'on_message') """ - name = func.__name__ if name is None else name + name = func.__name__ if name is MISSING else name if not asyncio.iscoroutinefunction(func): raise TypeError('Listeners must be coroutines') @@ -400,9 +662,13 @@ async def my_message(message): pass else: self.extra_events[name] = [func] - def remove_listener(self, func, name=None): + def remove_listener(self, func: CoroFunc, /, name: str = MISSING) -> None: """Removes a listener from the pool of listeners. + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + Parameters ----------- func @@ -412,7 +678,7 @@ def remove_listener(self, func, name=None): ``func.__name__``. """ - name = func.__name__ if name is None else name + name = func.__name__ if name is MISSING else name if name in self.extra_events: try: @@ -420,7 +686,7 @@ def remove_listener(self, func, name=None): except ValueError: pass - def listen(self, name=None): + def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]: """A decorator that registers another function as an external event listener. Basically this allows you to listen to multiple events from different places e.g. such as :func:`.on_ready` @@ -450,7 +716,7 @@ async def my_message(message): The function being listened to is not a coroutine. """ - def decorator(func): + def decorator(func: CFT) -> CFT: self.add_listener(func, name) return func @@ -458,15 +724,64 @@ def decorator(func): # cogs - def add_cog(self, cog): - """Adds a "cog" to the bot. + async def add_cog( + self, + cog: Cog, + /, + *, + override: bool = False, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + ) -> None: + """|coro| + + Adds a "cog" to the bot. A cog is a class that has its own event listeners and commands. + If the cog is a :class:`.app_commands.Group` then it is added to + the bot's :class:`~discord.app_commands.CommandTree` as well. + + .. note:: + + Exceptions raised inside a :class:`.Cog`'s :meth:`~.Cog.cog_load` method will be + propagated to the caller. + + .. versionchanged:: 2.0 + + :exc:`.ClientException` is raised when a cog with the same name + is already loaded. + + .. versionchanged:: 2.0 + + ``cog`` parameter is now positional-only. + + .. versionchanged:: 2.0 + + This method is now a :term:`coroutine`. + Parameters ----------- cog: :class:`.Cog` The cog to register to the bot. + override: :class:`bool` + If a previously loaded cog with the same name should be ejected + instead of raising an error. + + .. versionadded:: 2.0 + guild: Optional[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guild where the cog group would be added to. If not given then + it becomes a global command instead. + + .. versionadded:: 2.0 + guilds: List[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guilds where the cog group would be added to. If not given then + it becomes a global command instead. Cannot be mixed with + ``guild``. + + .. versionadded:: 2.0 Raises ------- @@ -474,40 +789,97 @@ def add_cog(self, cog): The cog does not inherit from :class:`.Cog`. CommandError An error happened during loading. + ClientException + A cog with the same name is already loaded. """ if not isinstance(cog, Cog): raise TypeError('cogs must derive from Cog') - cog = cog._inject(self) - self.__cogs[cog.__cog_name__] = cog + cog_name = cog.__cog_name__ + existing = self.__cogs.get(cog_name) + + if existing is not None: + if not override: + raise discord.ClientException(f'Cog named {cog_name!r} already loaded') + await self.remove_cog(cog_name, guild=guild, guilds=guilds) - def get_cog(self, name): + if cog.__cog_app_commands_group__: + self.__tree.add_command(cog.__cog_app_commands_group__, override=override, guild=guild, guilds=guilds) + + cog = await cog._inject(self, override=override, guild=guild, guilds=guilds) + self.__cogs[cog_name] = cog + + def get_cog(self, name: str, /) -> Optional[Cog]: """Gets the cog instance requested. If the cog is not found, ``None`` is returned instead. + .. versionchanged:: 2.0 + + ``name`` parameter is now positional-only. + Parameters ----------- name: :class:`str` The name of the cog you are requesting. This is equivalent to the name passed via keyword argument in class creation or the class name if unspecified. + + Returns + -------- + Optional[:class:`Cog`] + The cog that was requested. If not found, returns ``None``. """ return self.__cogs.get(name) - def remove_cog(self, name): - """Removes a cog from the bot. + async def remove_cog( + self, + name: str, + /, + *, + guild: Optional[Snowflake] = MISSING, + guilds: Sequence[Snowflake] = MISSING, + ) -> Optional[Cog]: + """|coro| + + Removes a cog from the bot and returns it. All registered commands and event listeners that the cog has registered will be removed as well. If no cog is found then this method has no effect. + .. versionchanged:: 2.0 + + ``name`` parameter is now positional-only. + + .. versionchanged:: 2.0 + + This method is now a :term:`coroutine`. + Parameters ----------- name: :class:`str` The name of the cog to remove. + guild: Optional[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guild where the cog group would be removed from. If not given then + a global command is removed instead instead. + + .. versionadded:: 2.0 + guilds: List[:class:`~discord.abc.Snowflake`] + If the cog is an application command group, then this would be the + guilds where the cog group would be removed from. If not given then + a global command is removed instead instead. Cannot be mixed with + ``guild``. + + .. versionadded:: 2.0 + + Returns + ------- + Optional[:class:`.Cog`] + The cog that was removed. ``None`` if not found. """ cog = self.__cogs.pop(name, None) @@ -517,21 +889,32 @@ def remove_cog(self, name): help_command = self._help_command if help_command and help_command.cog is cog: help_command.cog = None - cog._eject(self) + + guild_ids = _retrieve_guild_ids(cog, guild, guilds) + if cog.__cog_app_commands_group__: + if guild_ids is None: + self.__tree.remove_command(name) + else: + for guild_id in guild_ids: + self.__tree.remove_command(name, guild=discord.Object(guild_id)) + + await cog._eject(self, guild_ids=guild_ids) + + return cog @property - def cogs(self): + def cogs(self) -> Mapping[str, Cog]: """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" return types.MappingProxyType(self.__cogs) # extensions - def _remove_module_references(self, name): + async def _remove_module_references(self, name: str) -> None: # find all references to the module # remove the cogs registered from the module for cogname, cog in self.__cogs.copy().items(): if _is_submodule(name, cog.__module__): - self.remove_cog(cogname) + await self.remove_cog(cogname) # remove all the commands from the module for cmd in self.all_commands.copy().values(): @@ -550,14 +933,17 @@ def _remove_module_references(self, name): for index in reversed(remove): del event_list[index] - def _call_module_finalizers(self, lib, key): + # remove all relevant application commands from the tree + self.__tree._remove_with_module(name) + + async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = getattr(lib, 'teardown') except AttributeError: pass else: try: - func(self) + await func(self) except Exception: pass finally: @@ -568,8 +954,16 @@ def _call_module_finalizers(self, lib, key): if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec(self, lib, key): + async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions + lib = importlib.util.module_from_spec(spec) + sys.modules[key] = lib + try: + spec.loader.exec_module(lib) # type: ignore + except Exception as e: + del sys.modules[key] + raise errors.ExtensionFailed(key, e) from e + try: setup = getattr(lib, 'setup') except AttributeError: @@ -577,16 +971,25 @@ def _load_from_module_spec(self, lib, key): raise errors.NoEntryPointError(key) try: - setup(self) + await setup(self) except Exception as e: - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, key) + del sys.modules[key] + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, key) raise errors.ExtensionFailed(key, e) from e else: self.__extensions[key] = lib - def load_extension(self, name): - """Loads an extension. + def _resolve_name(self, name: str, package: Optional[str]) -> str: + try: + return importlib.util.resolve_name(name, package) + except ImportError: + raise errors.ExtensionNotFound(name) + + async def load_extension(self, name: str, *, package: Optional[str] = None) -> None: + """|coro| + + Loads an extension. An extension is a python module that contains commands, cogs, or listeners. @@ -595,37 +998,51 @@ def load_extension(self, name): the entry point on what to do when the extension is loaded. This entry point must have a single argument, the ``bot``. + .. versionchanged:: 2.0 + + This method is now a :term:`coroutine`. + Parameters ------------ name: :class:`str` The extension name to load. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when loading an extension using a relative path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 Raises -------- ExtensionNotFound The extension could not be imported. + This is also raised if the name of the extension could not + be resolved using the provided ``package`` parameter. ExtensionAlreadyLoaded The extension is already loaded. NoEntryPointError The extension does not have a setup function. ExtensionFailed - The extension setup function had an execution error. + The extension or its setup function had an execution error. """ + name = self._resolve_name(name, package) if name in self.__extensions: raise errors.ExtensionAlreadyLoaded(name) - try: - lib = importlib.import_module(name) - except ImportError as e: - raise errors.ExtensionNotFound(name, e) from e - else: - self._load_from_module_spec(lib, name) + spec = importlib.util.find_spec(name) + if spec is None: + raise errors.ExtensionNotFound(name) + + await self._load_from_module_spec(spec, name) - def unload_extension(self, name): - """Unloads an extension. + async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: + """|coro| + + Unloads an extension. When the extension is unloaded, all commands, listeners, and cogs are removed from the bot and the module is un-imported. @@ -635,28 +1052,44 @@ def unload_extension(self, name): parameter, the ``bot``, similar to ``setup`` from :meth:`~.Bot.load_extension`. + .. versionchanged:: 2.0 + + This method is now a :term:`coroutine`. + Parameters ------------ name: :class:`str` The extension name to unload. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when unloading an extension using a relative path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 Raises ------- + ExtensionNotFound + The name of the extension could not + be resolved using the provided ``package`` parameter. ExtensionNotLoaded The extension was not loaded. """ + name = self._resolve_name(name, package) lib = self.__extensions.get(name) if lib is None: raise errors.ExtensionNotLoaded(name) - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, name) + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, name) - def reload_extension(self, name): - """Atomically reloads an extension. + async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: + """|coro| + + Atomically reloads an extension. This replaces the extension with the same extension, only refreshed. This is equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension` @@ -669,6 +1102,12 @@ def reload_extension(self, name): The extension name to reload. It must be dot separated like regular Python imports if accessing a sub-module. e.g. ``foo.test`` if you want to import ``foo/test.py``. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when reloading an extension using a relative path, e.g ``.foo.test``. + Defaults to ``None``. + + .. versionadded:: 1.7 Raises ------- @@ -676,51 +1115,57 @@ def reload_extension(self, name): The extension was not loaded. ExtensionNotFound The extension could not be imported. + This is also raised if the name of the extension could not + be resolved using the provided ``package`` parameter. NoEntryPointError The extension does not have a setup function. ExtensionFailed The extension setup function had an execution error. """ + name = self._resolve_name(name, package) lib = self.__extensions.get(name) if lib is None: raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules + # fmt: off modules = { name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name) } + # fmt: on try: # Unload and then load the module... - self._remove_module_references(lib.__name__) - self._call_module_finalizers(lib, name) - self.load_extension(name) - except Exception as e: + await self._remove_module_references(lib.__name__) + await self._call_module_finalizers(lib, name) + await self.load_extension(name) + except Exception: # if the load failed, the remnants should have been # cleaned from the load_extension function call # so let's load it from our old compiled library. - self._load_from_module_spec(lib, name) + await lib.setup(self) + self.__extensions[name] = lib # revert sys.modules back to normal and raise back to caller sys.modules.update(modules) raise @property - def extensions(self): + def extensions(self) -> Mapping[str, types.ModuleType]: """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" return types.MappingProxyType(self.__extensions) # help command stuff @property - def help_command(self): + def help_command(self) -> Optional[HelpCommand]: return self._help_command @help_command.setter - def help_command(self, value): + def help_command(self, value: Optional[HelpCommand]) -> None: if value is not None: if not isinstance(value, HelpCommand): raise TypeError('help_command must be a subclass of HelpCommand') @@ -734,14 +1179,32 @@ def help_command(self, value): else: self._help_command = None + # application command interop + + # As mentioned above, this is a mixin so the Self type hint fails here. + # However, since the only classes that can use this are subclasses of Client + # anyway, then this is sound. + @property + def tree(self) -> app_commands.CommandTree[Self]: # type: ignore + """:class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands + in this bot. + + .. versionadded:: 2.0 + """ + return self.__tree + # command processing - async def get_prefix(self, message): + async def get_prefix(self, message: Message, /) -> Union[List[str], str]: """|coro| Retrieves the prefix the bot is listening to with the message as a context. + .. versionchanged:: 2.0 + + ``message`` parameter is now positional-only. + Parameters ----------- message: :class:`discord.Message` @@ -754,30 +1217,54 @@ async def get_prefix(self, message): listening for. """ prefix = ret = self.command_prefix + if callable(prefix): - ret = await discord.utils.maybe_coroutine(prefix, self, message) + # self will be a Bot or AutoShardedBot + ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore if not isinstance(ret, str): try: - ret = list(ret) + ret = list(ret) # type: ignore except TypeError: # It's possible that a generator raised this exception. Don't # replace it with our own error if that's the case. - if isinstance(ret, collections.Iterable): + if isinstance(ret, collections.abc.Iterable): raise - raise TypeError("command_prefix must be plain string, iterable of strings, or callable " - "returning either of these, not {}".format(ret.__class__.__name__)) - - if not ret: - raise ValueError("Iterable command_prefix must contain at least one prefix") + raise TypeError( + 'command_prefix must be plain string, iterable of strings, or callable ' + f'returning either of these, not {ret.__class__.__name__}' + ) return ret - async def get_context(self, message, *, cls=Context): + @overload + async def get_context( + self, + origin: Union[Message, Interaction], + /, + ) -> Context[Self]: # type: ignore + ... + + @overload + async def get_context( + self, + origin: Union[Message, Interaction], + /, + *, + cls: Type[ContextT], + ) -> ContextT: ... + + async def get_context( + self, + origin: Union[Message, Interaction], + /, + *, + cls: Type[ContextT] = MISSING, + ) -> Any: r"""|coro| - Returns the invocation context from the message. + Returns the invocation context from the message or interaction. This is a more low-level counter-part for :meth:`.process_commands` to allow users more fine grained control over the processing. @@ -787,10 +1274,20 @@ async def get_context(self, message, *, cls=Context): If the context is not valid then it is not a valid candidate to be invoked under :meth:`~.Bot.invoke`. + .. note:: + + In order for the custom context to be used inside an interaction-based + context (such as :class:`HybridCommand`) then this method must be + overridden to return that class. + + .. versionchanged:: 2.0 + + ``message`` parameter is now positional-only and renamed to ``origin``. + Parameters ----------- - message: :class:`discord.Message` - The message to get the invocation context from. + origin: Union[:class:`discord.Message`, :class:`discord.Interaction`] + The message or interaction to get the invocation context from. cls The factory class that will be used to create the context. By default, this is :class:`.Context`. Should a custom @@ -803,14 +1300,19 @@ class be provided, it must be similar enough to :class:`.Context`\'s The invocation context. The type of this can change via the ``cls`` parameter. """ + if cls is MISSING: + cls = Context # type: ignore - view = StringView(message.content) - ctx = cls(prefix=None, view=view, bot=self, message=message) + if isinstance(origin, discord.Interaction): + return await cls.from_interaction(origin) - if self._skip_check(message.author.id, self.user.id): + view = StringView(origin.content) + ctx = cls(prefix=None, view=view, bot=self, message=origin) + + if origin.author.id == self.user.id: # type: ignore return ctx - prefix = await self.get_prefix(message) + prefix = await self.get_prefix(origin) invoked_prefix = prefix if isinstance(prefix, str): @@ -820,37 +1322,48 @@ class be provided, it must be similar enough to :class:`.Context`\'s try: # if the context class' __init__ consumes something from the view this # will be wrong. That seems unreasonable though. - if message.content.startswith(tuple(prefix)): + if origin.content.startswith(tuple(prefix)): invoked_prefix = discord.utils.find(view.skip_string, prefix) else: return ctx except TypeError: if not isinstance(prefix, list): - raise TypeError("get_prefix must return either a string or a list of string, " - "not {}".format(prefix.__class__.__name__)) + raise TypeError( + f'get_prefix must return either a string or a list of string, not {prefix.__class__.__name__}' + ) # It's possible a bad command_prefix got us here. for value in prefix: if not isinstance(value, str): - raise TypeError("Iterable command_prefix or list returned from get_prefix must " - "contain only strings, not {}".format(value.__class__.__name__)) + raise TypeError( + 'Iterable command_prefix or list returned from get_prefix must ' + f'contain only strings, not {value.__class__.__name__}' + ) # Getting here shouldn't happen raise + if self.strip_after_prefix: + view.skip_ws() + invoker = view.get_word() ctx.invoked_with = invoker - ctx.prefix = invoked_prefix + # type-checker fails to narrow invoked_prefix type. + ctx.prefix = invoked_prefix # type: ignore ctx.command = self.all_commands.get(invoker) return ctx - async def invoke(self, ctx): + async def invoke(self, ctx: Context[BotT], /) -> None: """|coro| Invokes the command given under the invocation context and handles all the internal event dispatch mechanisms. + .. versionchanged:: 2.0 + + ``ctx`` parameter is now positional-only. + Parameters ----------- ctx: :class:`.Context` @@ -861,15 +1374,17 @@ async def invoke(self, ctx): try: if await self.can_run(ctx, call_once=True): await ctx.command.invoke(ctx) + else: + raise errors.CheckFailure('The global check once functions failed.') except errors.CommandError as exc: await ctx.command.dispatch_error(ctx, exc) else: self.dispatch('command_completion', ctx) elif ctx.invoked_with: - exc = errors.CommandNotFound('Command "{}" is not found'.format(ctx.invoked_with)) + exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') self.dispatch('command_error', ctx, exc) - async def process_commands(self, message): + async def process_commands(self, message: Message, /) -> None: """|coro| This function processes the commands that have been registered @@ -886,6 +1401,10 @@ async def process_commands(self, message): This also checks if the message's author is a bot and doesn't call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so. + .. versionchanged:: 2.0 + + ``message`` parameter is now positional-only. + Parameters ----------- message: :class:`discord.Message` @@ -895,13 +1414,15 @@ async def process_commands(self, message): return ctx = await self.get_context(message) - await self.invoke(ctx) + # the type of the invocation context's bot attribute will be correct + await self.invoke(ctx) # type: ignore - async def on_message(self, message): + async def on_message(self, message: Message, /) -> None: await self.process_commands(message) + class Bot(BotBase, discord.Client): - """Represents a discord bot. + """Represents a Discord bot. This class is a subclass of :class:`discord.Client` and as a result anything that you can do with a :class:`discord.Client` you can do with @@ -910,6 +1431,18 @@ class Bot(BotBase, discord.Client): This class also subclasses :class:`.GroupMixin` to provide the functionality to manage commands. + Unlike :class:`discord.Client`, this class does not require manually setting + a :class:`~discord.app_commands.CommandTree` and is automatically set upon + instantiating the class. + + .. container:: operations + + .. describe:: async with x + + Asynchronously initialises the bot and automatically cleans up. + + .. versionadded:: 2.0 + Attributes ----------- command_prefix @@ -929,8 +1462,7 @@ class Bot(BotBase, discord.Client): The command prefix could also be an iterable of strings indicating that multiple checks for the prefix should be used and the first one to match will be the invocation prefix. You can get this prefix via - :attr:`.Context.prefix`. To avoid confusion empty iterables are not - allowed. + :attr:`.Context.prefix`. .. note:: @@ -947,23 +1479,75 @@ class Bot(BotBase, discord.Client): you require group commands to be case insensitive as well. description: :class:`str` The content prefixed into the default help message. - self_bot: :class:`bool` - If ``True``, the bot will only listen to commands invoked by itself rather - than ignoring itself. If ``False`` (the default) then the bot will ignore - itself. This cannot be changed once initialised. help_command: Optional[:class:`.HelpCommand`] The help command implementation to use. This can be dynamically set at runtime. To remove the help command pass ``None``. For more information on implementing a help command, see :ref:`ext_commands_help_command`. owner_id: Optional[:class:`int`] - The ID that owns the bot. If this is not set and is then queried via + The user ID that owns the bot. If this is not set and is then queried via :meth:`.is_owner` then it is fetched automatically using :meth:`~.Bot.application_info`. + owner_ids: Optional[Collection[:class:`int`]] + The user IDs that owns the bot. This is similar to :attr:`owner_id`. + If this is not set and the application is team based, then it is + fetched automatically using :meth:`~.Bot.application_info`. + For performance reasons it is recommended to use a :class:`set` + for the collection. You cannot set both ``owner_id`` and ``owner_ids``. + + .. versionadded:: 1.3 + strip_after_prefix: :class:`bool` + Whether to strip whitespace characters after encountering the command + prefix. This allows for ``! hello`` and ``!hello`` to both work if + the ``command_prefix`` is set to ``!``. Defaults to ``False``. + + .. versionadded:: 1.7 + tree_cls: Type[:class:`~discord.app_commands.CommandTree`] + The type of application command tree to use. Defaults to :class:`~discord.app_commands.CommandTree`. + + .. versionadded:: 2.0 + allowed_contexts: :class:`~discord.app_commands.AppCommandContext` + The default allowed contexts that applies to all application commands + in the application command tree. + + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 + allowed_installs: :class:`~discord.app_commands.AppInstallationType` + The default allowed install locations that apply to all application commands + in the application command tree. + + Note that you can override this on a per command basis. + + .. versionadded:: 2.4 """ + pass + class AutoShardedBot(BotBase, discord.AutoShardedClient): """This is similar to :class:`.Bot` except that it is inherited from :class:`discord.AutoShardedClient` instead. + + .. container:: operations + + .. describe:: async with x + + Asynchronously initialises the bot and automatically cleans. + + .. versionadded:: 2.0 """ - pass + + if TYPE_CHECKING: + + def __init__( + self, + command_prefix: PrefixType[BotT], + *, + help_command: Optional[HelpCommand] = _default, + tree_cls: Type[app_commands.CommandTree[Any]] = app_commands.CommandTree, + description: Optional[str] = None, + allowed_contexts: app_commands.AppCommandContext = MISSING, + allowed_installs: app_commands.AppInstallationType = MISSING, + intents: discord.Intents, + **kwargs: Unpack[_AutoShardedBotOptions], + ) -> None: ... diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 9f3c745ce595..b6d2ab0c1805 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,15 +22,55 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import inspect -import copy -from ._types import _BaseCommand +import discord +import logging +from discord import app_commands +from discord.utils import maybe_coroutine, _to_kebab_case + +from typing import ( + Any, + Callable, + ClassVar, + Coroutine, + Dict, + Generator, + Iterable, + List, + Optional, + TYPE_CHECKING, + Sequence, + Tuple, + TypeVar, + Union, +) + +from ._types import _BaseCommand, BotT + +if TYPE_CHECKING: + from typing_extensions import Self + from discord.abc import Snowflake + from discord._types import ClientT + + from .bot import BotBase + from .context import Context + from .core import Command + __all__ = ( 'CogMeta', 'Cog', + 'GroupCog', ) +FuncT = TypeVar('FuncT', bound=Callable[..., Any]) + +MISSING: Any = discord.utils.MISSING +_log = logging.getLogger(__name__) + + class CogMeta(type): """A metaclass for defining a cog. @@ -70,9 +108,14 @@ class MyCog(commands.Cog, name='My Cog'): ----------- name: :class:`str` The cog name. By default, it is the name of the class with no modification. + description: :class:`str` + The cog description. By default, it is the cleaned docstring of the class. + + .. versionadded:: 1.6 + command_attrs: :class:`dict` A list of attributes to apply to every command inside this cog. The dictionary - is passed into the :class:`Command` (or its subclass) options at ``__init__``. + is passed into the :class:`Command` options at ``__init__``. If you specify attributes inside the command attribute in the class, it will override the one specified inside this attribute. For example: @@ -86,14 +129,84 @@ async def foo(self, ctx): @commands.command(hidden=False) async def bar(self, ctx): pass # hidden -> False + + group_name: Union[:class:`str`, :class:`~discord.app_commands.locale_str`] + The group name of a cog. This is only applicable for :class:`GroupCog` instances. + By default, it's the same value as :attr:`name`. + + .. versionadded:: 2.0 + group_description: Union[:class:`str`, :class:`~discord.app_commands.locale_str`] + The group description of a cog. This is only applicable for :class:`GroupCog` instances. + By default, it's the same value as :attr:`description`. + + .. versionadded:: 2.0 + group_nsfw: :class:`bool` + Whether the application command group is NSFW. This is only applicable for :class:`GroupCog` instances. + By default, it's ``False``. + + .. versionadded:: 2.0 + group_auto_locale_strings: :class:`bool` + If this is set to ``True``, then all translatable strings will implicitly + be wrapped into :class:`~discord.app_commands.locale_str` rather + than :class:`str`. Defaults to ``True``. + + .. versionadded:: 2.0 + group_extras: :class:`dict` + A dictionary that can be used to store extraneous data. + This is only applicable for :class:`GroupCog` instances. + The library will not touch any values or keys within this dictionary. + + .. versionadded:: 2.1 """ - def __new__(cls, *args, **kwargs): + __cog_name__: str + __cog_description__: str + __cog_group_name__: Union[str, app_commands.locale_str] + __cog_group_description__: Union[str, app_commands.locale_str] + __cog_group_nsfw__: bool + __cog_group_auto_locale_strings__: bool + __cog_group_extras__: Dict[Any, Any] + __cog_settings__: Dict[str, Any] + __cog_commands__: List[Command[Any, ..., Any]] + __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] + __cog_listeners__: List[Tuple[str, str]] + + def __new__(cls, *args: Any, **kwargs: Any) -> CogMeta: name, bases, attrs = args - attrs['__cog_name__'] = kwargs.pop('name', name) - attrs['__cog_settings__'] = command_attrs = kwargs.pop('command_attrs', {}) + if any(issubclass(base, app_commands.Group) for base in bases): + raise TypeError( + 'Cannot inherit from app_commands.Group with commands.Cog, consider using commands.GroupCog instead' + ) + + # If name='...' is given but not group_name='...' then name='...' is used for both. + # If neither is given then cog name is the class name but group name is kebab case + try: + cog_name = kwargs.pop('name') + except KeyError: + cog_name = name + try: + group_name = kwargs.pop('group_name') + except KeyError: + group_name = _to_kebab_case(name) + else: + group_name = kwargs.pop('group_name', cog_name) + + attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) + attrs['__cog_name__'] = cog_name + attrs['__cog_group_name__'] = group_name + attrs['__cog_group_nsfw__'] = kwargs.pop('group_nsfw', False) + attrs['__cog_group_auto_locale_strings__'] = kwargs.pop('group_auto_locale_strings', True) + attrs['__cog_group_extras__'] = kwargs.pop('group_extras', {}) + + description = kwargs.pop('description', None) + if description is None: + description = inspect.cleandoc(attrs.get('__doc__', '')) + + attrs['__cog_description__'] = description + attrs['__cog_group_description__'] = kwargs.pop('group_description', description or '\u2026') commands = {} + cog_app_commands = {} listeners = {} no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' @@ -110,13 +223,19 @@ def __new__(cls, *args, **kwargs): value = value.__func__ if isinstance(value, _BaseCommand): if is_static_method: - raise TypeError('Command in method {0}.{1!r} must not be staticmethod.'.format(base, elem)) + raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') if elem.startswith(('cog_', 'bot_')): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value + elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None: + if is_static_method: + raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') + if elem.startswith(('cog_', 'bot_')): + raise TypeError(no_bot_cog.format(base, elem)) + cog_app_commands[elem] = value elif inspect.iscoroutinefunction(value): try: - is_listener = getattr(value, '__cog_listener__') + getattr(value, '__cog_listener__') except AttributeError: continue else: @@ -124,7 +243,8 @@ def __new__(cls, *args, **kwargs): raise TypeError(no_bot_cog.format(base, elem)) listeners[elem] = value - new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_app_commands__ = list(cog_app_commands.values()) listeners_as_list = [] for listener in listeners.values(): @@ -136,17 +256,19 @@ def __new__(cls, *args, **kwargs): new_cls.__cog_listeners__ = listeners_as_list return new_cls - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args) @classmethod - def qualified_name(cls): + def qualified_name(cls) -> str: return cls.__cog_name__ -def _cog_special_method(func): + +def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func + class Cog(metaclass=CogMeta): """The base class that all cogs must inherit from. @@ -158,7 +280,21 @@ class Cog(metaclass=CogMeta): are equally valid here. """ - def __new__(cls, *args, **kwargs): + __cog_name__: str + __cog_description__: str + __cog_group_name__: Union[str, app_commands.locale_str] + __cog_group_description__: Union[str, app_commands.locale_str] + __cog_settings__: Dict[str, Any] + __cog_commands__: List[Command[Self, ..., Any]] + __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] + __cog_listeners__: List[Tuple[str, str]] + __cog_is_app_commands_group__: ClassVar[bool] = False + __cog_app_commands_group__: Optional[app_commands.Group] + __discord_app_commands_error_handler__: Optional[ + Callable[[discord.Interaction, app_commands.AppCommandError], Coroutine[Any, Any, None]] + ] + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: # For issue 426, we need to store a copy of the command objects # since we modify them to inject `self` to them. # To do this, we need to interfere with the Cog creation process. @@ -166,12 +302,33 @@ def __new__(cls, *args, **kwargs): cmd_attrs = cls.__cog_settings__ # Either update the command with the cog provided defaults or copy it. - self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) - - lookup = { - cmd.qualified_name: cmd - for cmd in self.__cog_commands__ - } + # r.e type ignore, type-checker complains about overriding a ClassVar + self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore + + lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__} + + # Register the application commands + children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = [] + app_command_refs: Dict[str, Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = {} + + if cls.__cog_is_app_commands_group__: + group = app_commands.Group( + name=cls.__cog_group_name__, + description=cls.__cog_group_description__, + nsfw=cls.__cog_group_nsfw__, + auto_locale_strings=cls.__cog_group_auto_locale_strings__, + parent=None, + guild_ids=getattr(cls, '__discord_app_commands_default_guilds__', None), + guild_only=getattr(cls, '__discord_app_commands_guild_only__', False), + allowed_contexts=getattr(cls, '__discord_app_commands_contexts__', None), + allowed_installs=getattr(cls, '__discord_app_commands_installation_types__', None), + default_permissions=getattr(cls, '__discord_app_commands_default_permissions__', None), + extras=cls.__cog_group_extras__, + ) + else: + group = None + + self.__cog_app_commands_group__ = group # Update the Command instances dynamically as well for command in self.__cog_commands__: @@ -179,58 +336,168 @@ def __new__(cls, *args, **kwargs): parent = command.parent if parent is not None: # Get the latest parent reference - parent = lookup[parent.qualified_name] + parent = lookup[parent.qualified_name] # type: ignore + + # Hybrid commands already deal with updating the reference + # Due to the copy below, so we need to handle them specially + if hasattr(parent, '__commands_is_hybrid__') and hasattr(command, '__commands_is_hybrid__'): + current: Optional[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = getattr( + command, 'app_command', None + ) + updated = app_command_refs.get(command.qualified_name) + if current and updated: + command.app_command = updated # type: ignore # Safe attribute access # Update our parent's reference to our self - removed = parent.remove_command(command.name) - parent.add_command(command) + parent.remove_command(command.name) # type: ignore + parent.add_command(command) # type: ignore + + if hasattr(command, '__commands_is_hybrid__') and parent is None: + app_command: Optional[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = getattr( + command, 'app_command', None + ) + if app_command: + group_parent = self.__cog_app_commands_group__ + app_command = app_command._copy_with(parent=group_parent, binding=self) + # The type checker does not see the app_command attribute even though it exists + command.app_command = app_command # type: ignore + + # Update all the references to point to the new copy + if isinstance(app_command, app_commands.Group): + for child in app_command.walk_commands(): + app_command_refs[child.qualified_name] = child + if hasattr(child, '__commands_is_hybrid_app_command__') and child.qualified_name in lookup: + child.wrapped = lookup[child.qualified_name] # type: ignore + + if self.__cog_app_commands_group__: + children.append(app_command) + + if Cog._get_overridden_method(self.cog_app_command_error) is not None: + error_handler = self.cog_app_command_error + else: + error_handler = None + + self.__discord_app_commands_error_handler__ = error_handler + + for command in cls.__cog_app_commands__: + copy = command._copy_with(parent=self.__cog_app_commands_group__, binding=self) + + # Update set bindings + if copy._attr: + setattr(self, copy._attr, copy) + + if isinstance(copy, app_commands.Group): + copy.__discord_app_commands_error_handler__ = error_handler + for command in copy._children.values(): + if isinstance(command, app_commands.Group): + command.__discord_app_commands_error_handler__ = error_handler + + children.append(copy) + + self.__cog_app_commands__ = children + if self.__cog_app_commands_group__: + self.__cog_app_commands_group__.module = cls.__module__ + mapping = {cmd.name: cmd for cmd in children} + if len(mapping) > 25: + raise TypeError('maximum number of application command children exceeded') + + self.__cog_app_commands_group__._children = mapping return self - def get_commands(self): - r"""Returns a :class:`list` of :class:`.Command`\s that are - defined inside this cog. + def get_commands(self) -> List[Command[Self, ..., Any]]: + r"""Returns the commands that are defined inside this cog. - .. note:: + This does *not* include :class:`discord.app_commands.Command` or :class:`discord.app_commands.Group` + instances. - This does not include subcommands. + Returns + -------- + List[:class:`.Command`] + A :class:`list` of :class:`.Command`\s that are + defined inside this cog, not including subcommands. """ return [c for c in self.__cog_commands__ if c.parent is None] + def get_app_commands(self) -> List[Union[app_commands.Command[Self, ..., Any], app_commands.Group]]: + r"""Returns the app commands that are defined inside this cog. + + Returns + -------- + List[Union[:class:`discord.app_commands.Command`, :class:`discord.app_commands.Group`]] + A :class:`list` of :class:`discord.app_commands.Command`\s and :class:`discord.app_commands.Group`\s that are + defined inside this cog, not including subcommands. + """ + return [c for c in self.__cog_app_commands__ if c.parent is None] + @property - def qualified_name(self): + def qualified_name(self) -> str: """:class:`str`: Returns the cog's specified name, not the class name.""" return self.__cog_name__ @property - def description(self): + def description(self) -> str: """:class:`str`: Returns the cog's description, typically the cleaned docstring.""" - try: - return self.__cog_cleaned_doc__ - except AttributeError: - self.__cog_cleaned_doc__ = cleaned = inspect.getdoc(self) - return cleaned + return self.__cog_description__ + + @description.setter + def description(self, description: str) -> None: + self.__cog_description__ = description + + def walk_commands(self) -> Generator[Command[Self, ..., Any], None, None]: + """An iterator that recursively walks through this cog's commands and subcommands. - def walk_commands(self): - """An iterator that recursively walks through this cog's commands and subcommands.""" + Yields + ------ + Union[:class:`.Command`, :class:`.Group`] + A command or group from the cog. + """ from .core import GroupMixin + for command in self.__cog_commands__: if command.parent is None: yield command if isinstance(command, GroupMixin): yield from command.walk_commands() - def get_listeners(self): - """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.""" + def walk_app_commands(self) -> Generator[Union[app_commands.Command[Self, ..., Any], app_commands.Group], None, None]: + """An iterator that recursively walks through this cog's app commands and subcommands. + + Yields + ------ + Union[:class:`discord.app_commands.Command`, :class:`discord.app_commands.Group`] + An app command or group from the cog. + """ + for command in self.__cog_app_commands__: + yield command + if isinstance(command, app_commands.Group): + yield from command.walk_commands() + + @property + def app_command(self) -> Optional[app_commands.Group]: + """Optional[:class:`discord.app_commands.Group`]: Returns the associated group with this cog. + + This is only available if inheriting from :class:`GroupCog`. + """ + return self.__cog_app_commands_group__ + + def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]: + """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. + + Returns + -------- + List[Tuple[:class:`str`, :ref:`coroutine `]] + The listeners defined in this cog. + """ return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] @classmethod - def _get_overridden_method(cls, method): + def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" return getattr(method.__func__, '__cog_special_method__', method) @classmethod - def listener(cls, name=None): + def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: """A decorator that marks a function as a listener. This is the cog equivalent of :meth:`.Bot.listen`. @@ -248,10 +515,10 @@ def listener(cls, name=None): the name. """ - if name is not None and not isinstance(name, str): - raise TypeError('Cog.listener expected str but received {0.__class__.__name__!r} instead.'.format(name)) + if name is not MISSING and not isinstance(name, str): + raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__} instead.') - def decorator(func): + def decorator(func: FuncT) -> FuncT: actual = func if isinstance(actual, staticmethod): actual = actual.__func__ @@ -268,21 +535,55 @@ def decorator(func): # to pick it up but the metaclass unfurls the function and # thus the assignments need to be on the actual function return func + return decorator + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the cog has an error handler. + + .. versionadded:: 1.7 + """ + return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') + + def has_app_command_error_handler(self) -> bool: + """:class:`bool`: Checks whether the cog has an app error handler. + + .. versionadded:: 2.1 + """ + return not hasattr(self.cog_app_command_error.__func__, '__cog_special_method__') + @_cog_special_method - def cog_unload(self): - """A special method that is called when the cog gets removed. + async def cog_load(self) -> None: + """|maybecoro| + + A special method that is called when the cog gets loaded. - This function **cannot** be a coroutine. It must be a regular - function. + Subclasses must replace this if they want special asynchronous loading behaviour. + Note that the ``__init__`` special method does not allow asynchronous code to run + inside it, thus this is helpful for setting up code that needs to be asynchronous. + + .. versionadded:: 2.0 + """ + pass + + @_cog_special_method + async def cog_unload(self) -> None: + """|maybecoro| + + A special method that is called when the cog gets removed. Subclasses must replace this if they want special unloading behaviour. + + Exceptions raised in this method are ignored during extension unloading. + + .. versionchanged:: 2.0 + + This method can now be a :term:`coroutine`. """ pass @_cog_special_method - def bot_check_once(self, ctx): + def bot_check_once(self, ctx: Context[BotT]) -> bool: """A special method that registers as a :meth:`.Bot.check_once` check. @@ -292,7 +593,7 @@ def bot_check_once(self, ctx): return True @_cog_special_method - def bot_check(self, ctx): + def bot_check(self, ctx: Context[BotT]) -> bool: """A special method that registers as a :meth:`.Bot.check` check. @@ -302,8 +603,8 @@ def bot_check(self, ctx): return True @_cog_special_method - def cog_check(self, ctx): - """A special method that registers as a :func:`commands.check` + def cog_check(self, ctx: Context[BotT]) -> bool: + """A special method that registers as a :func:`~discord.ext.commands.check` for every command and subcommand in this cog. This function **can** be a coroutine and must take a sole parameter, @@ -312,14 +613,28 @@ def cog_check(self, ctx): return True @_cog_special_method - def cog_command_error(self, ctx, error): - """A special method that is called whenever an error + def interaction_check(self, interaction: discord.Interaction[ClientT], /) -> bool: + """A special method that registers as a :func:`discord.app_commands.check` + for every app command and subcommand in this cog. + + This function **can** be a coroutine and must take a sole parameter, + ``interaction``, to represent the :class:`~discord.Interaction`. + + .. versionadded:: 2.0 + """ + return True + + @_cog_special_method + async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None: + """|coro| + + A special method that is called whenever an error is dispatched inside this cog. This is similar to :func:`.on_command_error` except only applying to the commands inside this cog. - This function **can** be a coroutine. + This **must** be a coroutine. Parameters ----------- @@ -331,8 +646,31 @@ def cog_command_error(self, ctx, error): pass @_cog_special_method - async def cog_before_invoke(self, ctx): - """A special method that acts as a cog local pre-invoke hook. + async def cog_app_command_error(self, interaction: discord.Interaction, error: app_commands.AppCommandError) -> None: + """|coro| + + A special method that is called whenever an error within + an application command is dispatched inside this cog. + + This is similar to :func:`discord.app_commands.CommandTree.on_error` except + only applying to the application commands inside this cog. + + This **must** be a coroutine. + + Parameters + ----------- + interaction: :class:`~discord.Interaction` + The interaction that is being handled. + error: :exc:`~discord.app_commands.AppCommandError` + The exception that was raised. + """ + pass + + @_cog_special_method + async def cog_before_invoke(self, ctx: Context[BotT]) -> None: + """|coro| + + A special method that acts as a cog local pre-invoke hook. This is similar to :meth:`.Command.before_invoke`. @@ -346,8 +684,10 @@ async def cog_before_invoke(self, ctx): pass @_cog_special_method - async def cog_after_invoke(self, ctx): - """A special method that acts as a cog local post-invoke hook. + async def cog_after_invoke(self, ctx: Context[BotT]) -> None: + """|coro| + + A special method that acts as a cog local post-invoke hook. This is similar to :meth:`.Command.after_invoke`. @@ -360,9 +700,13 @@ async def cog_after_invoke(self, ctx): """ pass - def _inject(self, bot): + async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: Sequence[Snowflake]) -> Self: cls = self.__class__ + # we'll call this first so that errors can propagate without + # having to worry about undoing anything + await maybe_coroutine(self.cog_load) + # realistically, the only thing that can cause loading errors # is essentially just the command loading, which raises if there are # duplicates. When this condition is met, we want to undo all what @@ -375,8 +719,12 @@ def _inject(self, bot): except Exception as e: # undo our additions for to_undo in self.__cog_commands__[:index]: - bot.remove_command(to_undo) - raise e + if to_undo.parent is None: + bot.remove_command(to_undo.name) + try: + await maybe_coroutine(self.cog_unload) + finally: + raise e # check if we're overriding the default if cls.bot_check is not Cog.bot_check: @@ -392,9 +740,15 @@ def _inject(self, bot): for name, method_name in self.__cog_listeners__: bot.add_listener(getattr(self, method_name), name) + # Only do this if these are "top level" commands + if not self.__cog_app_commands_group__: + for command in self.__cog_app_commands__: + # This is already atomic + bot.tree.add_command(command, override=override, guild=guild, guilds=guilds) + return self - def _eject(self, bot): + async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None: cls = self.__class__ try: @@ -402,8 +756,17 @@ def _eject(self, bot): if command.parent is None: bot.remove_command(command.name) - for _, method_name in self.__cog_listeners__: - bot.remove_listener(getattr(self, method_name)) + if not self.__cog_app_commands_group__: + for command in self.__cog_app_commands__: + guild_ids = guild_ids or command._guild_ids + if guild_ids is None: + bot.tree.remove_command(command.name) + else: + for guild_id in guild_ids: + bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id)) + + for name, method_name in self.__cog_listeners__: + bot.remove_listener(getattr(self, method_name), name) if cls.bot_check is not Cog.bot_check: bot.remove_check(self.bot_check) @@ -411,4 +774,38 @@ def _eject(self, bot): if cls.bot_check_once is not Cog.bot_check_once: bot.remove_check(self.bot_check_once, call_once=True) finally: - self.cog_unload() + try: + await maybe_coroutine(self.cog_unload) + except Exception: + _log.exception('Ignoring exception in cog unload for Cog %r (%r)', cls, self.qualified_name) + + +class GroupCog(Cog): + """Represents a cog that also doubles as a parent :class:`discord.app_commands.Group` for + the application commands defined within it. + + This inherits from :class:`Cog` and the options in :class:`CogMeta` also apply to this. + See the :class:`Cog` documentation for methods. + + Decorators such as :func:`~discord.app_commands.guild_only`, :func:`~discord.app_commands.guilds`, + and :func:`~discord.app_commands.default_permissions` will apply to the group if used on top of the + cog. + + Hybrid commands will also be added to the Group, giving the ability to categorize slash commands into + groups, while keeping the prefix-style command as a root-level command. + + For example: + + .. code-block:: python3 + + from discord import app_commands + from discord.ext import commands + + @app_commands.guild_only() + class MyCog(commands.GroupCog, group_name='my-cog'): + pass + + .. versionadded:: 2.0 + """ + + __cog_is_app_commands_group__: ClassVar[bool] = True diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index fb9dcfc25805..968fec419130 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,10 +22,92 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any, Dict, Generator, Generic, List, Optional, TypeVar, Union, Sequence, Type, overload + import discord.abc import discord.utils +from discord import Interaction, Message, Attachment, MessageType, User, PartialMessageable, Permissions, ChannelType, Thread +from discord.context_managers import Typing +from .view import StringView + +from ._types import BotT + +if TYPE_CHECKING: + from typing_extensions import Self, ParamSpec, TypeGuard + + from discord.abc import MessageableChannel + from discord.guild import Guild + from discord.member import Member + from discord.state import ConnectionState + from discord.user import ClientUser + from discord.voice_client import VoiceProtocol + from discord.embeds import Embed + from discord.file import File + from discord.mentions import AllowedMentions + from discord.sticker import GuildSticker, StickerItem + from discord.message import MessageReference, PartialMessage + from discord.ui.view import BaseView, View, LayoutView + from discord.types.interactions import ApplicationCommandInteractionData + from discord.poll import Poll + + from .cog import Cog + from .core import Command + from .parameters import Parameter + + from types import TracebackType + + BE = TypeVar('BE', bound=BaseException) + +# fmt: off +__all__ = ( + 'Context', +) +# fmt: on + +MISSING: Any = discord.utils.MISSING + + +T = TypeVar('T') +CogT = TypeVar('CogT', bound='Cog') + +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') + + +def is_cog(obj: Any) -> TypeGuard[Cog]: + return hasattr(obj, '__cog_commands__') + + +class DeferTyping(Generic[BotT]): + def __init__(self, ctx: Context[BotT], *, ephemeral: bool): + self.ctx: Context[BotT] = ctx + self.ephemeral: bool = ephemeral + + async def do_defer(self) -> None: + if self.ctx.interaction and not self.ctx.interaction.response.is_done(): + await self.ctx.interaction.response.defer(ephemeral=self.ephemeral) + + def __await__(self) -> Generator[Any, None, None]: + return self.do_defer().__await__() + + async def __aenter__(self) -> None: + await self.do_defer() + + async def __aexit__( + self, + exc_type: Optional[Type[BE]], + exc: Optional[BE], + traceback: Optional[TracebackType], + ) -> None: + pass -class Context(discord.abc.Messageable): + +class Context(discord.abc.Messageable, Generic[BotT]): r"""Represents the context in which a command is being invoked under. This class contains a lot of meta data to help you understand more about @@ -40,28 +120,55 @@ class Context(discord.abc.Messageable): ----------- message: :class:`.Message` The message that triggered the command being executed. + + .. note:: + + In the case of an interaction based context, this message is "synthetic" + and does not actually exist. Therefore, the ID on it is invalid similar + to ephemeral messages. bot: :class:`.Bot` The bot that contains the command being executed. args: :class:`list` The list of transformed arguments that were passed into the command. - If this is accessed during the :func:`on_command_error` event + If this is accessed during the :func:`.on_command_error` event then this list could be incomplete. kwargs: :class:`dict` A dictionary of transformed arguments that were passed into the command. Similar to :attr:`args`\, if this is accessed in the - :func:`on_command_error` event then this dict could be incomplete. - prefix: :class:`str` - The prefix that was used to invoke the command. - command - The command (i.e. :class:`.Command` or its subclasses) that is being - invoked currently. - invoked_with: :class:`str` + :func:`.on_command_error` event then this dict could be incomplete. + current_parameter: Optional[:class:`Parameter`] + The parameter that is currently being inspected and converted. + This is only of use for within converters. + + .. versionadded:: 2.0 + current_argument: Optional[:class:`str`] + The argument string of the :attr:`current_parameter` that is currently being converted. + This is only of use for within converters. + + .. versionadded:: 2.0 + interaction: Optional[:class:`~discord.Interaction`] + The interaction associated with this context. + + .. versionadded:: 2.0 + prefix: Optional[:class:`str`] + The prefix that was used to invoke the command. For interaction based contexts, + this is ``/`` for slash commands and ``\u200b`` for context menu commands. + command: Optional[:class:`Command`] + The command that is being invoked currently. + invoked_with: Optional[:class:`str`] The command name that triggered this invocation. Useful for finding out which alias called the command. - invoked_subcommand - The subcommand (i.e. :class:`.Command` or its subclasses) that was - invoked. If no valid subcommand was invoked then this is equal to - ``None``. + invoked_parents: List[:class:`str`] + The command names of the parents that triggered this invocation. Useful for + finding out which aliases called the command. + + For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``. + + .. versionadded:: 1.7 + + invoked_subcommand: Optional[:class:`Command`] + The subcommand that was invoked. + If no valid subcommand was invoked then this is equal to ``None``. subcommand_passed: Optional[:class:`str`] The string that was attempted to call a subcommand. This does not have to point to a valid registered subcommand and could just point to a @@ -72,21 +179,132 @@ class Context(discord.abc.Messageable): or invoked. """ - def __init__(self, **attrs): - self.message = attrs.pop('message', None) - self.bot = attrs.pop('bot', None) - self.args = attrs.pop('args', []) - self.kwargs = attrs.pop('kwargs', {}) - self.prefix = attrs.pop('prefix') - self.command = attrs.pop('command', None) - self.view = attrs.pop('view', None) - self.invoked_with = attrs.pop('invoked_with', None) - self.invoked_subcommand = attrs.pop('invoked_subcommand', None) - self.subcommand_passed = attrs.pop('subcommand_passed', None) - self.command_failed = attrs.pop('command_failed', False) - self._state = self.message._state - - async def invoke(self, *args, **kwargs): + def __init__( + self, + *, + message: Message, + bot: BotT, + view: StringView, + args: List[Any] = MISSING, + kwargs: Dict[str, Any] = MISSING, + prefix: Optional[str] = None, + command: Optional[Command[Any, ..., Any]] = None, + invoked_with: Optional[str] = None, + invoked_parents: List[str] = MISSING, + invoked_subcommand: Optional[Command[Any, ..., Any]] = None, + subcommand_passed: Optional[str] = None, + command_failed: bool = False, + current_parameter: Optional[Parameter] = None, + current_argument: Optional[str] = None, + interaction: Optional[Interaction[BotT]] = None, + ): + self.message: Message = message + self.bot: BotT = bot + self.args: List[Any] = args or [] + self.kwargs: Dict[str, Any] = kwargs or {} + self.prefix: Optional[str] = prefix + self.command: Optional[Command[Any, ..., Any]] = command + self.view: StringView = view + self.invoked_with: Optional[str] = invoked_with + self.invoked_parents: List[str] = invoked_parents or [] + self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand + self.subcommand_passed: Optional[str] = subcommand_passed + self.command_failed: bool = command_failed + self.current_parameter: Optional[Parameter] = current_parameter + self.current_argument: Optional[str] = current_argument + self.interaction: Optional[Interaction[BotT]] = interaction + self._state: ConnectionState = self.message._state + + @classmethod + async def from_interaction(cls, interaction: Interaction[BotT], /) -> Self: + """|coro| + + Creates a context from a :class:`discord.Interaction`. This only + works on application command based interactions, such as slash commands + or context menus. + + On slash command based interactions this creates a synthetic :class:`~discord.Message` + that points to an ephemeral message that the command invoker has executed. This means + that :attr:`Context.author` returns the member that invoked the command. + + In a message context menu based interaction, the :attr:`Context.message` attribute + is the message that the command is being executed on. This means that :attr:`Context.author` + returns the author of the message being targetted. To get the member that invoked + the command then :attr:`discord.Interaction.user` should be used instead. + + .. versionadded:: 2.0 + + Parameters + ----------- + interaction: :class:`discord.Interaction` + The interaction to create a context with. + + Raises + ------- + ValueError + The interaction does not have a valid command. + TypeError + The interaction client is not derived from :class:`Bot` or :class:`AutoShardedBot`. + """ + + # Circular import + from .bot import BotBase + + if not isinstance(interaction.client, BotBase): + raise TypeError('Interaction client is not derived from commands.Bot or commands.AutoShardedBot') + + command = interaction.command + if command is None: + raise ValueError('interaction does not have command data') + + bot: BotT = interaction.client + data: ApplicationCommandInteractionData = interaction.data # type: ignore + if interaction.message is None: + synthetic_payload = { + 'id': interaction.id, + 'reactions': [], + 'embeds': [], + 'mention_everyone': False, + 'tts': False, + 'pinned': False, + 'edited_timestamp': None, + 'type': MessageType.chat_input_command if data.get('type', 1) == 1 else MessageType.context_menu_command, + 'flags': 64, + 'content': '', + 'mentions': [], + 'mention_roles': [], + 'attachments': [], + } + + if interaction.channel_id is None: + raise RuntimeError('interaction channel ID is null, this is probably a Discord bug') + + channel = interaction.channel or PartialMessageable( + state=interaction._state, guild_id=interaction.guild_id, id=interaction.channel_id + ) + message = Message(state=interaction._state, channel=channel, data=synthetic_payload) # type: ignore + message.author = interaction.user + message.attachments = [a for _, a in interaction.namespace if isinstance(a, Attachment)] + else: + message = interaction.message + + prefix = '/' if data.get('type', 1) == 1 else '\u200b' # Mock the prefix + ctx = cls( + message=message, + bot=bot, + view=StringView(''), + args=[], + kwargs={}, + prefix=prefix, + interaction=interaction, + invoked_with=command.name, + command=command, # type: ignore # this will be a hybrid command, technically + ) + interaction._baton = ctx + ctx.command_failed = interaction.command_failed + return ctx + + async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| Calls a command with the arguments given. @@ -103,36 +321,27 @@ async def invoke(self, *args, **kwargs): You must take care in passing the proper arguments when using this function. - .. warning:: + .. versionchanged:: 2.0 - The first parameter passed **must** be the command being invoked. + ``command`` parameter is now positional-only. Parameters ----------- command: :class:`.Command` - A command or subclass of a command that is going to be called. + The command that is going to be called. \*args - The arguments to to use. + The arguments to use. \*\*kwargs The keyword arguments to use. - """ - - try: - command = args[0] - except IndexError: - raise TypeError('Missing command to invoke.') from None - arguments = [] - if command.cog is not None: - arguments.append(command.cog) - - arguments.append(self) - arguments.extend(args[1:]) - - ret = await command.callback(*arguments, **kwargs) - return ret + Raises + ------- + TypeError + The command argument to invoke is missing. + """ + return await command(self, *args, **kwargs) - async def reinvoke(self, *, call_hooks=False, restart=True): + async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: """|coro| Calls the command again. @@ -156,6 +365,11 @@ async def reinvoke(self, *, call_hooks=False, restart=True): Whether to start the call chain from the very beginning or where we left off (i.e. the command that caused the error). The default is to start where we left off. + + Raises + ------- + ValueError + The context to reinvoke is not valid. """ cmd = self.command view = self.view @@ -166,13 +380,15 @@ async def reinvoke(self, *, call_hooks=False, restart=True): index, previous = view.index, view.previous invoked_with = self.invoked_with invoked_subcommand = self.invoked_subcommand + invoked_parents = self.invoked_parents subcommand_passed = self.subcommand_passed if restart: to_call = cmd.root_parent or cmd - view.index = len(self.prefix) + view.index = len(self.prefix or '') view.previous = 0 - view.get_word() # advance to get the root command + self.invoked_parents = [] + self.invoked_with = view.get_word() # advance to get the root command else: to_call = cmd @@ -184,51 +400,154 @@ async def reinvoke(self, *, call_hooks=False, restart=True): view.previous = previous self.invoked_with = invoked_with self.invoked_subcommand = invoked_subcommand + self.invoked_parents = invoked_parents self.subcommand_passed = subcommand_passed @property - def valid(self): - """Checks if the invocation context is valid to be invoked with.""" + def valid(self) -> bool: + """:class:`bool`: Checks if the invocation context is valid to be invoked with.""" return self.prefix is not None and self.command is not None - async def _get_channel(self): + async def _get_channel(self) -> discord.abc.Messageable: return self.channel @property - def cog(self): - """Returns the cog associated with this context's command. None if it does not exist.""" + def clean_prefix(self) -> str: + """:class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``. + + .. versionadded:: 2.0 + """ + if self.prefix is None: + return '' + + user = self.me + # this breaks if the prefix mention is not the bot itself but I + # consider this to be an *incredibly* strange use case. I'd rather go + # for this common use case rather than waste performance for the + # odd one. + pattern = re.compile(r'<@!?%s>' % user.id) + return pattern.sub('@%s' % user.display_name.replace('\\', r'\\'), self.prefix) + + @property + def cog(self) -> Optional[Cog]: + """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist.""" if self.command is None: return None return self.command.cog + @property + def filesize_limit(self) -> int: + """:class:`int`: Returns the maximum number of bytes files can have when uploaded to this guild or DM channel associated with this context. + + .. versionadded:: 2.3 + """ + return self.guild.filesize_limit if self.guild is not None else discord.utils.DEFAULT_FILE_SIZE_LIMIT_BYTES + @discord.utils.cached_property - def guild(self): - """Returns the guild associated with this context's command. None if not available.""" + def guild(self) -> Optional[Guild]: + """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available.""" return self.message.guild @discord.utils.cached_property - def channel(self): - """Returns the channel associated with this context's command. Shorthand for :attr:`.Message.channel`.""" + def channel(self) -> MessageableChannel: + """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. + Shorthand for :attr:`.Message.channel`. + """ return self.message.channel @discord.utils.cached_property - def author(self): - """Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`""" + def author(self) -> Union[User, Member]: + """Union[:class:`~discord.User`, :class:`.Member`]: + Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` + """ return self.message.author @discord.utils.cached_property - def me(self): - """Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.""" - return self.guild.me if self.guild is not None else self.bot.user + def me(self) -> Union[Member, ClientUser]: + """Union[:class:`.Member`, :class:`.ClientUser`]: + Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts. + """ + # bot.user will never be None at this point. + return self.guild.me if self.guild is not None else self.bot.user # type: ignore + + @discord.utils.cached_property + def permissions(self) -> Permissions: + """:class:`.Permissions`: Returns the resolved permissions for the invoking user in this channel. + Shorthand for :meth:`.abc.GuildChannel.permissions_for` or :attr:`.Interaction.permissions`. + + .. versionadded:: 2.0 + """ + if self.interaction is None and self.channel.type is ChannelType.private: + return Permissions._dm_permissions() + if not self.interaction: + # channel and author will always match relevant types here + return self.channel.permissions_for(self.author) # type: ignore + base = self.interaction.permissions + if self.channel.type in (ChannelType.voice, ChannelType.stage_voice): + if not base.connect: + # voice channels cannot be edited by people who can't connect to them + # It also implicitly denies all other voice perms + denied = Permissions.voice() + denied.update(manage_channels=True, manage_roles=True) + base.value &= ~denied.value + else: + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + @discord.utils.cached_property + def bot_permissions(self) -> Permissions: + """:class:`.Permissions`: Returns the resolved permissions for the bot in this channel. + Shorthand for :meth:`.abc.GuildChannel.permissions_for` or :attr:`.Interaction.app_permissions`. + + For interaction-based commands, this will reflect the effective permissions + for :class:`Context` calls, which may differ from calls through + other :class:`.abc.Messageable` endpoints, like :attr:`channel`. + + Notably, sending messages, embedding links, and attaching files are always + permitted, while reading messages might not be. + + .. versionadded:: 2.0 + """ + channel = self.channel + if self.interaction is None and channel.type == ChannelType.private: + return Permissions._dm_permissions() + if not self.interaction: + # channel and me will always match relevant types here + return channel.permissions_for(self.me) # type: ignore + guild = channel.guild + base = self.interaction.app_permissions + if self.channel.type in (ChannelType.voice, ChannelType.stage_voice): + if not base.connect: + # voice channels cannot be edited by people who can't connect to them + # It also implicitly denies all other voice perms + denied = Permissions.voice() + denied.update(manage_channels=True, manage_roles=True) + base.value &= ~denied.value + else: + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + base.update( + embed_links=True, + attach_files=True, + send_tts_messages=False, + ) + if isinstance(channel, Thread): + base.send_messages_in_threads = True + else: + base.send_messages = True + return base @property - def voice_client(self): - r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" + def voice_client(self) -> Optional[VoiceProtocol]: + r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" g = self.guild return g.voice_client if g else None - async def send_help(self, *args): + async def send_help(self, *args: Any) -> Any: """send_help(entity=) |coro| @@ -246,7 +565,7 @@ async def send_help(self, *args): Due to the way this function works, instead of returning something similar to :meth:`~.commands.HelpCommand.command_not_found` - this returns :class:`None` on bad input or no help command. + this returns ``None`` on bad input or no help command. Parameters ------------ @@ -258,7 +577,8 @@ async def send_help(self, *args): Any The result of the help command, if any. """ - from .core import Group, Command + from .core import Command, Group, wrap_callback + from .errors import CommandError bot = self.bot cmd = bot.help_command @@ -268,31 +588,561 @@ async def send_help(self, *args): cmd = cmd.copy() cmd.context = self + if len(args) == 0: await cmd.prepare_help_command(self, None) mapping = cmd.get_bot_mapping() - return await cmd.send_bot_help(mapping) + injected = wrap_callback(cmd.send_bot_help) + try: + return await injected(mapping) + except CommandError as e: + await cmd.on_help_command_error(self, e) + return None entity = args[0] - if entity is None: - return None - if isinstance(entity, str): entity = bot.get_cog(entity) or bot.get_command(entity) + if entity is None: + return None + try: - qualified_name = entity.qualified_name + entity.qualified_name except AttributeError: # if we're here then it's not a cog, group, or command. return None await cmd.prepare_help_command(self, entity.qualified_name) - if hasattr(entity, '__cog_commands__'): - return await cmd.send_cog_help(entity) - elif isinstance(entity, Group): - return await cmd.send_group_help(entity) - elif isinstance(entity, Command): - return await cmd.send_command_help(entity) + try: + if is_cog(entity): + injected = wrap_callback(cmd.send_cog_help) + return await injected(entity) + elif isinstance(entity, Group): + injected = wrap_callback(cmd.send_group_help) + return await injected(entity) + elif isinstance(entity, Command): + injected = wrap_callback(cmd.send_command_help) + return await injected(entity) + else: + return None + except CommandError as e: + await cmd.on_help_command_error(self, e) + + @overload + async def reply( + self, + *, + file: File = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def reply( + self, + *, + files: Sequence[File] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def reply( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def reply( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def reply( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def reply( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: + """|coro| + + A shortcut method to :meth:`send` to reply to the + :class:`~discord.Message` referenced by this context. + + For interaction based contexts, this is the same as :meth:`send`. + + .. versionadded:: 1.6 + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. + + Raises + -------- + ~discord.HTTPException + Sending the message failed. + ~discord.Forbidden + You do not have the proper permissions to send the message. + ValueError + The ``files`` list is not of the appropriate size + TypeError + You specified both ``file`` and ``files``. + + Returns + --------- + :class:`~discord.Message` + The message that was sent. + """ + if self.interaction is None: + return await self.send(content, reference=self.message, **kwargs) else: - return None + return await self.send(content, **kwargs) + + def typing(self, *, ephemeral: bool = False) -> Union[Typing, DeferTyping[BotT]]: + """Returns an asynchronous context manager that allows you to send a typing indicator to + the destination for an indefinite period of time, or 10 seconds if the context manager + is called using ``await``. + + In an interaction based context, this is equivalent to a :meth:`defer` call and + does not do any typing calls. + + Example Usage: :: + + async with channel.typing(): + # simulate something heavy + await asyncio.sleep(20) + + await channel.send('Done!') + + Example Usage: :: + + await channel.typing() + # Do some computational magic for about 10 seconds + await channel.send('Done!') + + .. versionchanged:: 2.0 + This no longer works with the ``with`` syntax, ``async with`` must be used instead. + + .. versionchanged:: 2.0 + Added functionality to ``await`` the context manager to send a typing indicator for 10 seconds. + + Parameters + ----------- + ephemeral: :class:`bool` + Indicates whether the deferred message will eventually be ephemeral. + Only valid for interaction based contexts. + + .. versionadded:: 2.0 + """ + if self.interaction is None: + return Typing(self) + return DeferTyping(self, ephemeral=ephemeral) + + async def defer(self, *, ephemeral: bool = False) -> None: + """|coro| + + Defers the interaction based contexts. + + This is typically used when the interaction is acknowledged + and a secondary action will be done later. + + If this isn't an interaction based context then it does nothing. + + Parameters + ----------- + ephemeral: :class:`bool` + Indicates whether the deferred message will eventually be ephemeral. + + Raises + ------- + HTTPException + Deferring the interaction failed. + InteractionResponded + This interaction has already been responded to before. + """ + + if self.interaction: + await self.interaction.response.defer(ephemeral=ephemeral) + + @overload + async def send( + self, + *, + file: File = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def send( + self, + *, + files: Sequence[File] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: LayoutView, + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + file: File = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: Sequence[Embed] = ..., + files: Sequence[File] = ..., + stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference, PartialMessage] = ..., + mention_author: bool = ..., + view: View = ..., + suppress_embeds: bool = ..., + ephemeral: bool = ..., + silent: bool = ..., + poll: Poll = ..., + ) -> Message: ... + + async def send( + self, + content: Optional[str] = None, + *, + tts: bool = False, + embed: Optional[Embed] = None, + embeds: Optional[Sequence[Embed]] = None, + file: Optional[File] = None, + files: Optional[Sequence[File]] = None, + stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None, + delete_after: Optional[float] = None, + nonce: Optional[Union[str, int]] = None, + allowed_mentions: Optional[AllowedMentions] = None, + reference: Optional[Union[Message, MessageReference, PartialMessage]] = None, + mention_author: Optional[bool] = None, + view: Optional[BaseView] = None, + suppress_embeds: bool = False, + ephemeral: bool = False, + silent: bool = False, + poll: Optional[Poll] = None, + ) -> Message: + """|coro| + + Sends a message to the destination with the content given. + + This works similarly to :meth:`~discord.abc.Messageable.send` for non-interaction contexts. + + For interaction based contexts this does one of the following: + + - :meth:`discord.InteractionResponse.send_message` if no response has been given. + - A followup message if a response has been given. + - Regular send if the interaction has expired + + .. versionchanged:: 2.0 + This function will now raise :exc:`TypeError` or + :exc:`ValueError` instead of ``InvalidArgument``. + + Parameters + ------------ + content: Optional[:class:`str`] + The content of the message to send. + tts: :class:`bool` + Indicates if the message should be sent using text-to-speech. + embed: :class:`~discord.Embed` + The rich embed for the content. + file: :class:`~discord.File` + The file to upload. + files: List[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + nonce: :class:`int` + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + delete_after: :class:`float` + If provided, the number of seconds to wait in the background + before deleting the message we just sent. If the deletion fails, + then it is silently ignored. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. + + .. versionadded:: 1.4 + + reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`] + A reference to the :class:`~discord.Message` to which you are replying, this can be created using + :meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control + whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user` + attribute of ``allowed_mentions`` or by setting ``mention_author``. + + This is ignored for interaction based contexts. + + .. versionadded:: 1.6 + + mention_author: Optional[:class:`bool`] + If set, overrides the :attr:`~discord.AllowedMentions.replied_user` attribute of ``allowed_mentions``. + This is ignored for interaction based contexts. + + .. versionadded:: 1.6 + view: Union[:class:`discord.ui.View`, :class:`discord.ui.LayoutView`] + A Discord UI View to add to the message. + + .. versionadded:: 2.0 + embeds: List[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + + .. versionadded:: 2.0 + stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]] + A list of stickers to upload. Must be a maximum of 3. This is ignored for interaction based contexts. + + .. versionadded:: 2.0 + suppress_embeds: :class:`bool` + Whether to suppress embeds for the message. This sends the message without any embeds if set to ``True``. + + .. versionadded:: 2.0 + ephemeral: :class:`bool` + Indicates if the message should only be visible to the user who started the interaction. + If a view is sent with an ephemeral message and it has no timeout set then the timeout + is set to 15 minutes. **This is only applicable in contexts with an interaction**. + + .. versionadded:: 2.0 + silent: :class:`bool` + Whether to suppress push and desktop notifications for the message. This will increment the mention counter + in the UI, but will not actually send a notification. + + .. versionadded:: 2.2 + + poll: Optional[:class:`~discord.Poll`] + The poll to send with this message. + + .. versionadded:: 2.4 + .. versionchanged:: 2.6 + This can now be ``None`` and defaults to ``None`` instead of ``MISSING``. + + Raises + -------- + ~discord.HTTPException + Sending the message failed. + ~discord.Forbidden + You do not have the proper permissions to send the message. + ValueError + The ``files`` list is not of the appropriate size. + TypeError + You specified both ``file`` and ``files``, + or you specified both ``embed`` and ``embeds``, + or the ``reference`` object is not a :class:`~discord.Message`, + :class:`~discord.MessageReference` or :class:`~discord.PartialMessage`. + + Returns + --------- + :class:`~discord.Message` + The message that was sent. + """ + + if self.interaction is None or self.interaction.is_expired(): + return await super().send( + content=content, + tts=tts, + embed=embed, + embeds=embeds, + file=file, + files=files, + stickers=stickers, + delete_after=delete_after, + nonce=nonce, + allowed_mentions=allowed_mentions, + reference=reference, + mention_author=mention_author, + view=view, + suppress_embeds=suppress_embeds, + silent=silent, + poll=poll, + ) # type: ignore # The overloads don't support Optional but the implementation does + + # Convert the kwargs from None to MISSING to appease the remaining implementations + kwargs = { + 'content': content, + 'tts': tts, + 'embed': MISSING if embed is None else embed, + 'embeds': MISSING if embeds is None else embeds, + 'file': MISSING if file is None else file, + 'files': MISSING if files is None else files, + 'allowed_mentions': MISSING if allowed_mentions is None else allowed_mentions, + 'view': MISSING if view is None else view, + 'suppress_embeds': suppress_embeds, + 'ephemeral': ephemeral, + 'silent': silent, + 'poll': MISSING if poll is None else poll, + } + + if self.interaction.response.is_done(): + msg = await self.interaction.followup.send(**kwargs, wait=True) + else: + response = await self.interaction.response.send_message(**kwargs) + if not isinstance(response.resource, discord.InteractionMessage): + msg = await self.interaction.original_response() + else: + msg = response.resource + + if delete_after is not None: + await msg.delete(delay=delete_after) + return msg diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 86a3db45a1ca..baf22c6263bc 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,33 +22,75 @@ DEALINGS IN THE SOFTWARE. """ -import re +from __future__ import annotations + import inspect +import re +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterable, + List, + Literal, + Optional, + overload, + Protocol, + Tuple, + Type, + TypeVar, + Union, + runtime_checkable, +) +import types import discord -from .errors import BadArgument, NoPrivateMessage +from .errors import * + +if TYPE_CHECKING: + from discord.state import Channel + from discord.threads import Thread + + from .parameters import Parameter + from ._types import BotT, _Bot + from .context import Context __all__ = ( 'Converter', + 'ObjectConverter', 'MemberConverter', 'UserConverter', 'MessageConverter', + 'PartialMessageConverter', 'TextChannelConverter', 'InviteConverter', + 'GuildConverter', 'RoleConverter', 'GameConverter', 'ColourConverter', + 'ColorConverter', 'VoiceChannelConverter', + 'StageChannelConverter', 'EmojiConverter', 'PartialEmojiConverter', 'CategoryChannelConverter', + 'ForumChannelConverter', 'IDConverter', + 'ThreadConverter', + 'GuildChannelConverter', + 'GuildStickerConverter', + 'ScheduledEventConverter', + 'SoundboardSoundConverter', 'clean_content', 'Greedy', + 'Range', + 'run_converters', ) -def _get_from_guilds(bot, getter, argument): + +def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any: result = None for guild in bot.guilds: result = getattr(guild, getter)(argument) @@ -58,7 +98,16 @@ def _get_from_guilds(bot, getter, argument): return result return result -class Converter: + +_utils_get = discord.utils.get +T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +CT = TypeVar('CT', bound=discord.abc.GuildChannel) +TT = TypeVar('TT', bound=discord.Thread) + + +@runtime_checkable +class Converter(Protocol[T_co]): """The base class of custom converters that require the :class:`.Context` to be passed to be useful. @@ -69,7 +118,7 @@ class Converter: method to do its conversion logic. This method must be a :ref:`coroutine `. """ - async def convert(self, ctx, argument): + async def convert(self, ctx: Context[BotT], argument: str) -> T_co: """|coro| The method to override to do conversion logic. @@ -78,24 +127,61 @@ async def convert(self, ctx, argument): raise a :exc:`.CommandError` derived exception as it will properly propagate to the error handlers. + Note that if this method is called manually, :exc:`Exception` + should be caught to handle the cases where a subclass does + not explicitly inherit from :exc:`.CommandError`. + Parameters ----------- ctx: :class:`.Context` The invocation context that the argument is being used in. argument: :class:`str` The argument that is being converted. + + Raises + ------- + CommandError + A generic exception occurred when converting the argument. + BadArgument + The converter failed to convert the argument. """ raise NotImplementedError('Derived classes need to implement this.') -class IDConverter(Converter): - def __init__(self): - self._id_regex = re.compile(r'([0-9]{15,21})$') - super().__init__() - def _get_id_match(self, argument): - return self._id_regex.match(argument) +_ID_REGEX = re.compile(r'([0-9]{15,20})$') + + +class IDConverter(Converter[T_co]): + @staticmethod + def _get_id_match(argument): + return _ID_REGEX.match(argument) + + +class ObjectConverter(IDConverter[discord.Object]): + """Converts to a :class:`~discord.Object`. + + The argument must follow the valid ID or mention formats (e.g. ``<@80088516616269824>``). + + .. versionadded:: 2.0 + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by member, role, or channel mention. + """ -class MemberConverter(IDConverter): + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object: + match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) + + if match is None: + raise ObjectNotFound(argument) + + result = int(match.group(1)) + + return discord.Object(id=result) + + +class MemberConverter(IDConverter[discord.Member]): """Converts to a :class:`~discord.Member`. All lookups are via the local guild. If in a DM context, then the lookup @@ -105,16 +191,70 @@ class MemberConverter(IDConverter): 1. Lookup by ID. 2. Lookup by mention. - 3. Lookup by name#discrim - 4. Lookup by name - 5. Lookup by nickname + 3. Lookup by username#discriminator (deprecated). + 4. Lookup by username#0 (deprecated, only gets users that migrated from their discriminator). + 5. Lookup by user name. + 6. Lookup by global name. + 7. Lookup by guild nickname. + + .. versionchanged:: 1.5 + Raise :exc:`.MemberNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.5.1 + This converter now lazily fetches members from the gateway and HTTP APIs, + optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled. + + .. deprecated:: 2.3 + Looking up users by discriminator will be removed in a future version due to + the removal of discriminators in an API change. """ - async def convert(self, ctx, argument): + async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]: + cache = guild._state.member_cache_flags.joined + username, _, discriminator = argument.rpartition('#') + + # If # isn't found then "discriminator" actually has the username + if not username: + discriminator, username = username, discriminator + + if discriminator == '0' or (len(discriminator) == 4 and discriminator.isdigit()): + lookup = username + predicate = lambda m: m.name == username and m.discriminator == discriminator + else: + lookup = argument + predicate = lambda m: m.name == argument or m.global_name == argument or m.nick == argument + + members = await guild.query_members(lookup, limit=100, cache=cache) + return discord.utils.find(predicate, members) + + async def query_member_by_id(self, bot: _Bot, guild: discord.Guild, user_id: int) -> Optional[discord.Member]: + ws = bot._get_websocket(shard_id=guild.shard_id) + cache = guild._state.member_cache_flags.joined + if ws.is_ratelimited(): + # If we're being rate limited on the WS, then fall back to using the HTTP API + # So we don't have to wait ~60 seconds for the query to finish + try: + member = await guild.fetch_member(user_id) + except discord.HTTPException: + return None + + if cache: + guild._add_member(member) + return member + + # If we're not being rate limited then we can use the websocket to actually query + members = await guild.query_members(limit=1, user_ids=[user_id], cache=cache) + if not members: + return None + return members[0] + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) guild = ctx.guild result = None + user_id = None + if match is None: # not a mention... if guild: @@ -124,16 +264,26 @@ async def convert(self, ctx, argument): else: user_id = int(match.group(1)) if guild: - result = guild.get_member(user_id) + result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) else: result = _get_from_guilds(bot, 'get_member', user_id) - if result is None: - raise BadArgument('Member "{}" not found'.format(argument)) + if not isinstance(result, discord.Member): + if guild is None: + raise MemberNotFound(argument) + + if user_id is not None: + result = await self.query_member_by_id(bot, guild, user_id) + else: + result = await self.query_member_named(guild, argument) + + if not result: + raise MemberNotFound(argument) return result -class UserConverter(IDConverter): + +class UserConverter(IDConverter[discord.User]): """Converts to a :class:`~discord.User`. All lookups are via the global user cache. @@ -142,75 +292,149 @@ class UserConverter(IDConverter): 1. Lookup by ID. 2. Lookup by mention. - 3. Lookup by name#discrim - 4. Lookup by name + 3. Lookup by username#discriminator (deprecated). + 4. Lookup by username#0 (deprecated, only gets users that migrated from their discriminator). + 5. Lookup by user name. + 6. Lookup by global name. + + .. versionchanged:: 1.5 + Raise :exc:`.UserNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.6 + This converter now lazily fetches users from the HTTP APIs if an ID is passed + and it's not available in cache. + + .. deprecated:: 2.3 + Looking up users by discriminator will be removed in a future version due to + the removal of discriminators in an API change. """ - async def convert(self, ctx, argument): - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument) + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.User: + match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) result = None state = ctx._state if match is not None: user_id = int(match.group(1)) - result = ctx.bot.get_user(user_id) + result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id) + if result is None: + try: + result = await ctx.bot.fetch_user(user_id) + except discord.HTTPException: + raise UserNotFound(argument) from None + + return result # type: ignore + + username, _, discriminator = argument.rpartition('#') + + # If # isn't found then "discriminator" actually has the username + if not username: + discriminator, username = username, discriminator + + if discriminator == '0' or (len(discriminator) == 4 and discriminator.isdigit()): + predicate = lambda u: u.name == username and u.discriminator == discriminator else: - arg = argument - # check for discriminator if it exists - if len(arg) > 5 and arg[-5] == '#': - discrim = arg[-4:] - name = arg[:-5] - predicate = lambda u: u.name == name and u.discriminator == discrim - result = discord.utils.find(predicate, state._users.values()) - if result is not None: - return result - - predicate = lambda u: u.name == arg - result = discord.utils.find(predicate, state._users.values()) + predicate = lambda u: u.name == argument or u.global_name == argument + result = discord.utils.find(predicate, state._users.values()) if result is None: - raise BadArgument('User "{}" not found'.format(argument)) + raise UserNotFound(argument) return result -class MessageConverter(Converter): + +class PartialMessageConverter(Converter[discord.PartialMessage]): + """Converts to a :class:`discord.PartialMessage`. + + .. versionadded:: 1.7 + + The creation strategy is as follows (in order): + + 1. By "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID") + 2. By message ID (The message is assumed to be in the context channel.) + 3. By message URL + """ + + @staticmethod + def _get_id_matches(ctx: Context[BotT], argument: str) -> Tuple[Optional[int], int, int]: + id_regex = re.compile(r'(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$') + link_regex = re.compile( + r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/' + r'(?P[0-9]{15,20}|@me)' + r'/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$' + ) + match = id_regex.match(argument) or link_regex.match(argument) + if not match: + raise MessageNotFound(argument) + data = match.groupdict() + channel_id = discord.utils._get_as_snowflake(data, 'channel_id') or ctx.channel.id + message_id = int(data['message_id']) + guild_id = data.get('guild_id') + if guild_id is None: + guild_id = ctx.guild and ctx.guild.id + elif guild_id == '@me': + guild_id = None + else: + guild_id = int(guild_id) + return guild_id, message_id, channel_id + + @staticmethod + def _resolve_channel( + ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int] + ) -> Optional[Union[Channel, Thread]]: + if channel_id is None: + # we were passed just a message id so we can assume the channel is the current context channel + return ctx.channel + + if guild_id is not None: + guild = ctx.bot.get_guild(guild_id) + if guild is None: + return None + return guild._resolve_channel(channel_id) + + return ctx.bot.get_channel(channel_id) + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage: + guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) + channel = self._resolve_channel(ctx, guild_id, channel_id) + if not channel or not isinstance(channel, discord.abc.Messageable): + raise ChannelNotFound(channel_id) + return discord.PartialMessage(channel=channel, id=message_id) + + +class MessageConverter(IDConverter[discord.Message]): """Converts to a :class:`discord.Message`. - .. versionadded:: 1.1.0 + .. versionadded:: 1.1 The lookup strategy is as follows (in order): 1. Lookup by "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID") 2. Lookup by message ID (the message **must** be in the context channel) 3. Lookup by message URL + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): - id_regex = re.compile(r'^(?:(?P[0-9]{15,21})-)?(?P[0-9]{15,21})$') - link_regex = re.compile( - r'^https?://(?:(ptb|canary)\.)?discordapp\.com/channels/' - r'(?:([0-9]{15,21})|(@me))' - r'/(?P[0-9]{15,21})/(?P[0-9]{15,21})/?$' - ) - match = id_regex.match(argument) or link_regex.match(argument) - if not match: - raise BadArgument('Message "{msg}" not found.'.format(msg=argument)) - message_id = int(match.group("message_id")) - channel_id = match.group("channel_id") + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message: + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: return message - channel = ctx.bot.get_channel(int(channel_id)) if channel_id else ctx.channel - if not channel: - raise BadArgument('Channel "{channel}" not found.'.format(channel=channel_id)) + channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id) + if not channel or not isinstance(channel, discord.abc.Messageable): + raise ChannelNotFound(channel_id) try: return await channel.fetch_message(message_id) except discord.NotFound: - raise BadArgument('Message "{msg}" not found.'.format(msg=argument)) + raise MessageNotFound(argument) except discord.Forbidden: - raise BadArgument("Can't read messages in {channel}".format(channel=channel.mention)) + raise ChannelNotReadable(channel) # type: ignore # type-checker thinks channel could be a DMChannel at this point -class TextChannelConverter(IDConverter): - """Converts to a :class:`~discord.TextChannel`. + +class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): + """Converts to a :class:`~discord.abc.GuildChannel`. All lookups are via the local guild. If in a DM context, then the lookup is done by the global cache. @@ -219,36 +443,114 @@ class TextChannelConverter(IDConverter): 1. Lookup by ID. 2. Lookup by mention. - 3. Lookup by name + 3. Lookup by channel URL. + 4. Lookup by name. + + .. versionadded:: 2.0 + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. """ - async def convert(self, ctx, argument): + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel: + return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) + + @staticmethod + def _parse_from_url(argument: str) -> Optional[re.Match[str]]: + link_regex = re.compile( + r'https?://(?:(?:ptb|canary|www)\.)?discord(?:app)?\.com/channels/' + r'(?:[0-9]{15,20}|@me)' + r'/([0-9]{15,20})(?:/(?:[0-9]{15,20})/?)?$' + ) + return link_regex.match(argument) + + @staticmethod + def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) + match = ( + IDConverter._get_id_match(argument) + or re.match(r'<#([0-9]{15,20})>$', argument) + or GuildChannelConverter._parse_from_url(argument) + ) result = None guild = ctx.guild if match is None: # not a mention if guild: - result = discord.utils.get(guild.text_channels, name=argument) + iterable: Iterable[CT] = getattr(guild, attribute) + result: Optional[CT] = discord.utils.get(iterable, name=argument) else: + def check(c): - return isinstance(c, discord.TextChannel) and c.name == argument - result = discord.utils.find(check, bot.get_all_channels()) + return isinstance(c, type) and c.name == argument + + result = discord.utils.find(check, bot.get_all_channels()) # type: ignore else: channel_id = int(match.group(1)) if guild: - result = guild.get_channel(channel_id) + # guild.get_channel returns an explicit union instead of the base class + result = guild.get_channel(channel_id) # type: ignore else: result = _get_from_guilds(bot, 'get_channel', channel_id) - if not isinstance(result, discord.TextChannel): - raise BadArgument('Channel "{}" not found.'.format(argument)) + if not isinstance(result, type): + raise ChannelNotFound(argument) return result -class VoiceChannelConverter(IDConverter): + @staticmethod + def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT: + match = ( + IDConverter._get_id_match(argument) + or re.match(r'<#([0-9]{15,20})>$', argument) + or GuildChannelConverter._parse_from_url(argument) + ) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + iterable: Iterable[TT] = getattr(guild, attribute) + result: Optional[TT] = discord.utils.get(iterable, name=argument) + else: + thread_id = int(match.group(1)) + if guild: + result = guild.get_thread(thread_id) # type: ignore + + if not result or not isinstance(result, type): + raise ThreadNotFound(argument) + + return result + + +class TextChannelConverter(IDConverter[discord.TextChannel]): + """Converts to a :class:`~discord.TextChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by channel URL. + 4. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel: + return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) + + +class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """Converts to a :class:`~discord.VoiceChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -258,35 +560,44 @@ class VoiceChannelConverter(IDConverter): 1. Lookup by ID. 2. Lookup by mention. - 3. Lookup by name + 3. Lookup by channel URL. + 4. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. """ - async def convert(self, ctx, argument): - bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.voice_channels, name=argument) - else: - def check(c): - return isinstance(c, discord.VoiceChannel) and c.name == argument - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) + async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel: + return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) - if not isinstance(result, discord.VoiceChannel): - raise BadArgument('Channel "{}" not found.'.format(argument)) - return result +class StageChannelConverter(IDConverter[discord.StageChannel]): + """Converts to a :class:`~discord.StageChannel`. + + .. versionadded:: 1.7 + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by channel URL. + 4. Lookup by name + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel: + return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) + -class CategoryChannelConverter(IDConverter): +class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """Converts to a :class:`~discord.CategoryChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -296,109 +607,189 @@ class CategoryChannelConverter(IDConverter): 1. Lookup by ID. 2. Lookup by mention. - 3. Lookup by name + 3. Lookup by channel URL. + 4. Lookup by name + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. + + .. versionchanged:: 1.5 + Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): - bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument) - result = None - guild = ctx.guild + async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel: + return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) - if match is None: - # not a mention - if guild: - result = discord.utils.get(guild.categories, name=argument) - else: - def check(c): - return isinstance(c, discord.CategoryChannel) and c.name == argument - result = discord.utils.find(check, bot.get_all_channels()) - else: - channel_id = int(match.group(1)) - if guild: - result = guild.get_channel(channel_id) - else: - result = _get_from_guilds(bot, 'get_channel', channel_id) - if not isinstance(result, discord.CategoryChannel): - raise BadArgument('Channel "{}" not found.'.format(argument)) +class ThreadConverter(IDConverter[discord.Thread]): + """Converts to a :class:`~discord.Thread`. - return result + All lookups are via the local guild. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by channel URL. + 4. Lookup by name. + + .. versionadded: 2.0 -class ColourConverter(Converter): + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread: + return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) + + +class ForumChannelConverter(IDConverter[discord.ForumChannel]): + """Converts to a :class:`~discord.ForumChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by channel URL. + 4. Lookup by name + + .. versionadded:: 2.0 + + .. versionchanged:: 2.4 + Add lookup by channel URL, accessed via "Copy Link" in the Discord client within channels. + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.ForumChannel: + return GuildChannelConverter._resolve_channel(ctx, argument, 'forums', discord.ForumChannel) + + +class ColourConverter(Converter[discord.Colour]): """Converts to a :class:`~discord.Colour`. + .. versionchanged:: 1.5 + Add an alias named ColorConverter + The following formats are accepted: - ``0x`` - ``#`` - ``0x#`` - - Any of the ``classmethod`` in :class:`Colour` + - ``rgb(, , )`` + - Any of the ``classmethod`` in :class:`~discord.Colour` - The ``_`` in the name can be optionally replaced with spaces. + + Like CSS, ```` can be either 0-255 or 0-100% and ```` can be + either a 6 digit hex number or a 3 digit hex shortcut (e.g. #fff). + + .. versionchanged:: 1.5 + Raise :exc:`.BadColourArgument` instead of generic :exc:`.BadArgument` + + .. versionchanged:: 1.7 + Added support for ``rgb`` function and 3-digit hex shortcuts """ - async def convert(self, ctx, argument): - arg = argument.replace('0x', '').lower() - if arg[0] == '#': - arg = arg[1:] + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour: try: - value = int(arg, base=16) - if not (0 <= value <= 0xFFFFFF): - raise BadArgument('Colour "{}" is invalid.'.format(arg)) - return discord.Colour(value=value) + return discord.Colour.from_str(argument) except ValueError: - arg = arg.replace(' ', '_') + arg = argument.lower().replace(' ', '_') method = getattr(discord.Colour, arg, None) if arg.startswith('from_') or method is None or not inspect.ismethod(method): - raise BadArgument('Colour "{}" is invalid.'.format(arg)) + raise BadColourArgument(arg) return method() -class RoleConverter(IDConverter): + +ColorConverter = ColourConverter + + +class RoleConverter(IDConverter[discord.Role]): """Converts to a :class:`~discord.Role`. - All lookups are via the local guild. If in a DM context, then the lookup - is done by the global cache. + All lookups are via the local guild. If in a DM context, the converter raises + :exc:`.NoPrivateMessage` exception. The lookup strategy is as follows (in order): 1. Lookup by ID. 2. Lookup by mention. 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role: guild = ctx.guild if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match(r'<@&([0-9]+)>$', argument) + match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument) if match: result = guild.get_role(int(match.group(1))) else: result = discord.utils.get(guild._roles.values(), name=argument) if result is None: - raise BadArgument('Role "{}" not found.'.format(argument)) + raise RoleNotFound(argument) return result -class GameConverter(Converter): - """Converts to :class:`~discord.Game`.""" - async def convert(self, ctx, argument): + +class GameConverter(Converter[discord.Game]): + """Converts to a :class:`~discord.Game`.""" + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game: return discord.Game(name=argument) -class InviteConverter(Converter): + +class InviteConverter(Converter[discord.Invite]): """Converts to a :class:`~discord.Invite`. This is done via an HTTP request using :meth:`.Bot.fetch_invite`. + + .. versionchanged:: 1.5 + Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite: try: invite = await ctx.bot.fetch_invite(argument) return invite except Exception as exc: - raise BadArgument('Invite is invalid or expired') from exc + raise BadInviteArgument(argument) from exc + + +class GuildConverter(IDConverter[discord.Guild]): + """Converts to a :class:`~discord.Guild`. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by name. (There is no disambiguation for Guilds with multiple matching names). + + .. versionadded:: 1.7 + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild: + match = self._get_id_match(argument) + result = None + + if match is not None: + guild_id = int(match.group(1)) + result = ctx.bot.get_guild(guild_id) -class EmojiConverter(IDConverter): + if result is None: + result = discord.utils.get(ctx.bot.guilds, name=argument) + + if result is None: + raise GuildNotFound(argument) + return result + + +class EmojiConverter(IDConverter[discord.Emoji]): """Converts to a :class:`~discord.Emoji`. All lookups are done for the local guild first, if available. If that lookup @@ -409,9 +800,13 @@ class EmojiConverter(IDConverter): 1. Lookup by ID. 2. Lookup by extracting ID from the emoji. 3. Lookup by name + + .. versionchanged:: 1.5 + Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): - match = self._get_id_match(argument) or re.match(r'$', argument) + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji: + match = self._get_id_match(argument) or re.match(r'$', argument) result = None bot = ctx.bot guild = ctx.guild @@ -427,36 +822,175 @@ async def convert(self, ctx, argument): emoji_id = int(match.group(1)) # Try to look up emoji by id. - if guild: - result = discord.utils.get(guild.emojis, id=emoji_id) - - if result is None: - result = discord.utils.get(bot.emojis, id=emoji_id) + result = bot.get_emoji(emoji_id) if result is None: - raise BadArgument('Emoji "{}" not found.'.format(argument)) + raise EmojiNotFound(argument) return result -class PartialEmojiConverter(Converter): + +class PartialEmojiConverter(Converter[discord.PartialEmoji]): """Converts to a :class:`~discord.PartialEmoji`. This is done by extracting the animated flag, name and ID from the emoji. + + .. versionchanged:: 1.5 + Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` """ - async def convert(self, ctx, argument): - match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument) + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji: + match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) if match: emoji_animated = bool(match.group(1)) emoji_name = match.group(2) emoji_id = int(match.group(3)) - return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name, - id=emoji_id) + return discord.PartialEmoji.with_state( + ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id + ) + + raise PartialEmojiConversionFailure(argument) + + +class GuildStickerConverter(IDConverter[discord.GuildSticker]): + """Converts to a :class:`~discord.GuildSticker`. + + All lookups are done for the local guild first, if available. If that lookup + fails, then it checks the client's global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by name. + + .. versionadded:: 2.0 + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker: + match = self._get_id_match(argument) + result = None + bot = ctx.bot + guild = ctx.guild + + if match is None: + # Try to get the sticker by name. Try local guild first. + if guild: + result = discord.utils.get(guild.stickers, name=argument) + + if result is None: + result = discord.utils.get(bot.stickers, name=argument) + else: + sticker_id = int(match.group(1)) + + # Try to look up sticker by id. + result = bot.get_sticker(sticker_id) + + if result is None: + raise GuildStickerNotFound(argument) + + return result + + +class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]): + """Converts to a :class:`~discord.ScheduledEvent`. + + Lookups are done for the local guild if available. Otherwise, for a DM context, + lookup is done by the global cache. + + The lookup strategy is as follows (in order): - raise BadArgument('Couldn\'t convert "{}" to PartialEmoji.'.format(argument)) + 1. Lookup by ID. + 2. Lookup by url. + 3. Lookup by name. + + .. versionadded:: 2.0 + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent: + guild = ctx.guild + match = self._get_id_match(argument) + result = None + + if match: + # ID match + event_id = int(match.group(1)) + if guild: + result = guild.get_scheduled_event(event_id) + else: + for guild in ctx.bot.guilds: + result = guild.get_scheduled_event(event_id) + if result: + break + else: + pattern = ( + r'https?://(?:(ptb|canary|www)\.)?discord\.com/events/' + r'(?P[0-9]{15,20})/' + r'(?P[0-9]{15,20})$' + ) + match = re.match(pattern, argument, flags=re.I) + if match: + # URL match + guild = ctx.bot.get_guild(int(match.group('guild_id'))) + + if guild: + event_id = int(match.group('event_id')) + result = guild.get_scheduled_event(event_id) + else: + # lookup by name + if guild: + result = discord.utils.get(guild.scheduled_events, name=argument) + else: + for guild in ctx.bot.guilds: + result = discord.utils.get(guild.scheduled_events, name=argument) + if result: + break + if result is None: + raise ScheduledEventNotFound(argument) + + return result + + +class SoundboardSoundConverter(IDConverter[discord.SoundboardSound]): + """Converts to a :class:`~discord.SoundboardSound`. + + Lookups are done for the local guild if available. Otherwise, for a DM context, + lookup is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by name. -class clean_content(Converter): + .. versionadded:: 2.5 + """ + + async def convert(self, ctx: Context[BotT], argument: str) -> discord.SoundboardSound: + guild = ctx.guild + match = self._get_id_match(argument) + result = None + + if match: + # ID match + sound_id = int(match.group(1)) + if guild: + result = guild.get_soundboard_sound(sound_id) + else: + result = ctx.bot.get_soundboard_sound(sound_id) + else: + # lookup by name + if guild: + result = discord.utils.get(guild.soundboard_sounds, name=argument) + else: + result = discord.utils.get(ctx.bot.soundboard_sounds, name=argument) + if result is None: + raise SoundboardSoundNotFound(argument) + + return result + + +class clean_content(Converter[str]): """Converts the argument to mention scrubbed version of said content. @@ -470,84 +1004,439 @@ class clean_content(Converter): Whether to use nicknames when transforming mentions. escape_markdown: :class:`bool` Whether to also escape special markdown characters. + remove_markdown: :class:`bool` + Whether to also remove special markdown characters. This option is not supported with ``escape_markdown`` + + .. versionadded:: 1.7 """ - def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False): + + def __init__( + self, + *, + fix_channel_mentions: bool = False, + use_nicknames: bool = True, + escape_markdown: bool = False, + remove_markdown: bool = False, + ) -> None: self.fix_channel_mentions = fix_channel_mentions self.use_nicknames = use_nicknames self.escape_markdown = escape_markdown + self.remove_markdown = remove_markdown - async def convert(self, ctx, argument): - message = ctx.message - transformations = {} + async def convert(self, ctx: Context[BotT], argument: str) -> str: + msg = ctx.message - if self.fix_channel_mentions and ctx.guild: - def resolve_channel(id, *, _get=ctx.guild.get_channel): - ch = _get(id) - return ('<#%s>' % id), ('#' + ch.name if ch else '#deleted-channel') + if ctx.guild: + + def resolve_member(id: int) -> str: + m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) # type: ignore + return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' - transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) + def resolve_role(id: int) -> str: + r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) # type: ignore + return f'@{r.name}' if r else '@deleted-role' - if self.use_nicknames and ctx.guild: - def resolve_member(id, *, _get=ctx.guild.get_member): - m = _get(id) - return '@' + m.display_name if m else '@deleted-user' else: - def resolve_member(id, *, _get=ctx.bot.get_user): - m = _get(id) - return '@' + m.name if m else '@deleted-user' + def resolve_member(id: int) -> str: + m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) + return f'@{m.display_name}' if m else '@deleted-user' - transformations.update( - ('<@%s>' % member_id, resolve_member(member_id)) - for member_id in message.raw_mentions - ) + def resolve_role(id: int) -> str: + return '@deleted-role' - transformations.update( - ('<@!%s>' % member_id, resolve_member(member_id)) - for member_id in message.raw_mentions - ) + if self.fix_channel_mentions and ctx.guild: - if ctx.guild: - def resolve_role(_id, *, _find=ctx.guild.get_role): - r = _find(_id) - return '@' + r.name if r else '@deleted-role' + def resolve_channel(id: int) -> str: + c = ctx.guild._resolve_channel(id) # type: ignore + return f'#{c.name}' if c else '#deleted-channel' - transformations.update( - ('<@&%s>' % role_id, resolve_role(role_id)) - for role_id in message.raw_role_mentions - ) + else: - def repl(obj): - return transformations.get(obj.group(0), '') + def resolve_channel(id: int) -> str: + return f'<#{id}>' - pattern = re.compile('|'.join(transformations.keys())) - result = pattern.sub(repl, argument) + transforms = { + '@': resolve_member, + '@!': resolve_member, + '#': resolve_channel, + '@&': resolve_role, + } + def repl(match: re.Match) -> str: + type = match[1] + id = int(match[2]) + transformed = transforms[type](id) + return transformed + + result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) if self.escape_markdown: result = discord.utils.escape_markdown(result) + elif self.remove_markdown: + result = discord.utils.remove_markdown(result) # Completely ensure no mentions escape: return discord.utils.escape_mentions(result) -class _Greedy: + +class Greedy(List[T]): + r"""A special converter that greedily consumes arguments until it can't. + As a consequence of this behaviour, most input errors are silently discarded, + since it is used as an indicator of when to stop parsing. + + When a parser error is met the greedy converter stops converting, undoes the + internal string parsing routine, and continues parsing regularly. + + For example, in the following code: + + .. code-block:: python3 + + @commands.command() + async def test(ctx, numbers: Greedy[int], reason: str): + await ctx.send("numbers: {}, reason: {}".format(numbers, reason)) + + An invocation of ``[p]test 1 2 3 4 5 6 hello`` would pass ``numbers`` with + ``[1, 2, 3, 4, 5, 6]`` and ``reason`` with ``hello``\. + + For more information, check :ref:`ext_commands_special_converters`. + + .. note:: + + For interaction based contexts the conversion error is propagated + rather than swallowed due to the difference in user experience with + application commands. + """ + __slots__ = ('converter',) - def __init__(self, *, converter=None): - self.converter = converter + def __init__(self, *, converter: T) -> None: + self.converter: T = converter - def __getitem__(self, params): + def __repr__(self) -> str: + converter = getattr(self.converter, '__name__', repr(self.converter)) + return f'Greedy[{converter}]' + + def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: if not isinstance(params, tuple): params = (params,) if len(params) != 1: raise TypeError('Greedy[...] only takes a single argument') converter = params[0] - if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')): + args = getattr(converter, '__args__', ()) + if discord.utils.PY_310 and converter.__class__ is types.UnionType: # type: ignore + converter = Union[args] + + origin = getattr(converter, '__origin__', None) + + if not (callable(converter) or isinstance(converter, Converter) or origin is not None): raise TypeError('Greedy[...] expects a type or a Converter instance.') - if converter is str or converter is type(None) or converter is _Greedy: - raise TypeError('Greedy[%s] is invalid.' % converter.__name__) + if converter in (str, type(None)) or origin is Greedy: + raise TypeError(f'Greedy[{converter.__name__}] is invalid.') # type: ignore + + if origin is Union and type(None) in args: + raise TypeError(f'Greedy[{converter!r}] is invalid.') + + return cls(converter=converter) # type: ignore + + @property + def constructed_converter(self) -> Any: + # Only construct a converter once in order to maintain state between convert calls + if ( + inspect.isclass(self.converter) + and issubclass(self.converter, Converter) + and not inspect.ismethod(self.converter.convert) + ): + return self.converter() + return self.converter + + +if TYPE_CHECKING: + from typing_extensions import Annotated as Range +else: - return self.__class__(converter=converter) + class Range: + """A special converter that can be applied to a parameter to require a numeric + or string type to fit within the range provided. -Greedy = _Greedy() + During type checking time this is equivalent to :obj:`typing.Annotated` so type checkers understand + the intent of the code. + + Some example ranges: + + - ``Range[int, 10]`` means the minimum is 10 with no maximum. + - ``Range[int, None, 10]`` means the maximum is 10 with no minimum. + - ``Range[int, 1, 10]`` means the minimum is 1 and the maximum is 10. + - ``Range[float, 1.0, 5.0]`` means the minimum is 1.0 and the maximum is 5.0. + - ``Range[str, 1, 10]`` means the minimum length is 1 and the maximum length is 10. + + Inside a :class:`HybridCommand` this functions equivalently to :class:`discord.app_commands.Range`. + + If the value cannot be converted to the provided type or is outside the given range, + :class:`~.ext.commands.BadArgument` or :class:`~.ext.commands.RangeError` is raised to + the appropriate error handlers respectively. + + .. versionadded:: 2.0 + + Examples + ---------- + + .. code-block:: python3 + + @bot.command() + async def range(ctx: commands.Context, value: commands.Range[int, 10, 12]): + await ctx.send(f'Your value is {value}') + """ + + def __init__( + self, + *, + annotation: Any, + min: Optional[Union[int, float]] = None, + max: Optional[Union[int, float]] = None, + ) -> None: + self.annotation: Any = annotation + self.min: Optional[Union[int, float]] = min + self.max: Optional[Union[int, float]] = max + + if min and max and min > max: + raise TypeError('minimum cannot be larger than maximum') + + async def convert(self, ctx: Context[BotT], value: str) -> Union[int, float]: + try: + count = converted = self.annotation(value) + except ValueError: + raise BadArgument( + f'Converting to "{self.annotation.__name__}" failed for parameter "{ctx.current_parameter.name}".' + ) + + if self.annotation is str: + count = len(value) + + if (self.min is not None and count < self.min) or (self.max is not None and count > self.max): + raise RangeError(converted, minimum=self.min, maximum=self.max) + + return converted + + def __call__(self) -> None: + # Trick to allow it inside typing.Union + pass + + def __or__(self, rhs) -> Any: + return Union[self, rhs] + + def __repr__(self) -> str: + return f'{self.__class__.__name__}[{self.annotation.__name__}, {self.min}, {self.max}]' + + def __class_getitem__(cls, obj) -> Range: + if not isinstance(obj, tuple): + raise TypeError(f'expected tuple for arguments, received {obj.__class__.__name__} instead') + + if len(obj) == 2: + obj = (*obj, None) + elif len(obj) != 3: + raise TypeError('Range accepts either two or three arguments with the first being the type of range.') + + annotation, min, max = obj + + if min is None and max is None: + raise TypeError('Range must not be empty') + + if min is not None and max is not None: + # At this point max and min are both not none + if type(min) != type(max): + raise TypeError('Both min and max in Range must be the same type') + + if annotation not in (int, float, str): + raise TypeError(f'expected int, float, or str as range type, received {annotation!r} instead') + + if annotation in (str, int): + cast = int + else: + cast = float + + return cls( + annotation=annotation, + min=cast(min) if min is not None else None, + max=cast(max) if max is not None else None, + ) + + +def _convert_to_bool(argument: str) -> bool: + lowered = argument.lower() + if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + return True + elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + return False + else: + raise BadBoolArgument(lowered) + + +_GenericAlias = type(List[T]) # type: ignore + + +def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool: + return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) + + +CONVERTER_MAPPING: Dict[type, Any] = { + discord.Object: ObjectConverter, + discord.Member: MemberConverter, + discord.User: UserConverter, + discord.Message: MessageConverter, + discord.PartialMessage: PartialMessageConverter, + discord.TextChannel: TextChannelConverter, + discord.Invite: InviteConverter, + discord.Guild: GuildConverter, + discord.Role: RoleConverter, + discord.Game: GameConverter, + discord.Colour: ColourConverter, + discord.VoiceChannel: VoiceChannelConverter, + discord.StageChannel: StageChannelConverter, + discord.Emoji: EmojiConverter, + discord.PartialEmoji: PartialEmojiConverter, + discord.CategoryChannel: CategoryChannelConverter, + discord.Thread: ThreadConverter, + discord.abc.GuildChannel: GuildChannelConverter, + discord.GuildSticker: GuildStickerConverter, + discord.ScheduledEvent: ScheduledEventConverter, + discord.ForumChannel: ForumChannelConverter, + discord.SoundboardSound: SoundboardSoundConverter, +} + + +async def _actual_conversion(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter): + if converter is bool: + return _convert_to_bool(argument) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and (module.startswith('discord.') and not module.endswith('converter')): + converter = CONVERTER_MAPPING.get(converter, converter) + + try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if inspect.ismethod(converter.convert): + return await converter.convert(ctx, argument) + else: + return await converter().convert(ctx, argument) + elif isinstance(converter, Converter): + return await converter.convert(ctx, argument) + except CommandError: + raise + except Exception as exc: + raise ConversionError(converter, exc) from exc # type: ignore + + try: + return converter(argument) + except CommandError: + raise + except Exception as exc: + try: + name = converter.__name__ + except AttributeError: + name = converter.__class__.__name__ + + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc + + +@overload +async def run_converters( + ctx: Context[BotT], converter: Union[Type[Converter[T]], Converter[T]], argument: str, param: Parameter +) -> T: ... + + +@overload +async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: Parameter) -> Any: ... + + +async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: Parameter) -> Any: + """|coro| + + Runs converters for a given converter, argument, and parameter. + + This function does the same work that the library does under the hood. + + .. versionadded:: 2.0 + + Parameters + ------------ + ctx: :class:`Context` + The invocation context to run the converters under. + converter: Any + The converter to run, this corresponds to the annotation in the function. + argument: :class:`str` + The argument to convert to. + param: :class:`Parameter` + The parameter being converted. This is mainly for error reporting. + + Raises + ------- + CommandError + The converter failed to convert. + + Returns + -------- + Any + The resulting conversion. + """ + origin = getattr(converter, '__origin__', None) + + if origin is Union: + errors = [] + _NoneType = type(None) + union_args = converter.__args__ + for conv in union_args: + # if we got to this part in the code, then the previous conversions have failed + # so we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.required else await param.get_default(ctx) + + try: + value = await run_converters(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, union_args, errors) + + if origin is Literal: + errors = [] + conversions = {} + literal_args = converter.__args__ + for literal in literal_args: + literal_type = type(literal) + try: + value = conversions[literal_type] + except KeyError: + try: + value = await _actual_conversion(ctx, literal_type, argument, param) + except CommandError as exc: + errors.append(exc) + conversions[literal_type] = object() + continue + else: + conversions[literal_type] = value + + if value == literal: + return value + + # if we're here, then we failed to match all the literals + raise BadLiteralArgument(param, literal_args, errors, argument) + + # This must be the last if-clause in the chain of origin checking + # Nearly every type is a generic type within the typing library + # So care must be taken to make sure a more specialised origin handle + # isn't overwritten by the widest if clause + if origin is not None and is_generic_type(converter): + converter = origin + + return await _actual_conversion(ctx, converter, argument, param) diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 5cd305f6fab7..fb68944bde66 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,111 +22,101 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + + +from typing import Any, Callable, Deque, Dict, Optional, Union, Generic, TypeVar, TYPE_CHECKING from discord.enums import Enum +from discord.abc import PrivateChannel import time +import asyncio +from collections import deque + +from .errors import MaxConcurrencyReached +from .context import Context +from discord.app_commands import Cooldown as Cooldown + +if TYPE_CHECKING: + from typing_extensions import Self + + from ...message import Message __all__ = ( 'BucketType', 'Cooldown', 'CooldownMapping', + 'DynamicCooldownMapping', + 'MaxConcurrency', ) -class BucketType(Enum): - default = 0 - user = 1 - guild = 2 - channel = 3 - member = 4 - category = 5 - -class Cooldown: - __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last') - - def __init__(self, rate, per, type): - self.rate = int(rate) - self.per = float(per) - self.type = type - self._window = 0.0 - self._tokens = self.rate - self._last = 0.0 - - if not isinstance(self.type, BucketType): - raise TypeError('Cooldown type must be a BucketType') - - def get_tokens(self, current=None): - if not current: - current = time.time() - - tokens = self._tokens - - if current > self._window + self.per: - tokens = self.rate - return tokens - - def update_rate_limit(self, current=None): - current = current or time.time() - self._last = current - - self._tokens = self.get_tokens(current) - - # first token used means that we start a new rate limit window - if self._tokens == self.rate: - self._window = current +T_contra = TypeVar('T_contra', contravariant=True) - # check if we are rate limited - if self._tokens == 0: - return self.per - (current - self._window) - # we're not so decrement our tokens - self._tokens -= 1 - - # see if we got rate limited due to this token change, and if - # so update the window to point to our current time frame - if self._tokens == 0: - self._window = current - - def reset(self): - self._tokens = self.rate - self._last = 0.0 - - def copy(self): - return Cooldown(self.rate, self.per, self.type) - - def __repr__(self): - return ''.format(self) - -class CooldownMapping: - def __init__(self, original): - self._cache = {} - self._cooldown = original +class BucketType(Enum): + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 + category = 5 + role = 6 - def copy(self): - ret = CooldownMapping(self._cooldown) + def get_key(self, msg: Union[Message, Context[Any]]) -> Any: + if self is BucketType.user: + return msg.author.id + elif self is BucketType.guild: + return (msg.guild or msg.author).id + elif self is BucketType.channel: + return msg.channel.id + elif self is BucketType.member: + return ((msg.guild and msg.guild.id), msg.author.id) + elif self is BucketType.category: + return (getattr(msg.channel, 'category', None) or msg.channel).id + elif self is BucketType.role: + # we return the channel id of a private-channel as there are only roles in guilds + # and that yields the same result as for a guild with only the @everyone role + # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are + # receiving a DMChannel or GroupChannel which inherit from PrivateChannel and do + return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore + + def __call__(self, msg: Union[Message, Context[Any]]) -> Any: + return self.get_key(msg) + + +class CooldownMapping(Generic[T_contra]): + def __init__( + self, + original: Optional[Cooldown], + type: Callable[[T_contra], Any], + ) -> None: + if not callable(type): + raise TypeError('Cooldown type must be a BucketType or callable') + + self._cache: Dict[Any, Cooldown] = {} + self._cooldown: Optional[Cooldown] = original + self._type: Callable[[T_contra], Any] = type + + def copy(self) -> CooldownMapping[T_contra]: + ret = CooldownMapping(self._cooldown, self._type) ret._cache = self._cache.copy() return ret @property - def valid(self): + def valid(self) -> bool: return self._cooldown is not None + @property + def type(self) -> Callable[[T_contra], Any]: + return self._type + @classmethod - def from_cooldown(cls, rate, per, type): - return cls(Cooldown(rate, per, type)) + def from_cooldown(cls, rate: float, per: float, type: Callable[[T_contra], Any]) -> Self: + return cls(Cooldown(rate, per), type) - def _bucket_key(self, msg): - bucket_type = self._cooldown.type - if bucket_type is BucketType.user: - return msg.author.id - elif bucket_type is BucketType.guild: - return (msg.guild or msg.author).id - elif bucket_type is BucketType.channel: - return msg.channel.id - elif bucket_type is BucketType.member: - return ((msg.guild and msg.guild.id), msg.author.id) - elif bucket_type is BucketType.category: - return (msg.channel.category or msg.channel).id + def _bucket_key(self, msg: T_contra) -> Any: + return self._type(msg) - def _verify_cache_integrity(self, current=None): + def _verify_cache_integrity(self, current: Optional[float] = None) -> None: # we want to delete all cache objects that haven't been used # in a cooldown window. e.g. if we have a command that has a # cooldown of 60s and it has not been used in 60s then that key should be deleted @@ -137,20 +125,161 @@ def _verify_cache_integrity(self, current=None): for k in dead_keys: del self._cache[k] - def get_bucket(self, message, current=None): - if self._cooldown.type is BucketType.default: + def create_bucket(self, message: T_contra) -> Cooldown: + return self._cooldown.copy() # type: ignore + + def get_bucket(self, message: T_contra, current: Optional[float] = None) -> Optional[Cooldown]: + if self._type is BucketType.default: return self._cooldown self._verify_cache_integrity(current) key = self._bucket_key(message) if key not in self._cache: - bucket = self._cooldown.copy() - self._cache[key] = bucket + bucket = self.create_bucket(message) + if bucket is not None: + self._cache[key] = bucket else: bucket = self._cache[key] return bucket - def update_rate_limit(self, message, current=None): + def update_rate_limit(self, message: T_contra, current: Optional[float] = None, tokens: int = 1) -> Optional[float]: bucket = self.get_bucket(message, current) - return bucket.update_rate_limit(current) + if bucket is None: + return None + return bucket.update_rate_limit(current, tokens=tokens) + + +class DynamicCooldownMapping(CooldownMapping[T_contra]): + def __init__( + self, + factory: Callable[[T_contra], Optional[Cooldown]], + type: Callable[[T_contra], Any], + ) -> None: + super().__init__(None, type) + self._factory: Callable[[T_contra], Optional[Cooldown]] = factory + + def copy(self) -> DynamicCooldownMapping[T_contra]: + ret = DynamicCooldownMapping(self._factory, self._type) + ret._cache = self._cache.copy() + return ret + + @property + def valid(self) -> bool: + return True + + def create_bucket(self, message: T_contra) -> Optional[Cooldown]: + return self._factory(message) + + +class _Semaphore: + """This class is a version of a semaphore. + + If you're wondering why asyncio.Semaphore isn't being used, + it's because it doesn't expose the internal value. This internal + value is necessary because I need to support both `wait=True` and + `wait=False`. + + An asyncio.Queue could have been used to do this as well -- but it is + not as inefficient since internally that uses two queues and is a bit + overkill for what is basically a counter. + """ + + __slots__ = ('value', 'loop', '_waiters') + + def __init__(self, number: int) -> None: + self.value: int = number + self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + self._waiters: Deque[asyncio.Future] = deque() + + def __repr__(self) -> str: + return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' + + def locked(self) -> bool: + return self.value == 0 + + def is_active(self) -> bool: + return len(self._waiters) > 0 + + def wake_up(self) -> None: + while self._waiters: + future = self._waiters.popleft() + if not future.done(): + future.set_result(None) + return + + async def acquire(self, *, wait: bool = False) -> bool: + if not wait and self.value <= 0: + # signal that we're not acquiring + return False + + while self.value <= 0: + future = self.loop.create_future() + self._waiters.append(future) + try: + await future + except: + future.cancel() + if self.value > 0 and not future.cancelled(): + self.wake_up() + raise + + self.value -= 1 + return True + + def release(self) -> None: + self.value += 1 + self.wake_up() + + +class MaxConcurrency: + __slots__ = ('number', 'per', 'wait', '_mapping') + + def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: + self._mapping: Dict[Any, _Semaphore] = {} + self.per: BucketType = per + self.number: int = number + self.wait: bool = wait + + if number <= 0: + raise ValueError("max_concurrency 'number' cannot be less than 1") + + if not isinstance(per, BucketType): + raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") + + def copy(self) -> Self: + return self.__class__(self.number, per=self.per, wait=self.wait) + + def __repr__(self) -> str: + return f'' + + def get_key(self, message: Union[Message, Context[Any]]) -> Any: + return self.per.get_key(message) + + async def acquire(self, message: Union[Message, Context[Any]]) -> None: + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + self._mapping[key] = sem = _Semaphore(self.number) + + acquired = await sem.acquire(wait=self.wait) + if not acquired: + raise MaxConcurrencyReached(self.number, self.per) + + async def release(self, message: Union[Message, Context[Any]]) -> None: + # Technically there's no reason for this function to be async + # But it might be more useful in the future + key = self.get_key(message) + + try: + sem = self._mapping[key] + except KeyError: + # ...? peculiar + return + else: + sem.release() + + if sem.value >= self.number and not sem.is_active(): + del self._mapping[key] diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index ead83c729003..949539b61176 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) -Copyright (c) 2015-2019 Rapptz +Copyright (c) 2015-present Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), @@ -24,19 +22,77 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import asyncio +import datetime import functools import inspect -import typing -import datetime +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Generic, + List, + Literal, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, + TypedDict, +) +import re import discord -from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping -from . import converter as converters -from ._types import _BaseCommand +from ._types import _BaseCommand, CogT from .cog import Cog +from .context import Context +from .converter import Greedy, run_converters +from .cooldowns import BucketType, Cooldown, CooldownMapping, DynamicCooldownMapping, MaxConcurrency +from .errors import * +from .parameters import Parameter, Signature +from discord.app_commands.commands import NUMPY_DOCSTRING_ARG_REGEX + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec, Self, Unpack + + from ._types import BotT, Check, ContextT, Coro, CoroFunc, Error, Hook, UserCheck + + from discord.permissions import _PermissionsKwargs + + class _CommandDecoratorKwargs(TypedDict, total=False): + enabled: bool + help: Optional[str] + brief: Optional[str] + usage: Optional[str] + rest_is_raw: bool + aliases: Union[List[str], Tuple[str, ...]] + description: str + hidden: bool + checks: List[UserCheck[Context[Any]]] + cooldown: CooldownMapping[Context[Any]] + max_concurrency: MaxConcurrency + require_var_positional: bool + cooldown_after_parsing: bool + ignore_extra: bool + extras: Dict[Any, Any] + + class _CommandKwargs(_CommandDecoratorKwargs, total=False): + name: str + + class _GroupDecoratorKwargs(_CommandDecoratorKwargs, total=False): + invoke_without_command: bool + case_insensitive: bool + + class _GroupKwargs(_GroupDecoratorKwargs, total=False): + name: str + __all__ = ( 'Command', @@ -48,19 +104,146 @@ 'has_permissions', 'has_any_role', 'check', + 'check_any', + 'before_invoke', + 'after_invoke', 'bot_has_role', 'bot_has_permissions', 'bot_has_any_role', 'cooldown', + 'dynamic_cooldown', + 'max_concurrency', 'dm_only', 'guild_only', 'is_owner', 'is_nsfw', + 'has_guild_permissions', + 'bot_has_guild_permissions', ) -def wrap_callback(coro): +MISSING: Any = discord.utils.MISSING + +T = TypeVar('T') +CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]') +# CHT = TypeVar('CHT', bound='Check') +GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]') + +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') + + +def unwrap_function(function: Callable[..., Any], /) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, '__wrapped__'): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + + +def get_signature_parameters( + function: Callable[..., Any], + globalns: Dict[str, Any], + /, + *, + skip_parameters: Optional[int] = None, +) -> Dict[str, Parameter]: + signature = Signature.from_callable(function) + params: Dict[str, Parameter] = {} + cache: Dict[str, Any] = {} + eval_annotation = discord.utils.evaluate_annotation + required_params = discord.utils.is_inside_class(function) + 1 if skip_parameters is None else skip_parameters + if len(signature.parameters) < required_params: + raise TypeError(f'Command signature requires at least {required_params - 1} parameter(s)') + + iterator = iter(signature.parameters.items()) + for _ in range(0, required_params): + next(iterator) + + for name, parameter in iterator: + default = parameter.default + if isinstance(default, Parameter): # update from the default + if default.annotation is not Parameter.empty: + # There are a few cases to care about here. + # x: TextChannel = commands.CurrentChannel + # x = commands.CurrentChannel + # In both of these cases, the default parameter has an explicit annotation + # but in the second case it's only used as the fallback. + if default._fallback: + if parameter.annotation is Parameter.empty: + parameter._annotation = default.annotation + else: + parameter._annotation = default.annotation + + parameter._default = default.default + parameter._description = default._description + parameter._displayed_default = default._displayed_default + parameter._displayed_name = default._displayed_name + + annotation = parameter.annotation + + if annotation is None: + params[name] = parameter.replace(annotation=type(None)) + continue + + annotation = eval_annotation(annotation, globalns, globalns, cache) + if annotation is Greedy: + raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + + params[name] = parameter.replace(annotation=annotation) + + return params + + +PARAMETER_HEADING_REGEX = re.compile(r'Parameters?\n---+\n', re.I) + + +def _fold_text(input: str) -> str: + """Turns a single newline into a space, and multiple newlines into a newline.""" + + def replacer(m: re.Match[str]) -> str: + if len(m.group()) <= 1: + return ' ' + return '\n' + + return re.sub(r'\n+', replacer, inspect.cleandoc(input)) + + +def extract_descriptions_from_docstring(function: Callable[..., Any], params: Dict[str, Parameter], /) -> Optional[str]: + docstring = inspect.getdoc(function) + + if docstring is None: + return None + + divide = PARAMETER_HEADING_REGEX.split(docstring, 1) + if len(divide) == 1: + return docstring + + description, param_docstring = divide + for match in NUMPY_DOCSTRING_ARG_REGEX.finditer(param_docstring): + name = match.group('name') + + if name not in params: + is_display_name = discord.utils.get(params.values(), displayed_name=name) + if is_display_name: + name = is_display_name.name + else: + continue + + param = params[name] + if param.description is None: + param._description = _fold_text(match.group('description')) + + return _fold_text(description.strip()) + + +def wrap_callback(coro: Callable[P, Coro[T]], /) -> Callable[P, Coro[Optional[T]]]: @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: try: ret = await coro(*args, **kwargs) except CommandError: @@ -70,11 +253,15 @@ async def wrapped(*args, **kwargs): except Exception as exc: raise CommandInvokeError(exc) from exc return ret + return wrapped -def hooked_wrapped_callback(command, ctx, coro): + +def hooked_wrapped_callback( + command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]], / +) -> Callable[P, Coro[Optional[T]]]: @functools.wraps(coro) - async def wrapped(*args, **kwargs): + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: try: ret = await coro(*args, **kwargs) except CommandError: @@ -87,39 +274,57 @@ async def wrapped(*args, **kwargs): ctx.command_failed = True raise CommandInvokeError(exc) from exc finally: + if command._max_concurrency is not None: + await command._max_concurrency.release(ctx.message) + await command.call_after_hooks(ctx) return ret + return wrapped -def _convert_to_bool(argument): - lowered = argument.lower() - if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): - return True - elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): - return False - else: - raise BadArgument(lowered + ' is not a recognised boolean option') class _CaseInsensitiveDict(dict): def __contains__(self, k): - return super().__contains__(k.lower()) + return super().__contains__(k.casefold()) def __delitem__(self, k): - return super().__delitem__(k.lower()) + return super().__delitem__(k.casefold()) def __getitem__(self, k): - return super().__getitem__(k.lower()) + return super().__getitem__(k.casefold()) def get(self, k, default=None): - return super().get(k.lower(), default) + return super().get(k.casefold(), default) def pop(self, k, default=None): - return super().pop(k.lower(), default) + return super().pop(k.casefold(), default) def __setitem__(self, k, v): - super().__setitem__(k.lower(), v) + super().__setitem__(k.casefold(), v) + + +class _AttachmentIterator: + def __init__(self, data: List[discord.Attachment]): + self.data: List[discord.Attachment] = data + self.index: int = 0 + + def __iter__(self) -> Self: + return self + + def __next__(self) -> discord.Attachment: + try: + value = self.data[self.index] + except IndexError: + raise StopIteration + else: + self.index += 1 + return value + + def is_empty(self) -> bool: + return self.index >= len(self.data) + -class Command(_BaseCommand): +class Command(_BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. These are not created manually, instead they are created via the @@ -131,24 +336,25 @@ class Command(_BaseCommand): The name of the command. callback: :ref:`coroutine ` The coroutine that is executed when the command is called. - help: :class:`str` + help: Optional[:class:`str`] The long help text for the command. - brief: :class:`str` - The short help text for the command. If this is not specified - then the first line of the long help text is used instead. - usage: :class:`str` + brief: Optional[:class:`str`] + The short help text for the command. + usage: Optional[:class:`str`] A replacement for arguments in the default help text. - aliases: :class:`list` + aliases: Union[List[:class:`str`], Tuple[:class:`str`]] The list of aliases the command can be invoked under. enabled: :class:`bool` A boolean that indicates if the command is currently enabled. If the command is invoked while it is disabled, then :exc:`.DisabledCommand` is raised to the :func:`.on_command_error` event. Defaults to ``True``. - parent: Optional[:class:`Command`] - The parent command that this command belongs to. ``None`` if there + parent: Optional[:class:`Group`] + The parent group that this command belongs to. ``None`` if there isn't one. - checks: List[Callable[..., :class:`bool`]] + cog: Optional[:class:`Cog`] + The cog that this command belongs to. ``None`` if there isn't one. + checks: List[Callable[[:class:`.Context`], :class:`bool`]] A list of predicates that verifies if the command could be executed with the given :class:`.Context` as the sole parameter. If an exception is necessary to be thrown to signal failure, then one inherited from @@ -169,6 +375,12 @@ class Command(_BaseCommand): in a completely raw matter. Defaults to ``False``. invoked_subcommand: Optional[:class:`Command`] The subcommand that was invoked, if any. + require_var_positional: :class:`bool` + If ``True`` and a variadic positional argument is specified, requires + the user to specify at least one argument. Defaults to ``False``. + + .. versionadded:: 1.5 + ignore_extra: :class:`bool` If ``True``\, ignores extraneous strings passed to a command if all its requirements are met (e.g. ``?foo a b c`` when only expecting ``a`` @@ -178,9 +390,19 @@ class Command(_BaseCommand): If ``True``\, cooldown processing is done after argument parsing, which calls converters. If ``False`` then cooldown processing is done first and then the converters are called second. Defaults to ``False``. + extras: :class:`dict` + A dict of user provided extras to attach to the Command. + + .. note:: + This object may be copied by the library. + + + .. versionadded:: 2.0 """ - def __new__(cls, *args, **kwargs): + __original_kwargs__: Dict[str, Any] + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: # if you're wondering why this is done, it's because we need to ensure # we have a complete original copy of **kwargs even for classes that # mess with it by popping before delegating to the subclass __init__. @@ -196,114 +418,241 @@ def __new__(cls, *args, **kwargs): self.__original_kwargs__ = kwargs.copy() return self - def __init__(self, func, **kwargs): + def __init__( + self, + func: Union[ + Callable[Concatenate[CogT, Context[Any], P], Coro[T]], + Callable[Concatenate[Context[Any], P], Coro[T]], + ], + /, + **kwargs: Unpack[_CommandKwargs], + ) -> None: if not asyncio.iscoroutinefunction(func): raise TypeError('Callback must be a coroutine.') - self.name = name = kwargs.get('name') or func.__name__ + name = kwargs.get('name') or func.__name__ if not isinstance(name, str): raise TypeError('Name of a command must be a string.') + self.name: str = name self.callback = func - self.enabled = kwargs.get('enabled', True) + self.enabled: bool = kwargs.get('enabled', True) help_doc = kwargs.get('help') if help_doc is not None: help_doc = inspect.cleandoc(help_doc) else: - help_doc = inspect.getdoc(func) - if isinstance(help_doc, bytes): - help_doc = help_doc.decode('utf-8') + help_doc = extract_descriptions_from_docstring(func, self.params) - self.help = help_doc + self.help: Optional[str] = help_doc - self.brief = kwargs.get('brief') - self.usage = kwargs.get('usage') - self.rest_is_raw = kwargs.get('rest_is_raw', False) - self.aliases = kwargs.get('aliases', []) + self.brief: Optional[str] = kwargs.get('brief') + self.usage: Optional[str] = kwargs.get('usage') + self.rest_is_raw: bool = kwargs.get('rest_is_raw', False) + self.aliases: Union[List[str], Tuple[str, ...]] = kwargs.get('aliases', []) + self.extras: Dict[Any, Any] = kwargs.get('extras', {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError("Aliases of a command must be a list of strings.") + raise TypeError('Aliases of a command must be a list or a tuple of strings.') - self.description = inspect.cleandoc(kwargs.get('description', '')) - self.hidden = kwargs.get('hidden', False) + self.description: str = inspect.cleandoc(kwargs.get('description', '')) + self.hidden: bool = kwargs.get('hidden', False) try: checks = func.__commands_checks__ checks.reverse() except AttributeError: checks = kwargs.get('checks', []) - finally: - self.checks = checks + + self.checks: List[UserCheck[Context[Any]]] = checks try: cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') - finally: - self._buckets = CooldownMapping(cooldown) - self.ignore_extra = kwargs.get('ignore_extra', True) - self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False) - self.cog = None + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets: CooldownMapping[Context[Any]] = cooldown + else: + raise TypeError('Cooldown must be an instance of CooldownMapping or None.') + self._buckets: CooldownMapping[Context[Any]] = buckets + + try: + max_concurrency = func.__commands_max_concurrency__ + except AttributeError: + max_concurrency = kwargs.get('max_concurrency') + + self._max_concurrency: Optional[MaxConcurrency] = max_concurrency + + self.require_var_positional: bool = kwargs.get('require_var_positional', False) + self.ignore_extra: bool = kwargs.get('ignore_extra', True) + self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) + self._cog: CogT = None # type: ignore # This breaks every other pyright release # bandaid for the fact that sometimes parent can be the bot instance - parent = kwargs.get('parent') - self.parent = parent if isinstance(parent, _BaseCommand) else None - self._before_invoke = None - self._after_invoke = None + parent: Optional[GroupMixin[Any]] = kwargs.get('parent') + self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None + + self._before_invoke: Optional[Hook] = None + try: + before_invoke = func.__before_invoke__ + except AttributeError: + pass + else: + self.before_invoke(before_invoke) + + self._after_invoke: Optional[Hook] = None + try: + after_invoke = func.__after_invoke__ + except AttributeError: + pass + else: + self.after_invoke(after_invoke) + + @property + def cog(self) -> CogT: + return self._cog + + @cog.setter + def cog(self, value: CogT) -> None: + self._cog = value @property - def callback(self): + def callback( + self, + ) -> Union[ + Callable[Concatenate[CogT, Context[Any], P], Coro[T]], + Callable[Concatenate[Context[Any], P], Coro[T]], + ]: return self._callback @callback.setter - def callback(self, function): + def callback( + self, + function: Union[ + Callable[Concatenate[CogT, Context[Any], P], Coro[T]], + Callable[Concatenate[Context[Any], P], Coro[T]], + ], + ) -> None: self._callback = function - self.module = function.__module__ + unwrap = unwrap_function(function) + self.module: str = unwrap.__module__ + + try: + globalns = unwrap.__globals__ + except AttributeError: + globalns = {} + + self.params: Dict[str, Parameter] = get_signature_parameters(function, globalns) + + def add_check(self, func: UserCheck[Context[Any]], /) -> None: + """Adds a check to the command. + + This is the non-decorator interface to :func:`.check`. + + .. versionadded:: 1.3 + + .. versionchanged:: 2.0 - signature = inspect.signature(function) - self.params = signature.parameters.copy() + ``func`` parameter is now positional-only. - # PEP-563 allows postponing evaluation of annotations with a __future__ - # import. When postponed, Parameter.annotation will be a string and must - # be replaced with the real value for the converters to work later on - for key, value in self.params.items(): - if isinstance(value.annotation, str): - self.params[key] = value = value.replace(annotation=eval(value.annotation, function.__globals__)) + .. seealso:: The :func:`~discord.ext.commands.check` decorator - # fail early for when someone passes an unparameterized Greedy type - if value.annotation is converters.Greedy: - raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + Parameters + ----------- + func + The function that will be used as a check. + """ + + self.checks.append(func) + + def remove_check(self, func: UserCheck[Context[Any]], /) -> None: + """Removes a check from the command. + + This function is idempotent and will not raise an exception + if the function is not in the command's checks. + + .. versionadded:: 1.3 + + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + + Parameters + ----------- + func + The function to remove from the checks. + """ + + try: + self.checks.remove(func) + except ValueError: + pass - def update(self, **kwargs): + def update(self, **kwargs: Unpack[_CommandKwargs]) -> None: """Updates :class:`Command` instance with updated attribute. - This works similarly to the :func:`.command` decorator in terms + This works similarly to the :func:`~discord.ext.commands.command` decorator in terms of parameters in that they are passed to the :class:`Command` or subclass constructors, sans the name and callback. """ + cog = self.cog self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) + self.cog = cog + + async def __call__(self, context: Context[BotT], /, *args: P.args, **kwargs: P.kwargs) -> T: + """|coro| + + Calls the internal callback that the command holds. + + .. note:: + + This bypasses all mechanisms -- including checks, converters, + invoke hooks, cooldowns, etc. You must take care to pass + the proper arguments and types to this function. + + .. versionadded:: 1.3 + + .. versionchanged:: 2.0 - def _ensure_assignment_on_copy(self, other): + ``context`` parameter is now positional-only. + """ + if self.cog is not None: + return await self.callback(self.cog, context, *args, **kwargs) # type: ignore + else: + return await self.callback(context, *args, **kwargs) # type: ignore + + def _ensure_assignment_on_copy(self, other: Self) -> Self: other._before_invoke = self._before_invoke other._after_invoke = self._after_invoke + other.extras = self.extras if self.checks != other.checks: other.checks = self.checks.copy() if self._buckets.valid and not other._buckets.valid: other._buckets = self._buckets.copy() + if self._max_concurrency and self._max_concurrency != other._max_concurrency: + other._max_concurrency = self._max_concurrency.copy() + try: other.on_error = self.on_error except AttributeError: pass return other - def copy(self): - """Creates a copy of this command.""" + def copy(self) -> Self: + """Creates a copy of this command. + + Returns + -------- + :class:`Command` + A new instance of this command. + """ ret = self.__class__(self.callback, **self.__original_kwargs__) return self._ensure_assignment_on_copy(ret) - def _update_copy(self, kwargs): + def _update_copy(self, kwargs: Dict[str, Any]) -> Self: if kwargs: kw = kwargs.copy() kw.update(self.__original_kwargs__) @@ -312,7 +661,7 @@ def _update_copy(self, kwargs): else: return self.copy() - async def dispatch_error(self, ctx, error): + async def dispatch_error(self, ctx: Context[BotT], error: CommandError, /) -> None: ctx.command_failed = True cog = self.cog try: @@ -320,11 +669,11 @@ async def dispatch_error(self, ctx, error): except AttributeError: pass else: - injected = wrap_callback(coro) + injected = wrap_callback(coro) # type: ignore if cog is not None: await injected(cog, ctx, error) else: - await injected(ctx, error) + await injected(ctx, error) # type: ignore try: if cog is not None: @@ -335,126 +684,72 @@ async def dispatch_error(self, ctx, error): finally: ctx.bot.dispatch('command_error', ctx, error) - async def _actual_conversion(self, ctx, converter, argument, param): - if converter is bool: - return _convert_to_bool(argument) - - try: - module = converter.__module__ - except AttributeError: - pass - else: - if module is not None and (module.startswith('discord.') and not module.endswith('converter')): - converter = getattr(converters, converter.__name__ + 'Converter') - - try: - if inspect.isclass(converter): - if issubclass(converter, converters.Converter): - instance = converter() - ret = await instance.convert(ctx, argument) - return ret - else: - method = getattr(converter, 'convert', None) - if method is not None and inspect.ismethod(method): - ret = await method(ctx, argument) - return ret - elif isinstance(converter, converters.Converter): - ret = await converter.convert(ctx, argument) - return ret - except CommandError: - raise - except Exception as exc: - raise ConversionError(converter, exc) from exc - - try: - return converter(argument) - except CommandError: - raise - except Exception as exc: - try: - name = converter.__name__ - except AttributeError: - name = converter.__class__.__name__ - - raise BadArgument('Converting to "{}" failed for parameter "{}".'.format(name, param.name)) from exc - - async def do_conversion(self, ctx, converter, argument, param): - try: - origin = converter.__origin__ - except AttributeError: - pass - else: - if origin is typing.Union: - errors = [] - _NoneType = type(None) - for conv in converter.__args__: - # if we got to this part in the code, then the previous conversions have failed - # so we should just undo the view, return the default, and allow parsing to continue - # with the other parameters - if conv is _NoneType and param.kind != param.VAR_POSITIONAL: - ctx.view.undo() - return None if param.default is param.empty else param.default - - try: - value = await self._actual_conversion(ctx, conv, argument, param) - except CommandError as exc: - errors.append(exc) - else: - return value - - # if we're here, then we failed all the converters - raise BadUnionArgument(param, converter.__args__, errors) - - return await self._actual_conversion(ctx, converter, argument, param) - - def _get_converter(self, param): - converter = param.annotation - if converter is param.empty: - if param.default is not param.empty: - converter = str if param.default is None else type(param.default) - else: - converter = str - return converter - - async def transform(self, ctx, param): - required = param.default is param.empty - converter = self._get_converter(param) + async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _AttachmentIterator, /) -> Any: + converter = param.converter consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() # The greedy converter is simple -- it keeps going until it fails in which case, # it undos the view ready for the next parameter to use instead - if type(converter) is converters._Greedy: - if param.kind == param.POSITIONAL_OR_KEYWORD: - return await self._transform_greedy_pos(ctx, param, required, converter.converter) + if isinstance(converter, Greedy): + # Special case for Greedy[discord.Attachment] to consume the attachments iterator + if converter.converter is discord.Attachment: + return list(attachments) + + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): + return await self._transform_greedy_pos(ctx, param, param.required, converter.constructed_converter) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos(ctx, param, converter.converter) + return await self._transform_greedy_var_pos(ctx, param, converter.constructed_converter) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] # into just X and do the parsing that way. - converter = converter.converter + converter = converter.constructed_converter + + # Try to detect Optional[discord.Attachment] or discord.Attachment special converter + if converter is discord.Attachment: + try: + return next(attachments) + except StopIteration: + raise MissingRequiredAttachment(param) + + if self._is_typing_optional(param.annotation) and param.annotation.__args__[0] is discord.Attachment: + if attachments.is_empty(): + # I have no idea who would be doing Optional[discord.Attachment] = 1 + # but for those cases then 1 should be returned instead of None + return None if param.default is param.empty else param.default + return next(attachments) if view.eof: if param.kind == param.VAR_POSITIONAL: - raise RuntimeError() # break the loop - if required: + raise RuntimeError() # break the loop + if param.required: if self._is_typing_optional(param.annotation): return None + if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible(): + return await converter._construct_default(ctx) raise MissingRequiredArgument(param) - return param.default + return await param.get_default(ctx) previous = view.index if consume_rest_is_special: - argument = view.read_rest().strip() + ctx.current_argument = argument = view.read_rest().strip() else: - argument = view.get_quoted_word() + try: + ctx.current_argument = argument = view.get_quoted_word() + except ArgumentParsingError as exc: + if self._is_typing_optional(param.annotation): + view.index = previous + return None if param.required else await param.get_default(ctx) + else: + raise exc view.previous = previous - return await self.do_conversion(ctx, converter, argument, param) + # type-checker fails to narrow argument + return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos(self, ctx, param, required, converter): + async def _transform_greedy_pos(self, ctx: Context[BotT], param: Parameter, required: bool, converter: Any) -> Any: view = ctx.view result = [] while not view.eof: @@ -462,52 +757,51 @@ async def _transform_greedy_pos(self, ctx, param, required, converter): previous = view.index view.skip_ws() - argument = view.get_quoted_word() try: - value = await self.do_conversion(ctx, converter, argument, param) - except CommandError: + ctx.current_argument = argument = view.get_quoted_word() + value = await run_converters(ctx, converter, argument, param) # type: ignore + except (CommandError, ArgumentParsingError): view.index = previous break else: result.append(value) if not result and not required: - return param.default + return await param.get_default(ctx) return result - async def _transform_greedy_var_pos(self, ctx, param, converter): + async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: Parameter, converter: Any) -> Any: view = ctx.view previous = view.index - argument = view.get_quoted_word() try: - value = await self.do_conversion(ctx, converter, argument, param) - except CommandError: + ctx.current_argument = argument = view.get_quoted_word() + value = await run_converters(ctx, converter, argument, param) # type: ignore + except (CommandError, ArgumentParsingError): view.index = previous - raise RuntimeError() from None # break loop + raise RuntimeError() from None # break loop else: return value @property - def clean_params(self): - """Retrieves the parameter OrderedDict without the context or self parameters. + def clean_params(self) -> Dict[str, Parameter]: + """Dict[:class:`str`, :class:`Parameter`]: + Retrieves the parameter dictionary without the context or self parameters. Useful for inspecting signature. """ - result = self.params.copy() - if self.cog is not None: - # first parameter is self - result.popitem(last=False) + return self.params.copy() - try: - # first/second parameter is context - result.popitem(last=False) - except Exception: - raise ValueError('Missing context parameter') from None + @property + def cooldown(self) -> Optional[Cooldown]: + """Optional[:class:`~discord.app_commands.Cooldown`]: The cooldown of a command when invoked + or ``None`` if the command doesn't have a registered cooldown. - return result + .. versionadded:: 2.0 + """ + return self._buckets._cooldown @property - def full_parent_name(self): + def full_parent_name(self) -> str: """:class:`str`: Retrieves the fully qualified parent command name. This the base command name required to execute it. For example, @@ -515,33 +809,34 @@ def full_parent_name(self): """ entries = [] command = self - while command.parent is not None: - command = command.parent - entries.append(command.name) + # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore return ' '.join(reversed(entries)) @property - def parents(self): - """:class:`Command`: Retrieves the parents of this command. + def parents(self) -> List[Group[Any, ..., Any]]: + """List[:class:`Group`]: Retrieves the parents of this command. If the command has no parents then it returns an empty :class:`list`. For example in commands ``?a b c test``, the parents are ``[c, b, a]``. - .. versionadded:: 1.1.0 + .. versionadded:: 1.1 """ entries = [] command = self - while command.parent is not None: - command = command.parent + while command.parent is not None: # type: ignore + command = command.parent # type: ignore entries.append(command) return entries @property - def root_parent(self): - """Retrieves the root parent of this command. + def root_parent(self) -> Optional[Group[Any, ..., Any]]: + """Optional[:class:`Group`]: Retrieves the root parent of this command. If the command has no parents then it returns ``None``. @@ -552,7 +847,7 @@ def root_parent(self): return self.parents[-1] @property - def qualified_name(self): + def qualified_name(self) -> str: """:class:`str`: Retrieves the fully qualified command name. This is the full parent name with the command name as well. @@ -566,75 +861,58 @@ def qualified_name(self): else: return self.name - def __str__(self): + def __str__(self) -> str: return self.qualified_name - async def _parse_arguments(self, ctx): + async def _parse_arguments(self, ctx: Context[BotT]) -> None: ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.kwargs = {} args = ctx.args kwargs = ctx.kwargs + attachments = _AttachmentIterator(ctx.message.attachments) view = ctx.view iterator = iter(self.params.items()) - if self.cog is not None: - # we have 'self' as the first parameter so just advance - # the iterator and resume parsing - try: - next(iterator) - except StopIteration: - fmt = 'Callback for {0.name} command is missing "self" parameter.' - raise discord.ClientException(fmt.format(self)) - - # next we have the 'ctx' as the next parameter - try: - next(iterator) - except StopIteration: - fmt = 'Callback for {0.name} command is missing "ctx" parameter.' - raise discord.ClientException(fmt.format(self)) - for name, param in iterator: - if param.kind == param.POSITIONAL_OR_KEYWORD: - transformed = await self.transform(ctx, param) + ctx.current_parameter = param + if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): + transformed = await self.transform(ctx, param, attachments) args.append(transformed) elif param.kind == param.KEYWORD_ONLY: # kwarg only param denotes "consume rest" semantics if self.rest_is_raw: - converter = self._get_converter(param) - argument = view.read_rest() - kwargs[name] = await self.do_conversion(ctx, converter, argument, param) + ctx.current_argument = argument = view.read_rest() + kwargs[name] = await run_converters(ctx, param.converter, argument, param) else: - kwargs[name] = await self.transform(ctx, param) + kwargs[name] = await self.transform(ctx, param, attachments) break elif param.kind == param.VAR_POSITIONAL: + if view.eof and self.require_var_positional: + raise MissingRequiredArgument(param) while not view.eof: try: - transformed = await self.transform(ctx, param) + transformed = await self.transform(ctx, param, attachments) args.append(transformed) except RuntimeError: break - if not self.ignore_extra: - if not view.eof: - raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) + if not self.ignore_extra and not view.eof: + raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) - async def _verify_checks(self, ctx): - if not self.enabled: - raise DisabledCommand('{0.name} command is disabled'.format(self)) - - if not await self.can_run(ctx): - raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self)) - - async def call_before_hooks(self, ctx): + async def call_before_hooks(self, ctx: Context[BotT], /) -> None: # now that we're done preparing we can call the pre-command hooks # first, call the command local hook: cog = self.cog if self._before_invoke is not None: - if cog is None: - await self._before_invoke(ctx) + # should be cog if @commands.before_invoke is used + instance = getattr(self._before_invoke, '__self__', cog) + # __self__ only exists for methods, not functions + # however, if @command.before_invoke is used, it will be a function + if instance: + await self._before_invoke(instance, ctx) # type: ignore else: - await self._before_invoke(cog, ctx) + await self._before_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: @@ -647,13 +925,14 @@ async def call_before_hooks(self, ctx): if hook is not None: await hook(ctx) - async def call_after_hooks(self, ctx): + async def call_after_hooks(self, ctx: Context[BotT], /) -> None: cog = self.cog if self._after_invoke is not None: - if cog is None: - await self._after_invoke(ctx) + instance = getattr(self._after_invoke, '__self__', cog) + if instance: + await self._after_invoke(instance, ctx) # type: ignore else: - await self._after_invoke(cog, ctx) + await self._after_invoke(ctx) # type: ignore # call the cog local hook if applicable: if cog is not None: @@ -665,30 +944,47 @@ async def call_after_hooks(self, ctx): if hook is not None: await hook(ctx) - def _prepare_cooldowns(self, ctx): + def _prepare_cooldowns(self, ctx: Context[BotT]) -> None: if self._buckets.valid: - current = ctx.message.created_at.replace(tzinfo=datetime.timezone.utc).timestamp() - bucket = self._buckets.get_bucket(ctx.message, current) - retry_after = bucket.update_rate_limit(current) - if retry_after: - raise CommandOnCooldown(bucket, retry_after) - - async def prepare(self, ctx): + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + bucket = self._buckets.get_bucket(ctx, current) + if bucket is not None: + retry_after = bucket.update_rate_limit(current) + if retry_after: + raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore + + async def prepare(self, ctx: Context[BotT], /) -> None: ctx.command = self - await self._verify_checks(ctx) - if self.cooldown_after_parsing: - await self._parse_arguments(ctx) - self._prepare_cooldowns(ctx) - else: - self._prepare_cooldowns(ctx) - await self._parse_arguments(ctx) + if not await self.can_run(ctx): + raise CheckFailure(f'The check functions for command {self.qualified_name} failed.') + + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) + + try: + if self.cooldown_after_parsing: + await self._parse_arguments(ctx) + self._prepare_cooldowns(ctx) + else: + self._prepare_cooldowns(ctx) + await self._parse_arguments(ctx) - await self.call_before_hooks(ctx) + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) + raise - def is_on_cooldown(self, ctx): + def is_on_cooldown(self, ctx: Context[BotT], /) -> bool: """Checks whether the command is currently on cooldown. + .. versionchanged:: 2.0 + + ``ctx`` parameter is now positional-only. + Parameters ----------- ctx: :class:`.Context` @@ -702,32 +998,72 @@ def is_on_cooldown(self, ctx): if not self._buckets.valid: return False - bucket = self._buckets.get_bucket(ctx.message) - return bucket.get_tokens() == 0 + bucket = self._buckets.get_bucket(ctx) + if bucket is None: + return False + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_tokens(current) == 0 - def reset_cooldown(self, ctx): + def reset_cooldown(self, ctx: Context[BotT], /) -> None: """Resets the cooldown on this command. + .. versionchanged:: 2.0 + + ``ctx`` parameter is now positional-only. + Parameters ----------- ctx: :class:`.Context` The invocation context to reset the cooldown under. """ if self._buckets.valid: - bucket = self._buckets.get_bucket(ctx.message) - bucket.reset() + bucket = self._buckets.get_bucket(ctx) + if bucket is not None: + bucket.reset() + + def get_cooldown_retry_after(self, ctx: Context[BotT], /) -> float: + """Retrieves the amount of seconds before this command can be tried again. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 + + ``ctx`` parameter is now positional-only. + + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to retrieve the cooldown from. + + Returns + -------- + :class:`float` + The amount of time left on this command's cooldown in seconds. + If this is ``0.0`` then the command isn't on cooldown. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx) + if bucket is None: + return 0.0 + dt = ctx.message.edited_at or ctx.message.created_at + current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() + return bucket.get_retry_after(current) - async def invoke(self, ctx): + return 0.0 + + async def invoke(self, ctx: Context[BotT], /) -> None: await self.prepare(ctx) # terminate the invoked_subcommand chain. # since we're in a regular command (and not a group) then # the invoked subcommand is None. ctx.invoked_subcommand = None - injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) + ctx.subcommand_passed = None + injected = hooked_wrapped_callback(self, ctx, self.callback) # type: ignore + await injected(*ctx.args, **ctx.kwargs) # type: ignore - async def reinvoke(self, ctx, *, call_hooks=False): + async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None: ctx.command = self await self._parse_arguments(ctx) @@ -736,7 +1072,7 @@ async def reinvoke(self, ctx, *, call_hooks=False): ctx.invoked_subcommand = None try: - await self.callback(*ctx.args, **ctx.kwargs) + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore except: ctx.command_failed = True raise @@ -744,13 +1080,17 @@ async def reinvoke(self, ctx, *, call_hooks=False): if call_hooks: await self.call_after_hooks(ctx) - def error(self, coro): + def error(self, coro: Error[CogT, ContextT], /) -> Error[CogT, ContextT]: """A decorator that registers a coroutine as a local error handler. A local error handler is an :func:`.on_command_error` event limited to a single command. However, the :func:`.on_command_error` is still invoked afterwards as the catch-all. + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Parameters ----------- coro: :ref:`coroutine ` @@ -765,10 +1105,17 @@ def error(self, coro): if not asyncio.iscoroutinefunction(coro): raise TypeError('The error handler must be a coroutine.') - self.on_error = coro + self.on_error: Error[CogT, Any] = coro return coro - def before_invoke(self, coro): + def has_error_handler(self) -> bool: + """:class:`bool`: Checks whether the command has an error handler registered. + + .. versionadded:: 1.7 + """ + return hasattr(self, 'on_error') + + def before_invoke(self, coro: Hook[CogT, ContextT], /) -> Hook[CogT, ContextT]: """A decorator that registers a coroutine as a pre-invoke hook. A pre-invoke hook is called directly before the command is @@ -779,6 +1126,10 @@ def before_invoke(self, coro): See :meth:`.Bot.before_invoke` for more info. + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Parameters ----------- coro: :ref:`coroutine ` @@ -795,7 +1146,7 @@ def before_invoke(self, coro): self._before_invoke = coro return coro - def after_invoke(self, coro): + def after_invoke(self, coro: Hook[CogT, ContextT], /) -> Hook[CogT, ContextT]: """A decorator that registers a coroutine as a post-invoke hook. A post-invoke hook is called directly after the command is @@ -806,6 +1157,10 @@ def after_invoke(self, coro): See :meth:`.Bot.after_invoke` for more info. + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + Parameters ----------- coro: :ref:`coroutine ` @@ -823,17 +1178,17 @@ def after_invoke(self, coro): return coro @property - def cog_name(self): - """:class:`str`: The name of the cog this command belongs to. None otherwise.""" + def cog_name(self) -> Optional[str]: + """Optional[:class:`str`]: The name of the cog this command belongs to, if any.""" return type(self.cog).__cog_name__ if self.cog is not None else None @property - def short_doc(self): + def short_doc(self) -> str: """:class:`str`: Gets the "short" documentation of a command. - By default, this is the :attr:`brief` attribute. + By default, this is the :attr:`.brief` attribute. If that lookup leads to an empty string then the first line of the - :attr:`help` attribute is used instead. + :attr:`.help` attribute is used instead. """ if self.brief is not None: return self.brief @@ -841,59 +1196,89 @@ def short_doc(self): return self.help.split('\n', 1)[0] return '' - def _is_typing_optional(self, annotation): - try: - origin = annotation.__origin__ - except AttributeError: - return False - - if origin is not typing.Union: - return False - - return annotation.__args__[-1] is type(None) + def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> bool: + return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore @property - def signature(self): + def signature(self) -> str: """:class:`str`: Returns a POSIX-like signature useful for help command output.""" if self.usage is not None: return self.usage - params = self.clean_params if not params: return '' result = [] - for name, param in params.items(): - greedy = isinstance(param.annotation, converters._Greedy) - - if param.default is not param.empty: + for param in params.values(): + name = param.displayed_name or param.name + + greedy = isinstance(param.converter, Greedy) + optional = False # postpone evaluation of if it's an optional argument + + annotation: Any = param.converter.converter if greedy else param.converter + origin = getattr(annotation, '__origin__', None) + if not greedy and origin is Union: + none_cls = type(None) + union_args = annotation.__args__ + optional = union_args[-1] is none_cls + if len(union_args) == 2 and optional: + annotation = union_args[0] + origin = getattr(annotation, '__origin__', None) + + if annotation is discord.Attachment: + # For discord.Attachment we need to signal to the user that it's an attachment + # It's not exactly pretty but it's enough to differentiate + if optional: + result.append(f'[{name} (upload a file)]') + elif greedy: + result.append(f'[{name} (upload files)]...') + else: + result.append(f'<{name} (upload a file)>') + continue + + # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the + # parameter signature is a literal list of it's values + if origin is Literal: + name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) + if not param.required: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None - if should_print: - result.append('[%s=%s]' % (name, param.default) if not greedy else - '[%s=%s]...' % (name, param.default)) + if param.displayed_default: + result.append( + f'[{name}={param.displayed_default}]' if not greedy else f'[{name}={param.displayed_default}]...' + ) continue else: - result.append('[%s]' % name) + result.append(f'[{name}]') elif param.kind == param.VAR_POSITIONAL: - result.append('[%s...]' % name) + if self.require_var_positional: + result.append(f'<{name}...>') + else: + result.append(f'[{name}...]') elif greedy: - result.append('[%s]...' % name) - elif self._is_typing_optional(param.annotation): - result.append('[%s]' % name) + result.append(f'[{name}]...') + elif optional: + result.append(f'[{name}]') else: - result.append('<%s>' % name) + result.append(f'<{name}>') return ' '.join(result) - async def can_run(self, ctx): + async def can_run(self, ctx: Context[BotT], /) -> bool: """|coro| Checks if the command can be executed by checking all the predicates - inside the :attr:`.checks` attribute. + inside the :attr:`~Command.checks` attribute. This also checks whether the + command is disabled. + + .. versionchanged:: 1.3 + Checks whether the command is disabled or not + + .. versionchanged:: 2.0 + + ``ctx`` parameter is now positional-only. Parameters ----------- @@ -912,12 +1297,15 @@ async def can_run(self, ctx): A boolean indicating if the command can be invoked. """ + if not self.enabled: + raise DisabledCommand(f'{self.name} command is disabled') + original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self)) + raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.') cog = self.cog if cog is not None: @@ -932,46 +1320,54 @@ async def can_run(self, ctx): # since we have no checks, then we just return True. return True - return await discord.utils.async_all(predicate(ctx) for predicate in predicates) + return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore finally: ctx.command = original -class GroupMixin: + +class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. Attributes ----------- all_commands: :class:`dict` - A mapping of command name to :class:`.Command` or subclass + A mapping of command name to :class:`.Command` objects. case_insensitive: :class:`bool` Whether the commands should be case insensitive. Defaults to ``False``. """ - def __init__(self, *args, **kwargs): + + def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get('case_insensitive', False) - self.all_commands = _CaseInsensitiveDict() if case_insensitive else {} - self.case_insensitive = case_insensitive + self.all_commands: Dict[str, Command[CogT, ..., Any]] = _CaseInsensitiveDict() if case_insensitive else {} + self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @property - def commands(self): + def commands(self) -> Set[Command[CogT, ..., Any]]: """Set[:class:`.Command`]: A unique set of commands without aliases that are registered.""" return set(self.all_commands.values()) - def recursively_remove_all_commands(self): + def recursively_remove_all_commands(self) -> None: for command in self.all_commands.copy().values(): if isinstance(command, GroupMixin): command.recursively_remove_all_commands() self.remove_command(command.name) - def add_command(self, command): - """Adds a :class:`.Command` or its subclasses into the internal list - of commands. + def add_command(self, command: Command[CogT, ..., Any], /) -> None: + """Adds a :class:`.Command` into the internal list of commands. This is usually not called, instead the :meth:`~.GroupMixin.command` or :meth:`~.GroupMixin.group` shortcut decorators are used instead. + .. versionchanged:: 1.4 + Raise :exc:`.CommandRegistrationError` instead of generic :exc:`.ClientException` + + .. versionchanged:: 2.0 + + ``command`` parameter is now positional-only. + Parameters ----------- command: :class:`Command` @@ -979,8 +1375,8 @@ def add_command(self, command): Raises ------- - :exc:`.ClientException` - If the command is already registered. + CommandRegistrationError + If the command or its alias is already registered by different command. TypeError If the command passed is not a subclass of :class:`.Command`. """ @@ -992,20 +1388,25 @@ def add_command(self, command): command.parent = self if command.name in self.all_commands: - raise discord.ClientException('Command {0.name} is already registered.'.format(command)) + raise CommandRegistrationError(command.name) self.all_commands[command.name] = command for alias in command.aliases: if alias in self.all_commands: - raise discord.ClientException('The alias {} is already an existing command or alias.'.format(alias)) + self.remove_command(command.name) + raise CommandRegistrationError(alias, alias_conflict=True) self.all_commands[alias] = command - def remove_command(self, name): - """Remove a :class:`.Command` or subclasses from the internal list + def remove_command(self, name: str, /) -> Optional[Command[CogT, ..., Any]]: + """Remove a :class:`.Command` from the internal list of commands. This could also be used as a way to remove aliases. + .. versionchanged:: 2.0 + + ``name`` parameter is now positional-only. + Parameters ----------- name: :class:`str` @@ -1013,9 +1414,9 @@ def remove_command(self, name): Returns -------- - :class:`.Command` or subclass + Optional[:class:`.Command`] The command that was removed. If the name is not valid then - `None` is returned instead. + ``None`` is returned instead. """ command = self.all_commands.pop(name, None) @@ -1029,18 +1430,32 @@ def remove_command(self, name): # we're not removing the alias so let's delete the rest of them. for alias in command.aliases: - self.all_commands.pop(alias, None) + cmd = self.all_commands.pop(alias, None) + # in the case of a CommandRegistrationError, an alias might conflict + # with an already existing command. If this is the case, we want to + # make sure the pre-existing command is not removed. + if cmd is not None and cmd != command: + self.all_commands[alias] = cmd return command - def walk_commands(self): - """An iterator that recursively walks through all commands and subcommands.""" - for command in tuple(self.all_commands.values()): + def walk_commands(self) -> Generator[Command[CogT, ..., Any], None, None]: + """An iterator that recursively walks through all commands and subcommands. + + .. versionchanged:: 1.4 + Duplicates due to aliases are no longer returned + + Yields + ------ + Union[:class:`.Command`, :class:`.Group`] + A command or group from the internal list of commands. + """ + for command in self.commands: yield command if isinstance(command, GroupMixin): yield from command.walk_commands() - def get_command(self, name): - """Get a :class:`.Command` or subclasses from the internal list + def get_command(self, name: str, /) -> Optional[Command[CogT, ..., Any]]: + """Get a :class:`.Command` from the internal list of commands. This could also be used as a way to get aliases. @@ -1049,6 +1464,10 @@ def get_command(self, name): the subcommand ``bar`` of the group command ``foo``. If a subcommand is not found then ``None`` is returned just as usual. + .. versionchanged:: 2.0 + + ``name`` parameter is now positional-only. + Parameters ----------- name: :class:`str` @@ -1056,7 +1475,7 @@ def get_command(self, name): Returns -------- - :class:`Command` or subclass + Optional[:class:`Command`] The command that was requested. If not found, returns ``None``. """ @@ -1065,43 +1484,136 @@ def get_command(self, name): return self.all_commands.get(name) names = name.split() + if not names: + return None obj = self.all_commands.get(names[0]) if not isinstance(obj, GroupMixin): return obj for name in names[1:]: try: - obj = obj.all_commands[name] + obj = obj.all_commands[name] # type: ignore except (AttributeError, KeyError): return None return obj - def command(self, *args, **kwargs): - """A shortcut decorator that invokes :func:`.command` and adds it to + @overload + def command( + self: GroupMixin[CogT], + name: str = ..., + *args: Any, + **kwargs: Unpack[_CommandDecoratorKwargs], + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + Command[CogT, P, T], + ]: ... + + @overload + def command( + self: GroupMixin[CogT], + name: str = ..., + cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set + *args: Any, + **kwargs: Unpack[_CommandDecoratorKwargs], + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + CommandT, + ]: ... + + def command( + self, + name: str = MISSING, + cls: Type[Command[Any, ..., Any]] = MISSING, + *args: Any, + **kwargs: Unpack[_CommandDecoratorKwargs], + ) -> Any: + """A shortcut decorator that invokes :func:`~discord.ext.commands.command` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. + + Returns + -------- + Callable[..., :class:`Command`] + A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ + def decorator(func): - kwargs.setdefault('parent', self) - result = command(*args, **kwargs)(func) + kwargs.setdefault('parent', self) # type: ignore # the parent kwarg is not for users to set. + result = command(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator - def group(self, *args, **kwargs): + @overload + def group( + self: GroupMixin[CogT], + name: str = ..., + *args: Any, + **kwargs: Unpack[_GroupDecoratorKwargs], + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + Group[CogT, P, T], + ]: ... + + @overload + def group( + self: GroupMixin[CogT], + name: str = ..., + cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set + *args: Any, + **kwargs: Unpack[_GroupDecoratorKwargs], + ) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], + GroupT, + ]: ... + + def group( + self, + name: str = MISSING, + cls: Type[Group[Any, ..., Any]] = MISSING, + *args: Any, + **kwargs: Unpack[_GroupDecoratorKwargs], + ) -> Any: """A shortcut decorator that invokes :func:`.group` and adds it to the internal command list via :meth:`~.GroupMixin.add_command`. + + Returns + -------- + Callable[..., :class:`Group`] + A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ + def decorator(func): - kwargs.setdefault('parent', self) - result = group(*args, **kwargs)(func) + kwargs.setdefault('parent', self) # type: ignore # the parent kwarg is not for users to set. + result = group(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator -class Group(GroupMixin, Command): + +class Group(GroupMixin[CogT], Command[CogT, P, T]): """A class that implements a grouping protocol for commands to be executed as subcommands. @@ -1110,7 +1622,7 @@ class Group(GroupMixin, Command): Attributes ----------- - invoke_without_command: Optional[:class:`bool`] + invoke_without_command: :class:`bool` Indicates if the group callback should begin parsing and invocation only if no subcommand was found. Useful for making it an error handling function to tell the user that @@ -1119,23 +1631,31 @@ class Group(GroupMixin, Command): the group callback will always be invoked first. This means that the checks and the parsing dictated by its parameters will be executed. Defaults to ``False``. - case_insensitive: Optional[:class:`bool`] + case_insensitive: :class:`bool` Indicates if the group's commands should be case insensitive. Defaults to ``False``. """ - def __init__(self, *args, **attrs): - self.invoke_without_command = attrs.pop('invoke_without_command', False) + + def __init__(self, *args: Any, **attrs: Unpack[_GroupKwargs]) -> None: + self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) super().__init__(*args, **attrs) - def copy(self): - """Creates a copy of this :class:`Group`.""" + def copy(self) -> Self: + """Creates a copy of this :class:`Group`. + + Returns + -------- + :class:`Group` + A new instance of this group. + """ ret = super().copy() for cmd in self.commands: ret.add_command(cmd.copy()) return ret - async def invoke(self, ctx): + async def invoke(self, ctx: Context[BotT], /) -> None: ctx.invoked_subcommand = None + ctx.subcommand_passed = None early_invoke = not self.invoke_without_command if early_invoke: await self.prepare(ctx) @@ -1150,8 +1670,10 @@ async def invoke(self, ctx): ctx.invoked_subcommand = self.all_commands.get(trigger, None) if early_invoke: - injected = hooked_wrapped_callback(self, ctx, self.callback) - await injected(*ctx.args, **ctx.kwargs) + injected = hooked_wrapped_callback(self, ctx, self.callback) # type: ignore + await injected(*ctx.args, **ctx.kwargs) # type: ignore + + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore if trigger and ctx.invoked_subcommand: ctx.invoked_with = trigger @@ -1162,7 +1684,7 @@ async def invoke(self, ctx): view.previous = previous await super().invoke(ctx) - async def reinvoke(self, ctx, *, call_hooks=False): + async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None: ctx.invoked_subcommand = None early_invoke = not self.invoke_without_command if early_invoke: @@ -1183,7 +1705,7 @@ async def reinvoke(self, ctx, *, call_hooks=False): if early_invoke: try: - await self.callback(*ctx.args, **ctx.kwargs) + await self.callback(*ctx.args, **ctx.kwargs) # type: ignore except: ctx.command_failed = True raise @@ -1191,6 +1713,8 @@ async def reinvoke(self, ctx, *, call_hooks=False): if call_hooks: await self.call_after_hooks(ctx) + ctx.invoked_parents.append(ctx.invoked_with) # type: ignore + if trigger and ctx.invoked_subcommand: ctx.invoked_with = trigger await ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks) @@ -1200,9 +1724,59 @@ async def reinvoke(self, ctx, *, call_hooks=False): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators -def command(name=None, cls=None, **attrs): +if TYPE_CHECKING: + # Using a class to emulate a function allows for overloading the inner function in the decorator. + + class _CommandDecorator: + @overload + def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Command[CogT, P, T]: ... + + @overload + def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Command[None, P, T]: ... + + def __call__(self, func: Callable[..., Coro[T]], /) -> Any: ... + + class _GroupDecorator: + @overload + def __call__(self, func: Callable[Concatenate[CogT, ContextT, P], Coro[T]], /) -> Group[CogT, P, T]: ... + + @overload + def __call__(self, func: Callable[Concatenate[ContextT, P], Coro[T]], /) -> Group[None, P, T]: ... + + def __call__(self, func: Callable[..., Coro[T]], /) -> Any: ... + + +@overload +def command( + name: str = ..., + **attrs: Unpack[_CommandDecoratorKwargs], +) -> _CommandDecorator: ... + + +@overload +def command( + name: str = ..., + cls: Type[CommandT] = ..., # type: ignore # previous overload handles case where cls is not set + **attrs: Unpack[_CommandDecoratorKwargs], +) -> Callable[ + [ + Union[ + Callable[Concatenate[ContextT, P], Coro[Any]], + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore # CogT is used here to allow covariance + ] + ], + CommandT, +]: ... + + +def command( + name: str = MISSING, + cls: Type[Command[Any, ..., Any]] = MISSING, + **attrs: Unpack[_CommandDecoratorKwargs], +) -> Any: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1232,7 +1806,7 @@ def command(name=None, cls=None, **attrs): TypeError If the function is not a coroutine or is already a command. """ - if cls is None: + if cls is MISSING: cls = Command def decorator(func): @@ -1242,20 +1816,50 @@ def decorator(func): return decorator -def group(name=None, **attrs): + +@overload +def group( + name: str = ..., + **attrs: Unpack[_GroupDecoratorKwargs], +) -> _GroupDecorator: ... + + +@overload +def group( + name: str = ..., + cls: Type[GroupT] = ..., # type: ignore # previous overload handles case where cls is not set + **attrs: Unpack[_GroupDecoratorKwargs], +) -> Callable[ + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore # CogT is used here to allow covariance + Callable[Concatenate[ContextT, P], Coro[Any]], + ] + ], + GroupT, +]: ... + + +def group( + name: str = MISSING, + cls: Type[Group[Any, ..., Any]] = MISSING, + **attrs: Unpack[_GroupDecoratorKwargs], +) -> Any: """A decorator that transforms a function into a :class:`.Group`. - This is similar to the :func:`.command` decorator but the ``cls`` + This is similar to the :func:`~discord.ext.commands.command` decorator but the ``cls`` parameter is set to :class:`Group` by default. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 The ``cls`` parameter can now be passed. """ + if cls is MISSING: + cls = Group + + return command(name=name, cls=cls, **attrs) - attrs.setdefault('cls', Group) - return command(name=name, **attrs) -def check(predicate): +def check(predicate: UserCheck[ContextT], /) -> Check[ContextT]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1269,9 +1873,27 @@ def check(predicate): will be propagated while those subclassed will be sent to :func:`.on_command_error`. + A special attribute named ``predicate`` is bound to the value + returned by this decorator to retrieve the predicate passed to the + decorator. This allows the following introspection and chaining to be done: + + .. code-block:: python3 + + def owner_or_permissions(**perms): + original = commands.has_permissions(**perms).predicate + async def extended_check(ctx): + if ctx.guild is None: + return False + return ctx.guild.owner_id == ctx.author.id or await original(ctx) + return commands.check(extended_check) + .. note:: - These functions can either be regular functions or coroutines. + The function returned by ``predicate`` is **always** a coroutine, + even if the original function was not a coroutine. + + .. versionchanged:: 1.3 + The ``predicate`` attribute was added. Examples --------- @@ -1302,15 +1924,19 @@ def predicate(ctx): async def only_me(ctx): await ctx.send('Only you!') + .. versionchanged:: 2.0 + + ``predicate`` parameter is now positional-only. + Parameters ----------- predicate: Callable[[:class:`Context`], :class:`bool`] The predicate to check if the command should be invoked. """ - def decorator(func): + def decorator(func: Union[Command[Any, ..., Any], CoroFunc]) -> Union[Command[Any, ..., Any], CoroFunc]: if isinstance(func, Command): - func.checks.append(predicate) + func.checks.append(predicate) # type: ignore else: if not hasattr(func, '__commands_checks__'): func.__commands_checks__ = [] @@ -1318,9 +1944,90 @@ def decorator(func): func.__commands_checks__.append(predicate) return func - return decorator -def has_role(item): + if inspect.iscoroutinefunction(predicate): + decorator.predicate = predicate + else: + + @functools.wraps(predicate) + async def wrapper(ctx: ContextT): + return predicate(ctx) + + decorator.predicate = wrapper + + return decorator # type: ignore + + +def check_any(*checks: Check[ContextT]) -> Check[ContextT]: + r"""A :func:`check` that is added that checks if any of the checks passed + will pass, i.e. using logical OR. + + If all checks fail then :exc:`.CheckAnyFailure` is raised to signal the failure. + It inherits from :exc:`.CheckFailure`. + + .. note:: + + The ``predicate`` attribute for this function **is** a coroutine. + + .. versionadded:: 1.3 + + Parameters + ------------ + \*checks: Callable[[:class:`Context`], :class:`bool`] + An argument list of checks that have been decorated with + the :func:`check` decorator. + + Raises + ------- + TypeError + A check passed has not been decorated with the :func:`check` + decorator. + + Examples + --------- + + Creating a basic check to see if it's the bot owner or + the server owner: + + .. code-block:: python3 + + def is_guild_owner(): + def predicate(ctx): + return ctx.guild is not None and ctx.guild.owner_id == ctx.author.id + return commands.check(predicate) + + @bot.command() + @commands.check_any(commands.is_owner(), is_guild_owner()) + async def only_for_owners(ctx): + await ctx.send('Hello mister owner!') + """ + + unwrapped = [] + for wrapped in checks: + try: + pred = wrapped.predicate + except AttributeError: + raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None + else: + unwrapped.append(pred) + + async def predicate(ctx: Context[BotT]) -> bool: + errors = [] + for func in unwrapped: + try: + value = await func(ctx) + except CheckFailure as e: + errors.append(e) + else: + if value: + return True + # if we're here, all checks failed + raise CheckAnyFailure(unwrapped, errors) + + return check(predicate) + + +def has_role(item: Union[int, str], /) -> Check[Any]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -1336,35 +2043,41 @@ def has_role(item): is missing a role, or :exc:`.NoPrivateMessage` if it is used in a private message. Both inherit from :exc:`.CheckFailure`. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 Raise :exc:`.MissingRole` or :exc:`.NoPrivateMessage` instead of generic :exc:`.CheckFailure` + .. versionchanged:: 2.0 + + ``item`` parameter is now positional-only. + Parameters ----------- item: Union[:class:`int`, :class:`str`] The name or ID of the role to check. """ - def predicate(ctx): - if not isinstance(ctx.channel, discord.abc.GuildChannel): + def predicate(ctx: Context[BotT]) -> bool: + if ctx.guild is None: raise NoPrivateMessage() + # ctx.guild is None doesn't narrow ctx.author to Member if isinstance(item, int): - role = discord.utils.get(ctx.author.roles, id=item) + role = ctx.author.get_role(item) # type: ignore else: - role = discord.utils.get(ctx.author.roles, name=item) + role = discord.utils.get(ctx.author.roles, name=item) # type: ignore if role is None: raise MissingRole(item) return True return check(predicate) -def has_any_role(*items): + +def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: r"""A :func:`.check` that is added that checks if the member invoking the command has **any** of the roles specified. This means that if they have - one out of the three roles specified, then this check will return `True`. + one out of the three roles specified, then this check will return ``True``. Similar to :func:`.has_role`\, the names or IDs passed in must be exact. @@ -1372,7 +2085,7 @@ def has_any_role(*items): is missing all roles, or :exc:`.NoPrivateMessage` if it is used in a private message. Both inherit from :exc:`.CheckFailure`. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 Raise :exc:`.MissingAnyRole` or :exc:`.NoPrivateMessage` instead of generic :exc:`.CheckFailure` @@ -1392,18 +2105,25 @@ def has_any_role(*items): async def cool(ctx): await ctx.send('You are cool indeed') """ + def predicate(ctx): - if not isinstance(ctx.channel, discord.abc.GuildChannel): + if ctx.guild is None: raise NoPrivateMessage() - getter = functools.partial(discord.utils.get, ctx.author.roles) - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + # ctx.guild is None doesn't narrow ctx.author to Member + if any( + ctx.author.get_role(item) is not None + if isinstance(item, int) + else discord.utils.get(ctx.author.roles, name=item) is not None + for item in items + ): return True - raise MissingAnyRole(items) + raise MissingAnyRole(list(items)) return check(predicate) -def bot_has_role(item): + +def bot_has_role(item: int, /) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the role. @@ -1411,28 +2131,32 @@ def bot_has_role(item): is missing the role, or :exc:`.NoPrivateMessage` if it is used in a private message. Both inherit from :exc:`.CheckFailure`. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 Raise :exc:`.BotMissingRole` or :exc:`.NoPrivateMessage` instead of generic :exc:`.CheckFailure` + + .. versionchanged:: 2.0 + + ``item`` parameter is now positional-only. """ def predicate(ctx): - ch = ctx.channel - if not isinstance(ch, discord.abc.GuildChannel): + if ctx.guild is None: raise NoPrivateMessage() - me = ch.guild.me if isinstance(item, int): - role = discord.utils.get(me.roles, id=item) + role = ctx.me.get_role(item) else: - role = discord.utils.get(me.roles, name=item) + role = discord.utils.get(ctx.me.roles, name=item) if role is None: raise BotMissingRole(item) return True + return check(predicate) -def bot_has_any_role(*items): + +def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has any of the roles listed. @@ -1440,27 +2164,34 @@ def bot_has_any_role(*items): is missing all roles, or :exc:`.NoPrivateMessage` if it is used in a private message. Both inherit from :exc:`.CheckFailure`. - .. versionchanged:: 1.1.0 + .. versionchanged:: 1.1 Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` instead of generic checkfailure """ + def predicate(ctx): - ch = ctx.channel - if not isinstance(ch, discord.abc.GuildChannel): + if ctx.guild is None: raise NoPrivateMessage() - me = ch.guild.me - getter = functools.partial(discord.utils.get, me.roles) - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + me = ctx.me + if any( + me.get_role(item) is not None if isinstance(item, int) else discord.utils.get(me.roles, name=item) is not None + for item in items + ): return True - raise BotMissingAnyRole(items) + raise BotMissingAnyRole(list(items)) + return check(predicate) -def has_permissions(**perms): + +def has_permissions(**perms: Unpack[_PermissionsKwargs]) -> Check[Any]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. + Note that this check operates on the current channel permissions, not the + guild wide permissions. + The permissions passed in must be exactly like the properties shown under :class:`.discord.Permissions`. @@ -1483,11 +2214,15 @@ async def test(ctx): await ctx.send('You can manage messages.') """ - def predicate(ctx): - ch = ctx.channel - permissions = ch.permissions_for(ctx.author) - missing = [perm for perm, value in perms.items() if getattr(permissions, perm, None) != value] + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(ctx: Context[BotT]) -> bool: + permissions = ctx.permissions + + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -1496,19 +2231,78 @@ def predicate(ctx): return check(predicate) -def bot_has_permissions(**perms): + +def bot_has_permissions(**perms: Unpack[_PermissionsKwargs]) -> Check[Any]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. This check raises a special exception, :exc:`.BotMissingPermissions` that is inherited from :exc:`.CheckFailure`. """ - def predicate(ctx): - guild = ctx.guild - me = guild.me if guild is not None else ctx.bot.user - permissions = ctx.channel.permissions_for(me) - missing = [perm for perm, value in perms.items() if getattr(permissions, perm, None) != value] + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(ctx: Context[BotT]) -> bool: + permissions = ctx.bot_permissions + + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return check(predicate) + + +def has_guild_permissions(**perms: Unpack[_PermissionsKwargs]) -> Check[Any]: + """Similar to :func:`.has_permissions`, but operates on guild wide + permissions instead of the current channel permissions. + + If this check is called in a DM context, it will raise an + exception, :exc:`.NoPrivateMessage`. + + .. versionadded:: 1.3 + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(ctx: Context[BotT]) -> bool: + if not ctx.guild: + raise NoPrivateMessage + + permissions = ctx.author.guild_permissions # type: ignore + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return True + + raise MissingPermissions(missing) + + return check(predicate) + + +def bot_has_guild_permissions(**perms: Unpack[_PermissionsKwargs]) -> Check[Any]: + """Similar to :func:`.has_guild_permissions`, but checks the bot + members guild permissions. + + .. versionadded:: 1.3 + """ + + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f'Invalid permission(s): {", ".join(invalid)}') + + def predicate(ctx: Context[BotT]) -> bool: + if not ctx.guild: + raise NoPrivateMessage + + permissions = ctx.me.guild_permissions # type: ignore + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -1517,7 +2311,8 @@ def predicate(ctx): return check(predicate) -def dm_only(): + +def dm_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when using the command. @@ -1525,33 +2320,69 @@ def dm_only(): This check raises a special exception, :exc:`.PrivateMessageOnly` that is inherited from :exc:`.CheckFailure`. - .. versionadded:: 1.1.0 + .. versionadded:: 1.1 """ - def predicate(ctx): + def predicate(ctx: Context[BotT]) -> bool: if ctx.guild is not None: raise PrivateMessageOnly() return True return check(predicate) -def guild_only(): + +def guild_only() -> Check[Any]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when using the command. This check raises a special exception, :exc:`.NoPrivateMessage` that is inherited from :exc:`.CheckFailure`. + + If used on hybrid commands, this will be equivalent to the + :func:`discord.app_commands.guild_only` decorator. In an unsupported + context, such as a subcommand, this will still fallback to applying the + check. """ - def predicate(ctx): + # Due to implementation quirks, this check has to be re-implemented completely + # to work with both app_commands and the command framework. + + def predicate(ctx: Context[BotT]) -> bool: if ctx.guild is None: raise NoPrivateMessage() return True - return check(predicate) + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + if isinstance(func, Command): + func.checks.append(predicate) + if hasattr(func, '__commands_is_hybrid__'): + app_command = getattr(func, 'app_command', None) + if app_command: + app_command.guild_only = True + else: + if not hasattr(func, '__commands_checks__'): + func.__commands_checks__ = [] + + func.__commands_checks__.append(predicate) + func.__discord_app_commands_guild_only__ = True + + return func + + if inspect.iscoroutinefunction(predicate): + decorator.predicate = predicate + else: + + @functools.wraps(predicate) + async def wrapper(ctx: Context[BotT]): + return predicate(ctx) + + decorator.predicate = wrapper + + return decorator # type: ignore -def is_owner(): + +def is_owner() -> Check[Any]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -1561,47 +2392,83 @@ def is_owner(): from :exc:`.CheckFailure`. """ - async def predicate(ctx): + async def predicate(ctx: Context[BotT]) -> bool: if not await ctx.bot.is_owner(ctx.author): raise NotOwner('You do not own this bot.') return True return check(predicate) -def is_nsfw(): + +def is_nsfw() -> Check[Any]: """A :func:`.check` that checks if the channel is a NSFW channel. This check raises a special exception, :exc:`.NSFWChannelRequired` that is derived from :exc:`.CheckFailure`. - .. versionchanged:: 1.1.0 + If used on hybrid commands, this will be equivalent to setting the + application command's ``nsfw`` attribute to ``True``. In an unsupported + context, such as a subcommand, this will still fallback to applying the + check. + + .. versionchanged:: 1.1 Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ - def pred(ctx): + + # Due to implementation quirks, this check has to be re-implemented completely + # to work with both app_commands and the command framework. + + def predicate(ctx: Context[BotT]) -> bool: ch = ctx.channel - if ctx.guild is None or (isinstance(ch, discord.TextChannel) and ch.is_nsfw()): + if ctx.guild is None or ( + isinstance(ch, (discord.TextChannel, discord.Thread, discord.VoiceChannel)) and ch.is_nsfw() + ): return True - raise NSFWChannelRequired(ch) - return check(pred) + raise NSFWChannelRequired(ch) # type: ignore + + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + if isinstance(func, Command): + func.checks.append(predicate) + if hasattr(func, '__commands_is_hybrid__'): + app_command = getattr(func, 'app_command', None) + if app_command: + app_command.nsfw = True + else: + if not hasattr(func, '__commands_checks__'): + func.__commands_checks__ = [] + + func.__commands_checks__.append(predicate) + func.__discord_app_commands_is_nsfw__ = True + + return func + + if inspect.iscoroutinefunction(predicate): + decorator.predicate = predicate + else: + + @functools.wraps(predicate) + async def wrapper(ctx: Context[BotT]): + return predicate(ctx) + + decorator.predicate = wrapper + + return decorator # type: ignore -def cooldown(rate, per, type=BucketType.default): + +def cooldown( + rate: int, + per: float, + type: Union[BucketType, Callable[[Context[Any]], Any]] = BucketType.default, +) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` - or its subclasses. A cooldown allows a command to only be used a specific amount of times in a specific time frame. These cooldowns can be based - either on a per-guild, per-channel, per-user, or global basis. + either on a per-guild, per-channel, per-user, per-role or global basis. Denoted by the third argument of ``type`` which must be of enum - type ``BucketType`` which could be either: - - - ``BucketType.default`` for a global basis. - - ``BucketType.user`` for a per-user basis. - - ``BucketType.guild`` for a per-guild basis. - - ``BucketType.channel`` for a per-channel basis. - - ``BucketType.member`` for a per-member basis. - - ``BucketType.category`` for a per-category basis. + type :class:`.BucketType`. If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in :func:`.on_command_error` and the local error handler. @@ -1614,14 +2481,180 @@ def cooldown(rate, per, type=BucketType.default): The number of times a command can be used before triggering a cooldown. per: :class:`float` The amount of seconds to wait for a cooldown when it's been triggered. - type: ``BucketType`` + type: Union[:class:`.BucketType`, Callable[[:class:`.Context`], Any]] + The type of cooldown to have. If callable, should return a key for the mapping. + + .. versionchanged:: 1.7 + Callables are now supported for custom bucket types. + + .. versionchanged:: 2.0 + When passing a callable, it now needs to accept :class:`.Context` + rather than :class:`~discord.Message` as its only argument. + """ + + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + if isinstance(func, Command): + func._buckets = CooldownMapping(Cooldown(rate, per), type) + else: + func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) + return func + + return decorator # type: ignore + + +def dynamic_cooldown( + cooldown: Callable[[Context[Any]], Optional[Cooldown]], + type: Union[BucketType, Callable[[Context[Any]], Any]], +) -> Callable[[T], T]: + """A decorator that adds a dynamic cooldown to a :class:`.Command` + + This differs from :func:`.cooldown` in that it takes a function that + accepts a single parameter of type :class:`.Context` and must + return a :class:`~discord.app_commands.Cooldown` or ``None``. + If ``None`` is returned then that cooldown is effectively bypassed. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-guild, per-channel, per-user, per-role or global basis. + Denoted by the third argument of ``type`` which must be of enum + type :class:`.BucketType`. + + If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in + :func:`.on_command_error` and the local error handler. + + A command can only have a single cooldown. + + .. versionadded:: 2.0 + + Parameters + ------------ + cooldown: Callable[[:class:`.Context`], Optional[:class:`~discord.app_commands.Cooldown`]] + A function that takes a message and returns a cooldown that will + apply to this invocation or ``None`` if the cooldown should be bypassed. + type: :class:`.BucketType` The type of cooldown to have. """ + if not callable(cooldown): + raise TypeError('A callable must be provided') - def decorator(func): + if type is BucketType.default: + raise ValueError('BucketType.default cannot be used in dynamic cooldowns') + + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): - func._buckets = CooldownMapping(Cooldown(rate, per, type)) + func._buckets = DynamicCooldownMapping(cooldown, type) else: - func.__commands_cooldown__ = Cooldown(rate, per, type) + func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func - return decorator + + return decorator # type: ignore + + +def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: + """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. + + This enables you to only allow a certain number of command invocations at the same time, + for example if a command takes too long or if only one user can use it at a time. This + differs from a cooldown in that there is no set waiting period or token bucket -- only + a set number of people can run the command. + + .. versionadded:: 1.3 + + Parameters + ------------- + number: :class:`int` + The maximum number of invocations of this command that can be running at the same time. + per: :class:`.BucketType` + The bucket that this concurrency is based on, e.g. ``BucketType.guild`` would allow + it to be used up to ``number`` times per guild. + wait: :class:`bool` + Whether the command should wait for the queue to be over. If this is set to ``False`` + then instead of waiting until the command can run again, the command raises + :exc:`.MaxConcurrencyReached` to its error handler. If this is set to ``True`` + then the command waits until it can be executed. + """ + + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: + value = MaxConcurrency(number, per=per, wait=wait) + if isinstance(func, Command): + func._max_concurrency = value + else: + func.__commands_max_concurrency__ = value + return func + + return decorator # type: ignore + + +def before_invoke(coro: Hook[CogT, ContextT], /) -> Callable[[T], T]: + """A decorator that registers a coroutine as a pre-invoke hook. + + This allows you to refer to one before invoke hook for several commands that + do not have to be within the same cog. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 + + ``coro`` parameter is now positional-only. + + Example + --------- + + .. code-block:: python3 + + async def record_usage(ctx): + print(ctx.author, 'used', ctx.command, 'at', ctx.message.created_at) + + @bot.command() + @commands.before_invoke(record_usage) + async def who(ctx): # Output: used who at