diff --git a/.github/workflows/docs-json-export.yml b/.github/workflows/docs-json-export.yml index eb46d1b1b3..618df46fc7 100644 --- a/.github/workflows/docs-json-export.yml +++ b/.github/workflows/docs-json-export.yml @@ -26,7 +26,7 @@ jobs: id: install-deps run: | python -m pip install -U pip - pip install ".[docs]" + pip install ".[docs,voice]" pip install beautifulsoup4 - name: Build Sphinx HTML docs id: build-sphinx diff --git a/discord/__init__.py b/discord/__init__.py index dbd87452b2..92f3e7a215 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -25,6 +25,7 @@ from . import abc, opus, sinks, ui, utils +from ._voice_aliases import * from .activity import * from .appinfo import * from .application_role_connection import * @@ -72,7 +73,6 @@ from .template import * from .threads import * from .user import * -from .voice_client import * from .webhook import * from .welcome_screen import * from .widget import * diff --git a/discord/_voice_aliases.py b/discord/_voice_aliases.py new file mode 100644 index 0000000000..9054e4a04a --- /dev/null +++ b/discord/_voice_aliases.py @@ -0,0 +1,74 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .utils import deprecated + +""" +since discord.voice raises an error when importing it without having the +required package (ie davey) installed, we can't import it in __init__ because +that would break the whole library, that is why this file is here. + +the error would still be raised, but at least here we have more freedom on how we are typing it +""" + +__all__ = ("VoiceProtocol", "VoiceClient") + + +if TYPE_CHECKING: + from typing_extensions import deprecated + + from discord.voice import VoiceClientC, VoiceProtocolC + + @deprecated( + "discord.VoiceClient is deprecated in favour " + "of discord.voice.VoiceClient since 2.7 and " + "will be removed in 3.0", + ) + def VoiceClient(client, channel) -> VoiceClientC: ... + + @deprecated( + "discord.VoiceProtocol is deprecated in favour " + "of discord.voice.VoiceProtocol since 2.7 and " + "will be removed in 3.0", + ) + def VoiceProtocol(client, channel) -> VoiceProtocolC: ... + +else: + + @deprecated("discord.VoiceClient", "discord.voice.VoiceClient", "2.7", "3.0") + def VoiceClient(client, channel): + from discord.voice import VoiceClient + + return VoiceClient(client, channel) + + @deprecated("discord.VoiceProtocol", "discord.voice.VoiceProtocol", "2.7", "3.0") + def VoiceProtocol(client, channel): + from discord.voice import VoiceProtocol + + return VoiceProtocol(client, channel) diff --git a/discord/abc.py b/discord/abc.py index 1fd36948e1..2552b23a7d 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -55,7 +55,7 @@ from .role import Role from .scheduled_events import ScheduledEvent from .sticker import GuildSticker, StickerItem -from .voice_client import VoiceClient, VoiceProtocol +from .voice import VoiceClient, VoiceProtocol __all__ = ( "Snowflake", @@ -1966,6 +1966,7 @@ class Connectable(Protocol): __slots__ = () _state: ConnectionState + id: int def _get_voice_client_key(self) -> tuple[int, str]: raise NotImplementedError diff --git a/discord/channel.py b/discord/channel.py index e5e55f360c..1bf5297176 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -1624,7 +1624,7 @@ def _update( self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload ) -> None: # This data will always exist - self.guild = guild + self.guild: Guild = guild self.name: str = data["name"] self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") diff --git a/discord/client.py b/discord/client.py index ed6cdb8991..97c5f6d557 100644 --- a/discord/client.py +++ b/discord/client.py @@ -27,7 +27,6 @@ import asyncio import logging -import signal import sys import traceback from types import TracebackType @@ -51,7 +50,7 @@ from .invite import Invite from .iterators import EntitlementIterator, GuildIterator from .mentions import AllowedMentions -from .monetization import SKU, Entitlement +from .monetization import SKU from .object import Object from .soundboard import SoundboardSound from .stage_instance import StageInstance @@ -62,7 +61,7 @@ from .ui.view import View from .user import ClientUser, User from .utils import MISSING -from .voice_client import VoiceClient +from .voice import VoiceClient from .webhook import Webhook from .widget import Widget @@ -75,7 +74,7 @@ from .poll import Poll from .soundboard import SoundboardSound from .ui.item import Item - from .voice_client import VoiceProtocol + from .voice import VoiceProtocol __all__ = ("Client",) @@ -807,7 +806,12 @@ async def start(self, token: str, *, reconnect: bool = True) -> None: await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args: Any, **kwargs: Any) -> None: + def run( + self, + token: str, + *, + reconnect: bool = True, + ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -831,41 +835,18 @@ def run(self, *args: Any, **kwargs: Any) -> None: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.add_signal_handler(signal.SIGTERM, loop.stop) - except (NotImplementedError, RuntimeError): - pass async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() + async with self: + await self.start(token, reconnect=reconnect) - def stop_loop_on_completion(f): - loop.stop() - - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() + asyncio.run(runner()) except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") - finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) - - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # nothing to do here + # `asyncio.run` handles the loop cleanup + # and `self.start` closes all sockets and the HTTPClient instance. + return # properties diff --git a/discord/commands/context.py b/discord/commands/context.py index 73a6b39a45..d8748b9098 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -48,7 +48,7 @@ from ..permissions import Permissions from ..state import ConnectionState from ..user import User - from ..voice_client import VoiceClient + from ..voice import VoiceClient from ..webhook import WebhookMessage from .core import ApplicationCommand, Option diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index afd023d351..11e2216fce 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -41,7 +41,7 @@ from discord.member import Member from discord.state import ConnectionState from discord.user import ClientUser, User - from discord.voice_client import VoiceProtocol + from discord.voice import VoiceProtocol from .bot import AutoShardedBot, Bot from .cog import Cog diff --git a/discord/gateway.py b/discord/gateway.py index a24581ed78..b9b28f887d 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -28,29 +28,31 @@ import asyncio import concurrent.futures import logging -import struct import sys import threading import time import traceback import zlib -from collections import deque, namedtuple -from typing import TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, NamedTuple import aiohttp from . import utils from .activity import BaseActivity -from .enums import SpeakingState from .errors import ConnectionClosed, InvalidArgument +if TYPE_CHECKING: + from typing_extensions import Self + + from .client import Client + from .state import ConnectionState + _log = logging.getLogger(__name__) __all__ = ( "DiscordWebSocket", "KeepAliveHandler", - "VoiceKeepAliveHandler", - "DiscordVoiceWebSocket", "ReconnectWebSocket", ) @@ -68,26 +70,30 @@ class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" -EventListener = namedtuple("EventListener", "predicate event result future") +class EventListener(NamedTuple): + predicate: Callable[[dict[str, Any]], bool] + event: str + result: Callable[[dict[str, Any]], Any] | None + future: asyncio.Future[Any] class GatewayRatelimiter: - def __init__(self, count=110, per=60.0): + def __init__(self, count: int = 110, per: float = 60.0): # The default is 110 to give room for at least 10 heartbeats per minute - self.max = count - self.remaining = count - self.window = 0.0 - self.per = per - self.lock = asyncio.Lock() - self.shard_id = None - - def is_ratelimited(self): + self.max: int = count + self.remaining: int = count + self.window: float = 0.0 + self.per: float = per + self.lock: asyncio.Lock = asyncio.Lock() + self.shard_id: int | None = None + + def is_ratelimited(self) -> bool: current = time.time() if current > self.window + self.per: return False return self.remaining == 0 - def get_delay(self): + def get_delay(self) -> float: current = time.time() if current > self.window + self.per: @@ -105,7 +111,7 @@ def get_delay(self): return 0.0 - async def block(self): + async def block(self) -> None: async with self.lock: delta = self.get_delay() if delta: @@ -118,15 +124,25 @@ async def block(self): class KeepAliveHandler(threading.Thread): - def __init__(self, *args, **kwargs): - ws = kwargs.pop("ws", None) - interval = kwargs.pop("interval", None) - shard_id = kwargs.pop("shard_id", None) - threading.Thread.__init__(self, *args, **kwargs) - self.ws = ws + def __init__( + self, + *args: Any, + ws: DiscordWebSocket, + shard_id: int | None = None, + interval: float | None = None, + **kwargs: Any, + ) -> None: + daemon: bool = kwargs.pop("daemon", True) + name: str = kwargs.pop("name", f"keep-alive-handler:shard-{shard_id}") + super().__init__( + *args, + **kwargs, + daemon=daemon, + name=name, + ) + self.ws: DiscordWebSocket = ws self._main_thread_id = ws.thread_id self.interval = interval - self.daemon = True self.shard_id = shard_id self.msg = "Keeping shard ID %s websocket alive with sequence %s." self.block_msg = "Shard ID %s heartbeat blocked for more than %s seconds." @@ -138,7 +154,7 @@ def __init__(self, *args, **kwargs): self.latency = float("inf") self.heartbeat_timeout = ws._max_heartbeat_timeout - def run(self): + def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning( @@ -191,16 +207,16 @@ def run(self): else: self._last_send = time.perf_counter() - def get_payload(self): + def get_payload(self) -> dict[str, Any]: return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence} - def stop(self): + def stop(self) -> None: self._stop_ev.set() - def tick(self): + def tick(self) -> None: self._last_recv = time.perf_counter() - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self.latency = ack_time - self._last_send @@ -208,31 +224,6 @@ def ack(self): _log.warning(self.behind_msg, self.shard_id, self.latency) -class VoiceKeepAliveHandler(KeepAliveHandler): - if TYPE_CHECKING: - ws: DiscordVoiceWebSocket - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.recent_ack_latencies = deque(maxlen=20) - self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s." - self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds" - self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind" - - def get_payload(self): - return { - "op": self.ws.HEARTBEAT, - "d": {"t": int(time.time() * 1000), "seq_ack": self.ws.seq_ack}, - } - - def ack(self): - ack_time = time.perf_counter() - self._last_ack = ack_time - self._last_recv = ack_time - self.latency = ack_time - self._last_send - self.recent_ack_latencies.append(self.latency) - - class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: return await super().close(code=code, message=message) @@ -293,52 +284,68 @@ class DiscordWebSocket: GUILD_SYNC = 12 REQUEST_SOUNDBOARD_SOUNDS = 31 - def __init__(self, socket, *, loop): - self.socket = socket - self.loop = loop + if TYPE_CHECKING: + token: str | None + _connection: ConnectionState + _discord_parsers: dict[str, Callable[..., Any]] + call_hooks: Callable[..., Any] + gateway: str + _initial_identify: bool + shard_id: int | None + shard_count: int | None + _max_heartbeat_timeout: float + + def __init__( + self, + socket: aiohttp.ClientWebSocketResponse, + *, + loop: asyncio.AbstractEventLoop, + ) -> None: + self.socket: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes - self._dispatch = lambda *args: None + self._dispatch: Callable[..., Any] = lambda *args: None # generic event listeners - self._dispatch_listeners = [] + self._dispatch_listeners: list[EventListener] = [] # the keep alive - self._keep_alive = None - self.thread_id = threading.get_ident() + self._keep_alive: KeepAliveHandler | None = None + self.thread_id: int = threading.get_ident() # ws related stuff - self.session_id = None - self.sequence = None - self.resume_gateway_url = None - self._zlib = zlib.decompressobj() - self._buffer = bytearray() - self._close_code = None - self._rate_limiter = GatewayRatelimiter() + self.session_id: str | None = None + self.sequence: int | None = None + self.resume_gateway_url: str | None = None + self._zlib: zlib._Decompress = zlib.decompressobj() + self._buffer: bytearray = bytearray() + self._close_code: int | None = None + self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @property - def open(self): + def open(self) -> bool: return not self.socket.closed - def is_ratelimited(self): + def is_ratelimited(self) -> bool: return self._rate_limiter.is_ratelimited() - def debug_log_receive(self, data, /): + def debug_log_receive(self, data: dict[str, Any], /) -> None: self._dispatch("socket_raw_receive", data) - def log_receive(self, _, /): + def log_receive(self, _: dict[str, Any], /) -> None: pass @classmethod async def from_client( cls, - client, + client: Client, *, - initial=False, - gateway=None, - shard_id=None, - session=None, - sequence=None, - resume=False, - ): + initial: bool = False, + gateway: str | None = None, + shard_id: int | None = None, + session: str | None = None, + sequence: int | None = None, + resume: bool = False, + ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -380,7 +387,12 @@ async def from_client( await ws.resume() return ws - def wait_for(self, event, predicate, result=None): + def wait_for( + self, + event: str, + predicate: Callable[[dict[str, Any]], bool], + result: Callable[[dict[str, Any]], Any] | None = None, + ) -> asyncio.Future[Any]: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -407,7 +419,7 @@ def wait_for(self, event, predicate, result=None): self._dispatch_listeners.append(entry) return future - async def identify(self): + async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { "op": self.IDENTIFY, @@ -420,7 +432,6 @@ async def identify(self): }, "compress": True, "large_threshold": 250, - "v": 3, }, } @@ -445,7 +456,7 @@ async def identify(self): await self.send_as_json(payload) _log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) - async def resume(self): + async def resume(self) -> None: """Sends the RESUME packet.""" payload = { "op": self.RESUME, @@ -459,7 +470,7 @@ async def resume(self): await self.send_as_json(payload) _log.info("Shard ID %s has sent the RESUME payload.", self.shard_id) - async def received_message(self, msg, /): + async def received_message(self, msg: Any, /): if type(msg) is bytes: self._buffer.extend(msg) @@ -595,7 +606,7 @@ def latency(self) -> float: heartbeat = self._keep_alive return float("inf") if heartbeat is None else heartbeat.latency - def _can_handle_close(self): + def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code is_improper_close = self._close_code is None and self.socket.close_code == 1000 return is_improper_close or code not in ( @@ -608,7 +619,7 @@ def _can_handle_close(self): 4014, ) - async def poll_event(self): + async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. Raises @@ -622,11 +633,12 @@ async def poll_event(self): await self.received_message(msg.data) elif msg.type is aiohttp.WSMsgType.BINARY: await self.received_message(msg.data) + elif msg.type is aiohttp.WSMsgType.ERROR: + _log.debug("Received an error %s", msg) elif msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.ERROR, ): _log.debug("Received %s", msg) raise WebSocketClosure @@ -650,23 +662,23 @@ async def poll_event(self): self.socket, shard_id=self.shard_id, code=code ) from None - async def debug_send(self, data, /): + async def debug_send(self, data: str, /) -> None: await self._rate_limiter.block() self._dispatch("socket_raw_send", data) await self.socket.send_str(data) - async def send(self, data, /): + async def send(self, data: str, /) -> None: await self._rate_limiter.block() await self.socket.send_str(data) - async def send_as_json(self, data): + async def send_as_json(self, data: Any) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def send_heartbeat(self, data): + async def send_heartbeat(self, data: Any) -> None: # This bypasses the rate limit handling code since it has a higher priority try: await self.socket.send_str(utils._to_json(data)) @@ -674,13 +686,19 @@ async def send_heartbeat(self, data): if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity=None, status=None, since=0.0): + async def change_presence( + self, + *, + activity: BaseActivity | None = None, + status: str | None = None, + since: float = 0.0, + ) -> None: if activity is not None: if not isinstance(activity, BaseActivity): raise InvalidArgument("activity must derive from BaseActivity.") - activity = [activity.to_dict()] + activities = [activity.to_dict()] else: - activity = [] + activities = [] if status == "idle": since = int(time.time() * 1000) @@ -688,7 +706,7 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): payload = { "op": self.PRESENCE, "d": { - "activities": activity, + "activities": activities, "afk": False, "since": since, "status": status, @@ -700,8 +718,15 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): await self.send(sent) async def request_chunks( - self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None - ): + self, + guild_id: int, + query: str | None = None, + *, + limit: int, + user_ids: list[int] | None = None, + presences: bool = False, + nonce: str | None = None, + ) -> None: payload = { "op": self.REQUEST_MEMBERS, "d": {"guild_id": guild_id, "presences": presences, "limit": limit}, @@ -718,7 +743,13 @@ async def request_chunks( await self.send_as_json(payload) - async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + async def voice_state( + self, + guild_id: int, + channel_id: int, + self_mute: bool = False, + self_deaf: bool = False, + ) -> None: payload = { "op": self.VOICE_STATE, "d": { @@ -741,261 +772,10 @@ async def request_soundboard_sounds(self, guild_ids): _log.debug("Requesting soundboard sounds for guilds %s.", guild_ids) await self.send_as_json(payload) - async def close(self, code=4000): + async def close(self, code: int = 4000) -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None self._close_code = code await self.socket.close(code=code) - - -class DiscordVoiceWebSocket: - """Implements the WebSocket protocol for handling voice connections. - - Attributes - ---------- - IDENTIFY - Send only. Starts a new voice session. - SELECT_PROTOCOL - Send only. Tells discord what encryption mode and how to connect for voice. - READY - Receive only. Tells the websocket that the initial connection has completed. - HEARTBEAT - Send only. Keeps your websocket connection alive. - SESSION_DESCRIPTION - Receive only. Gives you the secret key required for voice. - SPEAKING - Send only. Notifies the client if you are currently speaking. - HEARTBEAT_ACK - Receive only. Tells you your heartbeat has been acknowledged. - RESUME - Sent only. Tells the client to resume its session. - HELLO - Receive only. Tells you that your websocket connection was acknowledged. - RESUMED - Sent only. Tells you that your RESUME request has succeeded. - CLIENT_CONNECT - Indicates a user has connected to voice. - CLIENT_DISCONNECT - Receive only. Indicates a user has disconnected from voice. - """ - - IDENTIFY = 0 - SELECT_PROTOCOL = 1 - READY = 2 - HEARTBEAT = 3 - SESSION_DESCRIPTION = 4 - SPEAKING = 5 - HEARTBEAT_ACK = 6 - RESUME = 7 - HELLO = 8 - RESUMED = 9 - CLIENT_CONNECT = 12 - CLIENT_DISCONNECT = 13 - - def __init__(self, socket, loop, *, hook=None): - self.ws = socket - self.loop = loop - self._keep_alive = None - self._close_code = None - self.secret_key = None - self.ssrc_map = {} - self.seq_ack: int = -1 - if hook: - self._hook = hook - - async def _hook(self, *args): - pass - - async def send_as_json(self, data): - _log.debug("Sending voice websocket frame: %s.", data) - await self.ws.send_str(utils._to_json(data)) - - send_heartbeat = send_as_json - - async def resume(self): - state = self._connection - payload = { - "op": self.RESUME, - "d": { - "token": state.token, - "server_id": str(state.server_id), - "session_id": state.session_id, - # this seq_ack will allow for us to do buffered resume, which is, receive the - # lost voice packets while trying to resume the reconnection - "seq_ack": self.seq_ack, - }, - } - await self.send_as_json(payload) - - async def identify(self): - state = self._connection - payload = { - "op": self.IDENTIFY, - "d": { - "server_id": str(state.server_id), - "user_id": str(state.user.id), - "session_id": state.session_id, - "token": state.token, - }, - } - await self.send_as_json(payload) - - @classmethod - async def from_client(cls, client, *, resume=False, hook=None): - """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = f"wss://{client.endpoint}/?v=8" - http = client._state.http - socket = await http.ws_connect(gateway, compress=15) - ws = cls(socket, loop=client.loop, hook=hook) - ws.gateway = gateway - ws._connection = client - ws._max_heartbeat_timeout = 60.0 - ws.thread_id = threading.get_ident() - - if resume: - await ws.resume() - else: - await ws.identify() - - return ws - - async def select_protocol(self, ip, port, mode): - payload = { - "op": self.SELECT_PROTOCOL, - "d": { - "protocol": "udp", - "data": {"address": ip, "port": port, "mode": mode}, - }, - } - - await self.send_as_json(payload) - - async def client_connect(self): - payload = { - "op": self.CLIENT_CONNECT, - "d": {"audio_ssrc": self._connection.ssrc}, - } - - await self.send_as_json(payload) - - async def speak(self, state=SpeakingState.voice): - payload = { - "op": self.SPEAKING, - "d": { - "speaking": int(state), - "delay": 0, - }, - } - - await self.send_as_json(payload) - - async def received_message(self, msg): - _log.debug("Voice websocket frame received: %s", msg) - op = msg["op"] - data = msg.get("d") - self.seq_ack = data.get("seq", self.seq_ack) - - if op == self.READY: - await self.initial_connection(data) - elif op == self.HEARTBEAT_ACK: - self._keep_alive.ack() - elif op == self.RESUMED: - _log.info("Voice RESUME succeeded.") - elif op == self.SESSION_DESCRIPTION: - self._connection.mode = data["mode"] - await self.load_secret_key(data) - elif op == self.HELLO: - interval = data["heartbeat_interval"] / 1000.0 - self._keep_alive = VoiceKeepAliveHandler( - ws=self, interval=min(interval, 5.0) - ) - self._keep_alive.start() - - elif op == self.SPEAKING: - ssrc = data["ssrc"] - user = int(data["user_id"]) - speaking = data["speaking"] - if ssrc in self.ssrc_map: - self.ssrc_map[ssrc]["speaking"] = speaking - else: - self.ssrc_map.update({ssrc: {"user_id": user, "speaking": speaking}}) - - await self._hook(self, msg) - - async def initial_connection(self, data): - state = self._connection - state.ssrc = data["ssrc"] - state.voice_port = data["port"] - state.endpoint_ip = data["ip"] - - packet = bytearray(74) - struct.pack_into(">H", packet, 0, 1) # 1 = Send - struct.pack_into(">H", packet, 2, 70) # 70 = Length - struct.pack_into(">I", packet, 4, state.ssrc) - state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) - recv = await self.loop.sock_recv(state.socket, 74) - _log.debug("received packet in initial_connection: %s", recv) - - # the ip is ascii starting at the 8th byte and ending at the first null - ip_start = 8 - ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode("ascii") - - state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0] - _log.debug("detected ip: %s port: %s", state.ip, state.port) - - # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [ - mode for mode in data["modes"] if mode in self._connection.supported_modes - ] - _log.debug("received supported encryption modes: %s", ", ".join(modes)) - - mode = modes[0] - await self.select_protocol(state.ip, state.port, mode) - _log.info("selected the voice protocol for use (%s)", mode) - - @property - def latency(self) -> float: - """Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" - heartbeat = self._keep_alive - return float("inf") if heartbeat is None else heartbeat.latency - - @property - def average_latency(self) -> list[float] | float: - """Average of last 20 HEARTBEAT latencies.""" - heartbeat = self._keep_alive - if heartbeat is None or not heartbeat.recent_ack_latencies: - return float("inf") - - return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) - - async def load_secret_key(self, data): - _log.info("received secret key for voice connection") - self.secret_key = self._connection.secret_key = data.get("secret_key") - await self.speak() - await self.speak(False) - - async def poll_event(self): - # This exception is handled up the chain - msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) - if msg.type is aiohttp.WSMsgType.TEXT: - await self.received_message(utils._from_json(msg.data)) - elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug("Received %s", msg) - raise ConnectionClosed(self.ws, shard_id=None) from msg.data - elif msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - _log.debug("Received %s", msg) - raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) - - async def close(self, code=1000): - if self._keep_alive is not None: - self._keep_alive.stop() - - self._close_code = code - await self.ws.close(code=code) diff --git a/discord/guild.py b/discord/guild.py index ff4c9e04b2..d5211d89ec 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -115,8 +115,8 @@ from .types.guild import GuildFeature, MFALevel from .types.member import Member as MemberPayload from .types.threads import Thread as ThreadPayload - from .types.voice import GuildVoiceState - from .voice_client import VoiceClient + from .types.voice import VoiceState as GuildVoiceState + from .voice import VoiceClient from .webhook import Webhook VocalGuildChannel = Union[VoiceChannel, StageChannel] diff --git a/discord/member.py b/discord/member.py index 0ff90cce04..498f3322fd 100644 --- a/discord/member.py +++ b/discord/member.py @@ -30,7 +30,7 @@ import itertools import sys from operator import attrgetter -from typing import TYPE_CHECKING, Any, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar import discord.abc @@ -52,7 +52,8 @@ if TYPE_CHECKING: from .abc import Snowflake - from .channel import DMChannel, StageChannel, VoiceChannel + from .channel import DMChannel, VocalGuildChannel + from .client import Client from .flags import PublicUserFlags from .guild import Guild from .message import Message @@ -63,11 +64,9 @@ from .types.member import MemberWithUser as MemberWithUserPayload from .types.member import UserWithMember as UserWithMemberPayload from .types.user import User as UserPayload - from .types.voice import GuildVoiceState as GuildVoiceStatePayload + from .types.voice import VoiceState as GuildVoiceStatePayload from .types.voice import VoiceState as VoiceStatePayload - VocalGuildChannel = Union[VoiceChannel, StageChannel] - class VoiceState: """Represents a Discord user's voice state. @@ -165,6 +164,20 @@ def __repr__(self) -> str: inner = " ".join("%s=%r" % t for t in attrs) return f"<{self.__class__.__name__} {inner}>" + @classmethod + def _create_default(cls, channel: VocalGuildChannel, client: Client) -> VoiceState: + self = cls( + data={ + "channel_id": channel.id, + "guild_id": channel.guild.id, + "self_deaf": False, + "self_mute": False, + "user_id": client._connection.self_id, # type: ignore + }, + channel=channel, + ) + return self + def flatten_user(cls): for attr, value in itertools.chain( diff --git a/discord/opus.py b/discord/opus.py index 6ea6f84308..45a09d0573 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -28,21 +28,32 @@ import array import ctypes import ctypes.util -import gc import logging import math import os.path import struct import sys -import threading -import time from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar +import davey + +from discord.voice.packets.rtp import FakePacket +from discord.voice.utils.buffer import JitterBuffer +from discord.voice.utils.wrapped import add_wrapped + from .errors import DiscordException -from .sinks import RawData if TYPE_CHECKING: + from discord.member import Member + from discord.sinks.core import Sink + from discord.user import User + from discord.voice.client import VoiceClient + from discord.voice.packets import VoiceData + from discord.voice.packets.core import Packet + from discord.voice.receive.router import PacketRouter + T = TypeVar("T") + APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] SIGNAL_CTL = Literal["auto", "voice", "music"] @@ -61,6 +72,12 @@ class SignalCtl(TypedDict): music: int +class ApplicationCtl(TypedDict): + audio: int + voip: int + lowdelay: int + + __all__ = ( "Encoder", "Decoder", @@ -74,6 +91,7 @@ class SignalCtl(TypedDict): c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) +OPUS_SILENCE = b"\xf8\xff\xfe" _lib = None @@ -94,11 +112,19 @@ class DecoderStruct(ctypes.Structure): # Error codes OK = 0 BAD_ARG = -1 +BUFF_TOO_SMALL = -2 +INTERNAL_ERROR = -3 +INVALID_PACKET = -4 +UNIMPLEMENTED = -5 +INVALID_STATE = -6 +ALLOC_FAIL = -7 # Encoder CTLs -APPLICATION_AUDIO = 2049 -APPLICATION_VOIP = 2048 -APPLICATION_LOWDELAY = 2051 +application_ctl: ApplicationCtl = { + "audio": 2049, + "lowdelay": 2051, + "voip": 2048, +} CTL_SET_BITRATE = 4002 CTL_SET_BANDWIDTH = 4008 @@ -336,9 +362,9 @@ class OpusError(DiscordException): The error code returned. """ - def __init__(self, code: int): + def __init__(self, code: int = 0, message: str | None = None): self.code: int = code - msg = _lib.opus_strerror(self.code).decode("utf-8") + msg = message or _lib.opus_strerror(self.code).decode("utf-8") _log.info('"%s" has happened', msg) super().__init__(msg) @@ -365,16 +391,35 @@ def get_opus_version() -> str: class Encoder(_OpusStruct): - def __init__(self, application: int = APPLICATION_AUDIO): + def __init__( + self, + *, + application: APPLICATION_CTL = "audio", + bitrate: int = 128, + fec: bool = True, + expected_packet_loss: float = 0.15, + bandwidth: BAND_CTL = "full", + signal_type: SIGNAL_CTL = "auto", + ) -> None: + if application not in application_ctl: + raise ValueError("invalid application ctl type provided") + if not 16 <= bitrate <= 512: + raise ValueError("bitrate must be between 16 and 512, both included") + if not 0 < expected_packet_loss <= 1: + raise ValueError( + "expected_packet_loss must be between 0 and 1, including 1", + ) + _OpusStruct.get_opus_version() - self.application: int = application + self.application: int = application_ctl[application] self._state: EncoderStruct = self._create_state() - self.set_bitrate(128) - self.set_fec(True) - self.set_expected_packet_loss_percent(0.15) - self.set_bandwidth("full") - self.set_signal_type("auto") + + self.set_bitrate(bitrate) + self.set_fec(fec) + self.set_expected_packet_loss_percent(expected_packet_loss) + self.set_bandwidth(bandwidth) + self.set_signal_type(signal_type) def __del__(self) -> None: if hasattr(self, "_state"): @@ -420,12 +465,15 @@ def set_fec(self, enabled: bool = True) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None: _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore - def encode(self, pcm: bytes, frame_size: int) -> bytes: + def encode(self, pcm: bytes, frame_size: int | None = None) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() + if frame_size is None: + frame_size = self.FRAME_SIZE + ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know @@ -433,7 +481,7 @@ def encode(self, pcm: bytes, frame_size: int) -> bytes: class Decoder(_OpusStruct): - def __init__(self): + def __init__(self) -> None: _OpusStruct.get_opus_version() self._state = self._create_state() @@ -492,24 +540,31 @@ def _get_last_packet_duration(self): _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) return ret.value - def decode(self, data, *, fec=False): + def decode(self, data: bytes | None, *, fec: bool = True): if data is None and fec: - raise OpusError("Invalid arguments: FEC cannot be used with null data") + raise OpusError( + message="Invalid arguments: FEC cannot be used with null data" + ) + + channel_count = self.CHANNELS if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME - channel_count = self.CHANNELS else: frames = self.packet_get_nb_frames(data) - channel_count = self.CHANNELS samples_per_frame = self.packet_get_samples_per_frame(data) frame_size = frames * samples_per_frame - pcm = ( - ctypes.c_int16 - * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) - )() + # pcm = ( + # ctypes.c_int16 + # * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) + # )() + pcm = (ctypes.c_int16 * (frame_size * channel_count))() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) + pcm_ptr = ctypes.cast( + pcm, + c_int16_ptr, + ) ret = _lib.opus_decode( self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec @@ -518,58 +573,144 @@ def decode(self, data, *, fec=False): return array.array("h", pcm[: ret * channel_count]).tobytes() -class DecodeManager(threading.Thread, _OpusStruct): - def __init__(self, client): - super().__init__(daemon=True, name="DecodeManager") +class PacketDecoder: + def __init__(self, router: PacketRouter, ssrc: int) -> None: + self.router: PacketRouter = router + self.ssrc: int = ssrc - self.client = client - self.decode_queue = [] + self._decoder: Decoder | None = None if self.sink.is_opus() else Decoder() + self._buffer: JitterBuffer = JitterBuffer() + self._cached_id: int | None = None - self.decoder = {} + self._last_seq: int = -1 + self._last_ts: int = -1 - self._end_thread = threading.Event() + @property + def sink(self) -> Sink: + return self.router.sink - def decode(self, opus_frame): - if not isinstance(opus_frame, RawData): - raise TypeError("opus_frame should be a RawData object.") - self.decode_queue.append(opus_frame) + def _get_user(self, user_id: int) -> User | Member | None: + vc: VoiceClient = self.sink.client # type: ignore + return vc.guild.get_member(user_id) or vc.client.get_user(user_id) - def run(self): - while not self._end_thread.is_set(): - try: - data = self.decode_queue.pop(0) - except IndexError: - time.sleep(0.001) - continue + def _get_cached_member(self) -> User | Member | None: + return self._get_user(self._cached_id) if self._cached_id else None - try: - if data.decrypted_data is None: - continue - else: - data.decoded_data = self.get_decoder(data.ssrc).decode( - data.decrypted_data + def _flag_ready_state(self) -> None: + if self._buffer.peek(): + self.router.waiter.register(self) + else: + self.router.waiter.unregister(self) + + def push_packet(self, packet: Packet) -> None: + self._buffer.push(packet) + self._flag_ready_state() + + def pop_data(self, *, timeout: float = 0) -> VoiceData | None: + packet = self._get_next_packet(timeout) + self._flag_ready_state() + + if packet is None: + return None + return self._process_packet(packet) + + def set_user_id(self, user_id: int) -> None: + self._cached_id = user_id + + def reset(self) -> None: + self._buffer.reset() + self._decoder = None if self.sink.is_opus() else Decoder() + self._last_seq = self._last_ts = -1 + self._flag_ready_state() + + def destroy(self) -> None: + self._buffer.reset() + self._decoder = None + self._flag_ready_state() + + def _get_next_packet(self, timeout: float) -> Packet | None: + packet = self._buffer.pop(timeout=timeout) + + if packet is None: + if self._buffer: + packets = self._buffer.flush() + if any(packets[1:]): + _log.warning( + "%s packets were lost being flushed in decoder-%s", + len(packets) - 1, + self.ssrc, ) - except OpusError: - print("Error occurred while decoding opus frame.") - continue - - self.client.recv_decoded_audio(data) - - def stop(self): - while self.decoding: - time.sleep(0.1) - self.decoder = {} - gc.collect() - print("Decoder Process Killed") - self._end_thread.set() - - def get_decoder(self, ssrc): - d = self.decoder.get(ssrc) - if d is not None: - return d - self.decoder[ssrc] = Decoder() - return self.decoder[ssrc] + return packets[0] + return + elif not packet: + packet = self._make_fakepacket() + return packet - @property - def decoding(self): - return bool(self.decode_queue) + def _make_fakepacket(self) -> FakePacket: + seq = add_wrapped(self._last_seq, 1) + ts = add_wrapped(self._last_ts, Decoder.SAMPLES_PER_FRAME, wrap=2**32) + return FakePacket(self.ssrc, seq, ts) + + def _process_packet(self, packet: Packet) -> VoiceData: + from discord.object import Object + from discord.voice import VoiceData + + assert self.sink.client + + pcm = None + + member = self._get_cached_member() + + if member is None: + self._cached_id = self.sink.client._ssrc_to_id.get(self.ssrc) + member = self._get_cached_member() + else: + self._cached_id = member.id + + # yet still none, use Object + if member is None and self._cached_id: + member = Object(id=self._cached_id) + + if not self.sink.is_opus(): + _log.debug("Decoding packet %s (type %s)", packet, type(packet)) + packet, pcm = self._decode_packet(packet) + + data = VoiceData(packet, member, pcm=pcm) # type: ignore + self._last_seq = packet.sequence + self._last_ts = packet.timestamp + return data + + def _decode_packet(self, packet: Packet) -> tuple[Packet, bytes]: + assert self._decoder is not None + assert self.sink.client + + user_id: int | None = self._cached_id + dave: davey.DaveSession | None = self.sink.client._connection.dave_session + in_dave = dave is not None + + # personally, the best variable + other_code = True + + if packet: + other_code = False + pcm = self._decoder.decode(packet.decrypted_data, fec=False) + + if other_code: + next_packet = self._buffer.peek_next() + + if next_packet is not None: + nextdata: bytes = next_packet.decrypted_data # type: ignore + + _log.debug( + "Generating fec packet: fake=%s, fec=%s", + packet.sequence, + next_packet.sequence, + ) + pcm = self._decoder.decode(nextdata, fec=True) + else: + pcm = self._decoder.decode(None, fec=False) + + if user_id is not None and in_dave and dave.can_passthrough(user_id): + pcm = dave.decrypt(user_id, davey.MediaType.audio, pcm) + + return packet, pcm diff --git a/discord/player.py b/discord/player.py index 65b23ed42a..8876f8df55 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,17 +36,21 @@ import sys import threading import time -import traceback +import warnings from math import floor from typing import IO, TYPE_CHECKING, Any, Callable, Generic, TypeVar +from .enums import SpeakingState from .errors import ClientException from .oggparse import OggStream +from .opus import OPUS_SILENCE from .opus import Encoder as OpusEncoder from .utils import MISSING if TYPE_CHECKING: - from .voice_client import VoiceClient + from typing_extensions import Self + + from .voice import VoiceClient AT = TypeVar("AT", bound="AudioSource") @@ -145,6 +149,8 @@ class FFmpegAudio(AudioSource): .. versionadded:: 1.3 """ + BLOCKSIZE: int = io.DEFAULT_BUFFER_SIZE + def __init__( self, source: str | io.BufferedIOBase, @@ -153,38 +159,69 @@ def __init__( args: Any, **subprocess_kwargs: Any, ): - piping = subprocess_kwargs.get("stdin") == subprocess.PIPE - if piping and isinstance(source, str): + piping_stdin = subprocess_kwargs.get("stdin") == subprocess.PIPE + if piping_stdin and isinstance(source, str): raise TypeError( - "parameter conflict: 'source' parameter cannot be a string when piping" - " to stdin" + "parameter conflict: 'source' parameter cannot be a string when piping to stdin" + ) + + stderr: IO[bytes] | None = subprocess_kwargs.pop("stderr", None) + + if stderr == subprocess.PIPE: + warnings.warn( + "Passing subprocess.PIPE does nothing", DeprecationWarning, stacklevel=3 ) + stderr = None + + piping_stderr = False + if stderr is not None: + try: + stderr.fileno() + except Exception: + piping_stderr = True args = [executable, *args] - kwargs = {"stdout": subprocess.PIPE} + kwargs = { + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE if piping_stderr else stderr, + } kwargs.update(subprocess_kwargs) - self._process: subprocess.Popen = self._spawn_process(args, **kwargs) - self._stdout: IO[bytes] = self._process.stdout # type: ignore + # Ensure attribute is assigned even in the case of errors + self._process: subprocess.Popen = MISSING + self._process = self._spawn_process(args, **kwargs) + self._stdout: IO[bytes] = self._process.stdout # type: ignore # process stdout is explicitly set self._stdin: IO[bytes] | None = None - self._pipe_thread: threading.Thread | None = None + self._stderr: IO[bytes] | None = None + self._pipe_writer_thread: threading.Thread | None = None + self._pipe_reader_thread: threading.Thread | None = None - if piping: - n = f"popen-stdin-writer:{id(self):#x}" + if piping_stdin: + n = f"popen-stdin-writer:pid-{self._process.pid}" self._stdin = self._process.stdin - self._pipe_thread = threading.Thread( + self._pipe_writer_thread = threading.Thread( target=self._pipe_writer, args=(source,), daemon=True, name=n ) - self._pipe_thread.start() + self._pipe_writer_thread.start() + + if piping_stderr: + n = f"popen-stderr-reader:pid-{self._process.pid}" + self._stderr = self._process.stderr + self._pipe_reader_thread = threading.Thread( + target=self._pipe_reader, args=(stderr,), daemon=True, name=n + ) + self._pipe_reader_thread.start() def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: + _log.debug("Spawning ffmpeg process with command: %s", args) + process = None try: process = subprocess.Popen( args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs ) except FileNotFoundError: executable = args.partition(" ")[0] if isinstance(args, str) else args[0] - raise ClientException(f"{executable} was not found.") from None + raise ClientException(executable + " was not found.") from None except subprocess.SubprocessError as exc: raise ClientException( f"Popen failed: {exc.__class__.__name__}: {exc}" @@ -193,11 +230,12 @@ def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Pope return process def _kill_process(self) -> None: - proc = self._process + # this function gets called in __del__ so instance attributes might not even exist + proc = getattr(self, "_process", MISSING) if proc is MISSING: return - _log.info("Preparing to terminate ffmpeg process %s.", proc.pid) + _log.debug("Preparing to terminate ffmpeg process %s.", proc.pid) try: proc.kill() @@ -226,13 +264,14 @@ def _kill_process(self) -> None: def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: - # arbitrarily large read size - data = source.read(8192) + data = source.read(self.BLOCKSIZE) if not data: - self._stdin.close() + if self._stdin is not None: + self._stdin.close() return try: - self._stdin.write(data) + if self._stdin is not None: + self._stdin.write(data) except Exception: _log.debug( "Write error for %s, this is probably not a problem", @@ -243,9 +282,31 @@ def _pipe_writer(self, source: io.BufferedIOBase) -> None: self._process.terminate() return + def _pipe_reader(self, dest: IO[bytes]) -> None: + while self._process: + if self._stderr is None: + return + try: + data: bytes = self._stderr.read(self.BLOCKSIZE) + except Exception: + _log.debug( + "Read error for %s, this is probably not a problem", + self, + exc_info=True, + ) + return + if data is None: + return + try: + dest.write(data) + except Exception: + _log.exception("Write error for %s", self) + self._stderr.close() + return + def cleanup(self) -> None: self._kill_process() - self._process = self._stdout = self._stdin = MISSING + self._process = self._stdout = self._stdin = self._stderr = MISSING class FFmpegPCMAudio(FFmpegAudio): @@ -266,12 +327,17 @@ class FFmpegPCMAudio(FFmpegAudio): passed to the stdin of ffmpeg. executable: :class:`str` The executable name (and path) to use. Defaults to ``ffmpeg``. + + .. warning:: + + Since this class spawns a subprocess, care should be taken to not + pass in an arbitrary executable name when using this parameter. + pipe: :class:`bool` If ``True``, denotes that ``source`` parameter will be passed to the stdin of ffmpeg. Defaults to ``False``. stderr: Optional[:term:`py:file object`] A file-like object to pass to the Popen constructor. - Could also be an instance of ``subprocess.PIPE``. before_options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg before the ``-i`` flag. options: Optional[:class:`str`] @@ -289,7 +355,7 @@ def __init__( *, executable: str = "ffmpeg", pipe: bool = False, - stderr: IO[str] | None = None, + stderr: IO[bytes] | None = None, before_options: str | None = None, options: str | None = None, ) -> None: @@ -304,7 +370,21 @@ def __init__( args.append("-i") args.append("-" if pipe else source) - args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning")) + + args.extend( + ( + "-f", + "s16le", + "-ar", + "48000", + "-ac", + "2", + "-loglevel", + "warning", + "-blocksize", + str(self.BLOCKSIZE), + ) + ) if isinstance(options, str): args.extend(shlex.split(options)) @@ -356,9 +436,8 @@ class FFmpegOpusAudio(FFmpegAudio): The codec to use to encode the audio data. Normally this would be just ``libopus``, but is used by :meth:`FFmpegOpusAudio.from_probe` to opportunistically skip pointlessly re-encoding Opus audio data by passing - ``copy`` as the codec value. Any values other than ``copy``, or - ``libopus`` will be considered ``libopus``. ``opus`` will also be considered - ``libopus`` since the ``opus`` encoder is still in development. Defaults to ``libopus``. + ``copy`` as the codec value. Any values other than ``copy``, ``opus``, or + ``libopus`` will be considered ``libopus``. Defaults to ``libopus``. .. warning:: @@ -373,7 +452,6 @@ class FFmpegOpusAudio(FFmpegAudio): to the stdin of ffmpeg. Defaults to ``False``. stderr: Optional[:term:`py:file object`] A file-like object to pass to the Popen constructor. - Could also be an instance of ``subprocess.PIPE``. before_options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg before the ``-i`` flag. options: Optional[:class:`str`] @@ -389,13 +467,13 @@ def __init__( self, source: str | io.BufferedIOBase, *, - bitrate: int = 128, + bitrate: int | None = None, codec: str | None = None, executable: str = "ffmpeg", - pipe=False, - stderr=None, - before_options=None, - options=None, + pipe: bool = False, + stderr: IO[bytes] | None = None, + before_options: str | None = None, + options: str | None = None, ) -> None: args = [] subprocess_kwargs = { @@ -409,9 +487,8 @@ def __init__( args.append("-i") args.append("-" if pipe else source) - # use "libopus" when "opus" is specified since the "opus" encoder is incomplete - # link to ffmpeg docs: https://www.ffmpeg.org/ffmpeg-codecs.html#opus - codec = "copy" if codec == "copy" else "libopus" + codec = "copy" if codec in ("opus", "libopus", "copy") else "libopus" + bitrate = bitrate if bitrate is not None else 128 args.extend( ( @@ -421,24 +498,23 @@ def __init__( "opus", "-c:a", codec, + "-ar", + "48000", + "-ac", + "2", + "-b:a", + f"{bitrate}k", "-loglevel", "warning", + "-fec", + "true", + "-packet_loss", + "15", + "-blocksize", + str(self.BLOCKSIZE), ) ) - # only pass in bitrate, sample rate, channels arguments when actually encoding to avoid ffmpeg warnings - if codec != "copy": - args.extend( - ( - "-ar", - "48000", - "-ac", - "2", - "-b:a", - f"{bitrate}k", - ) - ) - if isinstance(options, str): args.extend(shlex.split(options)) @@ -449,45 +525,19 @@ def __init__( @classmethod async def from_probe( - cls: type[FT], + cls, source: str, *, method: str | Callable[[str, str], tuple[str | None, int | None]] | None = None, **kwargs: Any, - ) -> FT: - """|coro| + ) -> Self: + r"""|coro| A factory method that creates a :class:`FFmpegOpusAudio` after probing the input source for audio codec and bitrate information. - Parameters - ---------- - source - Identical to the ``source`` parameter for the constructor. - method: Optional[Union[:class:`str`, Callable[:class:`str`, :class:`str`]]] - The probing method used to determine bitrate and codec information. As a string, valid - values are ``native`` to use ffprobe (or avprobe) and ``fallback`` to use ffmpeg - (or avconv). As a callable, it must take two string arguments, ``source`` and - ``executable``. Both parameters are the same values passed to this factory function. - ``executable`` will default to ``ffmpeg`` if not provided as a keyword argument. - kwargs - The remaining parameters to be passed to the :class:`FFmpegOpusAudio` constructor, - excluding ``bitrate`` and ``codec``. - - Returns - ------- - :class:`FFmpegOpusAudio` - An instance of this class. - - Raises - ------ - AttributeError - Invalid probe method, must be ``'native'`` or ``'fallback'``. - TypeError - Invalid value for ``probe`` parameter, must be :class:`str` or a callable. - Examples - -------- + ---------- Use this function to create an :class:`FFmpegOpusAudio` instance instead of the constructor: :: @@ -508,13 +558,37 @@ def custom_probe(source, executable): source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe) voice_client.play(source) + + Parameters + ------------ + source + Identical to the ``source`` parameter for the constructor. + method: Optional[Union[:class:`str`, Callable[:class:`str`, :class:`str`]]] + The probing method used to determine bitrate and codec information. As a string, valid + values are ``native`` to use ffprobe (or avprobe) and ``fallback`` to use ffmpeg + (or avconv). As a callable, it must take two string arguments, ``source`` and + ``executable``. Both parameters are the same values passed to this factory function. + ``executable`` will default to ``ffmpeg`` if not provided as a keyword argument. + \*\*kwargs + The remaining parameters to be passed to the :class:`FFmpegOpusAudio` constructor, + excluding ``bitrate`` and ``codec``. + + Raises + -------- + AttributeError + Invalid probe method, must be ``'native'`` or ``'fallback'``. + TypeError + Invalid value for ``probe`` parameter, must be :class:`str` or a callable. + + Returns + -------- + :class:`FFmpegOpusAudio` + An instance of this class. """ executable = kwargs.get("executable") codec, bitrate = await cls.probe(source, method=method, executable=executable) - # only re-encode if the source isn't already opus, else directly copy the source audio stream - codec = "copy" if codec in ("opus", "libopus") else "libopus" - return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore + return cls(source, bitrate=bitrate, codec=codec, **kwargs) @classmethod async def probe( @@ -539,7 +613,7 @@ async def probe( Returns ------- - Optional[Tuple[Optional[:class:`str`], Optional[:class:`int`]]] + Optional[Tuple[Optional[:class:`str`], :class:`int`]] A 2-tuple with the codec and bitrate of the input source. Raises @@ -572,38 +646,45 @@ async def probe( ) codec = bitrate = None - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: - codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore - except Exception: + codec, bitrate = await loop.run_in_executor( + None, lambda: probefunc(source, executable) + ) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: if not fallback: _log.exception("Probe '%s' using '%s' failed", method, executable) - return # type: ignore + return None, None _log.exception( "Probe '%s' using '%s' failed, trying fallback", method, executable ) try: - codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore - except Exception: + codec, bitrate = await loop.run_in_executor( + None, lambda: fallback(source, executable) + ) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: _log.exception("Fallback probe using '%s' failed", executable) else: - _log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) + _log.debug("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) else: - _log.info("Probe found codec=%s, bitrate=%s", codec, bitrate) - finally: - return codec, bitrate + _log.debug("Probe found codec=%s, bitrate=%s", codec, bitrate) + + return codec, bitrate @staticmethod def _probe_codec_native( source, executable: str = "ffmpeg" ) -> tuple[str | None, int | None]: exe = ( - f"{executable[:2]}probe" - if executable in {"ffmpeg", "avconv"} + executable[:2] + "probe" + if executable in ("ffmpeg", "avconv") else executable ) - args = [ exe, "-v", @@ -631,7 +712,7 @@ def _probe_codec_native( @staticmethod def _probe_codec_fallback( source, executable: str = "ffmpeg" - ) -> tuple[str | None, int | None]: + ) -> Tuple[Optional[str], Optional[int]]: args = [executable, "-hide_banner", "-i", source] proc = subprocess.Popen( args, @@ -721,9 +802,14 @@ def read(self) -> bytes: class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 - def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): - threading.Thread.__init__(self) - self.daemon: bool = True + def __init__( + self, + source: AudioSource, + client: VoiceClient, + *, + after: Callable[[Exception | None], Any] | None = None, + ) -> None: + super().__init__(daemon=True, name=f"audio-player:{id(self):#x}") self.source: AudioSource = source self.client: VoiceClient = client self.after: Callable[[Exception | None], Any] | None = after @@ -732,7 +818,6 @@ def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): self._resumed: threading.Event = threading.Event() self._resumed.set() # we are not paused self._current_error: Exception | None = None - self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() self._played_frames_offset: int = 0 @@ -742,49 +827,52 @@ def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): def _do_run(self) -> None: # attempt to read first audio segment from source before starting # some sources can take a few seconds and may cause problems - first_data = self.source.read() self.loops = 0 self._start = time.perf_counter() # getattr lookup speed ups - play_audio = self.client.send_audio_packet - self._speak(True) + client = self.client + play_audio = client.send_audio_packet + self._speak(SpeakingState.voice) while not self._end.is_set(): # are we paused? if not self._resumed.is_set(): + self.send_silence() # wait until we aren't self._resumed.wait() continue + data = self.source.read() + + if not data: + self.stop() + break + # are we disconnected from voice? - if not self._connected.is_set(): - # wait until we are connected - self._connected.wait() + if not client.is_connected(): + _log.debug("Not connected, waiting for %ss...", client.timeout) + # wait until we are connected, but not forever + connected = client.wait_until_connected(client.timeout) + if self._end.is_set() or not connected: + _log.debug("Aborting playback") + return + _log.debug("Reconnected, resuming playback") + self._speak(SpeakingState.voice) # reset our internal data self._played_frames_offset += self.loops self.loops = 0 self._start = time.perf_counter() - self.loops += 1 - - # Send the data read from the start of the function if it is not None - if first_data is not None: - data = first_data - first_data = None - # Else read the next bit from the source - else: - data = self.source.read() - - if not data: - self.stop() - break - play_audio(data, encode=not self.source.is_opus()) + self.loops += 1 next_time = self._start + self.DELAY * self.loops delay = max(0, self.DELAY + (next_time - time.perf_counter())) time.sleep(delay) + if client.is_connected(): + self.send_silence() + def run(self) -> None: try: self._do_run() @@ -792,8 +880,8 @@ def run(self) -> None: self._current_error = exc self.stop() finally: - self.source.cleanup() self._call_after() + self.source.cleanup() def _call_after(self) -> None: error = self._current_error @@ -802,32 +890,27 @@ def _call_after(self) -> None: try: self.after(error) except Exception as exc: - _log.exception("Calling the after function failed.") exc.__context__ = error - traceback.print_exception(type(exc), exc, exc.__traceback__) + _log.exception("Calling the after function failed.", exc_info=exc) elif error: - msg = f"Exception in voice thread {self.name}" - _log.exception(msg, exc_info=error) - print(msg, file=sys.stderr) - traceback.print_exception(type(error), error, error.__traceback__) + _log.exception("Exception in voice thread %s", self.name, exc_info=error) def stop(self) -> None: self._end.set() self._resumed.set() - self._speak(False) + self._speak(SpeakingState.none) def pause(self, *, update_speaking: bool = True) -> None: self._resumed.clear() if update_speaking: - self._speak(False) + self._speak(SpeakingState.none) def resume(self, *, update_speaking: bool = True) -> None: - self._played_frames_offset += self.loops self.loops = 0 self._start = time.perf_counter() self._resumed.set() if update_speaking: - self._speak(True) + self._speak(SpeakingState.voice) def is_playing(self) -> bool: return self._resumed.is_set() and not self._end.is_set() @@ -835,20 +918,26 @@ def is_playing(self) -> bool: def is_paused(self) -> bool: return not self._end.is_set() and not self._resumed.is_set() - def _set_source(self, source: AudioSource) -> None: + def set_source(self, source: AudioSource) -> None: with self._lock: self.pause(update_speaking=False) self.source = source self.resume(update_speaking=False) - def _speak(self, speaking: bool) -> None: + def _speak(self, speaking: SpeakingState) -> None: try: asyncio.run_coroutine_threadsafe( - self.client.ws.speak(speaking), self.client.loop + self.client.ws.speak(speaking), self.client.client.loop ) - except Exception as e: - _log.info("Speaking call in player failed: %s", e) + except Exception as exc: + _log.exception("Speaking call in player failed", exc_info=exc) + + def send_silence(self, count: int = 5) -> None: + try: + for n in range(count): + self.client.send_audio_packet(OPUS_SILENCE, encode=False) + except Exception: + pass def played_frames(self) -> int: - """Gets the number of 20ms frames played since the start of the audio file.""" return self._played_frames_offset + self.loops diff --git a/discord/raw_models.py b/discord/raw_models.py index 86635b90e8..f6f5d7962a 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -26,27 +26,27 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING +from collections.abc import ItemsView, KeysView, ValuesView +from typing import TYPE_CHECKING, Any +from . import utils from .automod import AutoModAction, AutoModTriggerType from .enums import ( AuditLogAction, ChannelType, ReactionType, - VoiceChannelEffectAnimationType, try_enum, ) if TYPE_CHECKING: - from .abc import MessageableChannel + from .abc import GuildChannel, MessageableChannel from .guild import Guild from .member import Member from .message import Message from .partial_emoji import PartialEmoji - from .soundboard import PartialSoundboardSound, SoundboardSound + from .soundboard import PartialSoundboardSound from .state import ConnectionState from .threads import Thread - from .types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend from .types.raw_models import ( AuditLogEntryEvent, ) @@ -67,6 +67,8 @@ ThreadUpdateEvent, TypingEvent, VoiceChannelStatusUpdateEvent, + VoiceServerUpdateEvent, + VoiceStateEvent, ) from .user import User @@ -90,12 +92,20 @@ "RawVoiceChannelStatusUpdateEvent", "RawMessagePollVoteEvent", "RawSoundboardSoundDeleteEvent", + "RawVoiceServerUpdateEvent", + "RawVoiceStateUpdateEvent", ) class _RawReprMixin: + __slots__: tuple[str, ...] + def __repr__(self) -> str: - value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__) + value = " ".join( + f"{attr}={getattr(self, attr)!r}" + for attr in self.__slots__ + if not attr.startswith("_") + ) return f"<{self.__class__.__name__} {value}>" @@ -124,10 +134,7 @@ def __init__(self, data: MessageDeleteEvent) -> None: self.message_id: int = int(data["id"]) self.channel_id: int = int(data["channel_id"]) self.cached_message: Message | None = None - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: MessageDeleteEvent = data @@ -156,11 +163,7 @@ def __init__(self, data: BulkMessageDeleteEvent) -> None: self.message_ids: set[int] = {int(x) for x in data.get("ids", [])} self.channel_id: int = int(data["channel_id"]) self.cached_messages: list[Message] = [] - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: BulkMessageDeleteEvent = data @@ -200,11 +203,7 @@ def __init__(self, data: MessageUpdateEvent, new_message: Message) -> None: self.data: MessageUpdateEvent = data self.cached_message: Message | None = None self.new_message: Message = new_message - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") class RawReactionActionEvent(_RawReprMixin): @@ -277,11 +276,7 @@ def __init__( self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionActionEvent = data @@ -307,11 +302,7 @@ class RawReactionClearEvent(_RawReprMixin): def __init__(self, data: ReactionClearEvent) -> None: self.message_id: int = int(data["message_id"]) self.channel_id: int = int(data["channel_id"]) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionClearEvent = data @@ -364,11 +355,7 @@ def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionClearEmojiEvent = data @@ -396,11 +383,9 @@ class RawIntegrationDeleteEvent(_RawReprMixin): def __init__(self, data: IntegrationDeleteEvent) -> None: self.integration_id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - - try: - self.application_id: int | None = int(data["application_id"]) - except KeyError: - self.application_id: int | None = None + self.application_id: int | None = utils._get_as_snowflake( + data, "application_id" + ) self.data: IntegrationDeleteEvent = data @@ -463,10 +448,10 @@ class RawThreadDeleteEvent(_RawReprMixin): __slots__ = ("thread_id", "thread_type", "guild_id", "parent_id", "thread", "data") def __init__(self, data: ThreadDeleteEvent) -> None: - self.thread_id: int = int(data["id"]) - self.thread_type: ChannelType = try_enum(ChannelType, int(data["type"])) - self.guild_id: int = int(data["guild_id"]) - self.parent_id: int = int(data["parent_id"]) + self.thread_id: int = int(data["id"]) # type: ignore + self.thread_type: ChannelType = try_enum(ChannelType, int(data["type"])) # type: ignore + self.guild_id: int = int(data["guild_id"]) # type: ignore + self.parent_id: int = int(data["parent_id"]) # type: ignore self.thread: Thread | None = None self.data: ThreadDeleteEvent = data @@ -493,11 +478,7 @@ class RawVoiceChannelStatusUpdateEvent(_RawReprMixin): def __init__(self, data: VoiceChannelStatusUpdateEvent) -> None: self.id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - - try: - self.status: str | None = data["status"] - except KeyError: - self.status: str | None = None + self.status: str | None = data.get("status") self.data: VoiceChannelStatusUpdateEvent = data @@ -533,11 +514,7 @@ def __init__(self, data: TypingEvent) -> None: data.get("timestamp"), tz=datetime.timezone.utc ) self.member: Member | None = None - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: TypingEvent = data @@ -592,8 +569,8 @@ class RawScheduledEventSubscription(_RawReprMixin): __slots__ = ("event_id", "guild", "user_id", "event_type", "data") def __init__(self, data: ScheduledEventSubscription, event_type: str): - self.event_id: int = int(data["guild_scheduled_event_id"]) - self.user_id: int = int(data["user_id"]) + self.event_id: int = int(data["guild_scheduled_event_id"]) # type: ignore + self.user_id: int = int(data["user_id"]) # type: ignore self.guild: Guild | None = None self.event_type: str = event_type self.data: ScheduledEventSubscription = data @@ -679,42 +656,23 @@ def __init__(self, state: ConnectionState, data: AutoModActionExecution) -> None self.guild: Guild | None = state._get_guild(self.guild_id) self.user_id: int = int(data["user_id"]) self.content: str | None = data.get("content", None) - self.matched_keyword: str = data["matched_keyword"] + self.matched_keyword: str = data["matched_keyword"] # type: ignore self.matched_content: str | None = data.get("matched_content", None) - - if self.guild: - self.member: Member | None = self.guild.get_member(self.user_id) - else: - self.member: Member | None = None - - try: - # I don't see why this would be optional, but it's documented - # as such, so we should treat it that way - self.channel_id: int | None = int(data["channel_id"]) - self.channel: MessageableChannel | None = self.guild.get_channel_or_thread( - self.channel_id - ) - except KeyError: - self.channel_id: int | None = None - self.channel: MessageableChannel | None = None - - try: - self.message_id: int | None = int(data["message_id"]) - self.message: Message | None = state._get_message(self.message_id) - except KeyError: - self.message_id: int | None = None - self.message: Message | None = None - - try: - self.alert_system_message_id: int | None = int( - data["alert_system_message_id"] - ) - self.alert_system_message: Message | None = state._get_message( - self.alert_system_message_id - ) - except KeyError: - self.alert_system_message_id: int | None = None - self.alert_system_message: Message | None = None + self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") + self.channel: MessageableChannel | None = ( + self.channel_id + and self.guild + and self.guild.get_channel_or_thread(self.channel_id) + ) # type: ignore + self.member: Member | None = self.guild and self.guild.get_member(self.user_id) + self.message_id: int | None = utils._get_as_snowflake(data, "message_id") + self.message: Message | None = state._get_message(self.message_id) + self.alert_system_message_id: int | None = utils._get_as_snowflake( + data, "alert_system_message_id" + ) + self.alert_system_message: Message | None = state._get_message( + self.alert_system_message_id + ) self.data: AutoModActionExecution = data def __repr__(self) -> str: @@ -851,11 +809,192 @@ def __init__(self, data: MessagePollVoteEvent, added: bool) -> None: self.answer_id: int = int(data["answer_id"]) self.data: MessagePollVoteEvent = data self.added: bool = added + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + + +# this is for backwards compatibility because VoiceProtocol.on_voice_..._update +# passed the raw payload instead of a raw object. Emit deprecation warning. +class _PayloadLike(_RawReprMixin): + _raw_data: dict[str, Any] + + @utils.deprecated( + "the attributes", + "2.7", + "3.0", + ) + def __getitem__(self, key: str) -> Any: + return self._raw_data[key] + + @utils.deprecated( + "the attributes", + "2.7", + "3.0", + ) + def get(self, key: str, default: Any = None) -> Any: + """Gets an item from this raw event, and returns its value or ``default``. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.get(key, default) + + @utils.deprecated( + "the attributes", + "2.7", + "3.0", + ) + def items(self) -> ItemsView: + """Returns the (key, value) pairs of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.items() + + @utils.deprecated( + "the attributes", + "2.7", + "3.0", + ) + def values(self) -> ValuesView: + """Returns the values of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.values() + + @utils.deprecated( + "the attributes", + "2.7", + "3.0", + ) + def keys(self) -> KeysView: + """Returns the keys of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.keys() + + +class RawVoiceStateUpdateEvent(_PayloadLike): + """Represents the payload for a :meth:`VoiceProtocol.on_voice_state_update` event. + + .. versionadded:: 2.7 + + Attributes + ---------- + deaf: :class:`bool` + Whether the user is guild deafened. + mute: :class:`bool` + Whether the user is guild muted. + self_mute: :class:`bool` + Whether the user has muted themselves by their own accord. + self_deaf: :class:`bool` + Whether the user has deafened themselves by their own accord. + self_stream: :class:`bool` + Whether the user is currently streaming via the 'Go Live' feature. + self_video: :class:`bool` + Whether the user is currently broadcasting video. + suppress: :class:`bool` + Whether the user is suppressed from speaking in a stage channel. + requested_to_speak_at: Optional[:class:`datetime.datetime`] + An aware datetime object that specifies when a member has requested to speak + in a stage channel. It will be ``None`` if they are not requesting to speak + anymore or have been accepted to. + afk: :class:`bool` + Whether the user is connected on the guild's AFK channel. + channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]] + The voice channel that the user is currently connected to. ``None`` if the user + is not currently in a voice channel. + + There are certain scenarios in which this is impossible to be ``None``. + session_id: :class:`str` + The voice connection session ID. + guild_id: Optional[:class:`int`] + The guild ID the user channel is from. + channel_id: Optional[:class:`int`] + The channel ID the user is connected to. Or ``None`` if not connected to any. + """ + + __slots__ = ( + "session_id", + "mute", + "deaf", + "self_mute", + "self_deaf", + "self_stream", + "self_video", + "suppress", + "requested_to_speak_at", + "afk", + "guild_id", + "channel_id", + "_state", + "_raw_data", + ) + + def __init__(self, *, data: VoiceStateEvent, state: ConnectionState) -> None: + self.session_id: str = data["session_id"] + self._state: ConnectionState = state + + self.self_mute: bool = data.get("self_mute", False) + self.self_deaf: bool = data.get("self_deaf", False) + self.mute: bool = data.get("mute", False) + self.deaf: bool = data.get("deaf", False) + self.suppress: bool = data.get("suppress", False) + self.requested_to_speak_at: datetime.datetime | None = utils.parse_time( + data.get("request_to_speak_timestamp") + ) + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") + self._raw_data: VoiceStateEvent = data + + @property + def guild(self) -> Guild | None: + """Returns the guild channel the user is connected to, or ``None``.""" + return self._state._get_guild(self.guild_id) + + @property + def channel(self) -> GuildChannel | None: + """Returns the channel the user is connected to, or ``None``.""" + return self._state.get_channel(self.channel_id) # type: ignore + + +class RawVoiceServerUpdateEvent(_PayloadLike): + """Represents the payload for a :meth:`VoiceProtocol.on_voice_server_update` event. + + .. versionadded:: 2.7 + + Attributes + ---------- + token: :class:`str` + The voice connection token. This should not be shared. + guild_id: :class:`int` + The guild ID this token is part from. + endpoint: Optional[:class:`str`] + The voice server host to connect to. + """ + + __slots__ = ( + "token", + "guild_id", + "endpoint", + "_raw_data", + "_state", + ) + + def __init__(self, *, data: VoiceServerUpdateEvent, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.guild_id: int = int(data["guild_id"]) + self.token: str = data["token"] + self.endpoint: str | None = data["endpoint"] - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + @property + def guild(self) -> Guild | None: + """Returns the guild this server update is from.""" + return self._state._get_guild(self.guild_id) class RawSoundboardSoundDeleteEvent(_RawReprMixin): diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 90fdbf7d19..ccd73c4d59 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -25,25 +25,39 @@ from __future__ import annotations -import io -import os -import struct +import inspect +import logging +import shlex +import subprocess import sys import threading -import time -from typing import TYPE_CHECKING +from collections.abc import Callable, Generator, Sequence +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, overload -from ..types import snowflake -from .errors import SinkException +from discord.file import File +from discord.player import FFmpegAudio +from discord.utils import MISSING, SequenceProxy + +from .errors import FFmpegNotFound if TYPE_CHECKING: - from ..voice_client import VoiceClient + from typing_extensions import ParamSpec, Self + + from discord.member import Member + from discord.user import User + from discord.voice.packets import VoiceData + + from ..voice.client import VoiceClient + + R = TypeVar("R") + P = ParamSpec("P") __all__ = ( - "Filters", "Sink", - "AudioData", "RawData", + "FFmpegSink", + "FilterSink", + "MultiSink", ) @@ -53,194 +67,589 @@ CREATE_NO_WINDOW = 0x08000000 -default_filters = { - "time": 0, - "users": [], - "max_size": 0, -} +S = TypeVar("S", bound="Sink") +_log = logging.getLogger(__name__) -class Filters: - """Filters for :class:`~.Sink` +class SinkMeta(type): + __sink_listeners__: list[tuple[str, str]] - .. versionadded:: 2.0 + def __new__(cls, name, bases, attr, **kwargs): + listeners = {} - Parameters - ---------- - container - Container of all Filters. - """ + inst = super().__new__(cls, name, bases, attr, **kwargs) - def __init__(self, **kwargs): - self.filtered_users = kwargs.get("users", default_filters["users"]) - self.seconds = kwargs.get("time", default_filters["time"]) - self.max_size = kwargs.get("max_size", default_filters["max_size"]) - self.finished = False + for base in reversed(inst.__mro__): + for elem, value in base.__dict__.items(): + if elem in listeners: + del listeners[elem] - @staticmethod - def container(func): # Contains all filters - def _filter(self, data, user): - if not self.filtered_users or user in self.filtered_users: - return func(self, data, user) + is_static = isinstance(value, staticmethod) + if is_static: + value = value.__func__ - return _filter + if not hasattr(value, "__sink_listener__"): + continue - def init(self): - if self.seconds != 0: - thread = threading.Thread(target=self.wait_and_stop) - thread.start() + listeners[elem] = value - def wait_and_stop(self): - time.sleep(self.seconds) - if self.finished: - return - self.vc.stop_recording() + listeners_list = [] + for listener in listeners.values(): + for listener_name in listener.__sink_listener_names__: + listeners_list.append((listener_name, listener.__name__)) + + inst.__sink_listeners__ = listeners_list + return inst + + +class SinkBase(metaclass=SinkMeta): + """Represents an audio sink in which user's audios are stored.""" + + __sink_listeners__: list[tuple[str, str]] + _client: VoiceClient | None + + @property + def root(self) -> Sink: + """Returns the root parent of this sink.""" + return self # type: ignore + + @property + def parent(self) -> Sink | None: + """Returns the parent of this sink.""" + raise NotImplementedError + + @property + def child(self) -> Sink | None: + """Returns this sink's child.""" + raise NotImplementedError + + @property + def children(self) -> Sequence[Sink]: + """Returns the full list of children of this sink.""" + raise NotImplementedError + + @property + def client(self) -> VoiceClient | None: + """Returns the voice client this sink is connected to.""" + return self._client + + def is_opus(self) -> bool: + """Returns whether this sink is opus.""" + return False + + def write(self, user: User | Member | None, data: VoiceData) -> None: + """Writes the provided ``data`` into the ``user`` map.""" + raise NotImplementedError + + def cleanup(self) -> None: + """Cleans this sink.""" + raise NotImplementedError + + def _register_child(self, child: Sink) -> None: + """Registers a child to this sink.""" + raise NotImplementedError + + def walk_children(self, *, with_self: bool = False) -> Generator[Sink]: + """Iterates through all the children of this sink, including nested.""" + if with_self: + yield self # type: ignore + + for child in self.children: + yield child + yield from child.walk_children() + def __del__(self) -> None: + self.cleanup() -class RawData: - """Handles raw data from Discord so that it can be decrypted and decoded to be used. + +class Sink(SinkBase): + """Object that stores the recordings of the audio data. + + Can be subclassed for extra customizability. .. versionadded:: 2.0 """ - def __init__(self, data, client): - self.data = bytearray(data) - self.client = client - - unpacker = struct.Struct(">xxHII") - self.sequence, self.timestamp, self.ssrc = unpacker.unpack_from(self.data[:12]) + _parent: Sink | None = None + _child: Sink | None = None + _client = None - # RFC3550 5.1: RTP Fixed Header Fields - if self.client.mode.endswith("_rtpsize"): - # If It Has CSRC Chunks - cutoff = 12 + (data[0] & 0b00_0_0_1111) * 4 - # If It Has A Extension - if data[0] & 0b00_0_1_0000: - cutoff += 4 + def __init__(self, *, dest: Sink | None = None) -> None: + if dest is not None: + self._register_child(dest) else: - cutoff = 12 + self._child = dest + + def _register_child(self, child: Sink) -> None: + if child in self.root.walk_children(): + raise RuntimeError("Sink is already registered") + self._child = child + child._parent = self + + @property + def root(self) -> Sink: + if self.parent is None: + return self + return self.parent + + @property + def parent(self) -> Sink | None: + return self._parent + + @property + def child(self) -> Sink | None: + return self._child + + @property + def children(self) -> Sequence[Sink]: + return [self._child] if self._child else [] + + @property + def client(self) -> VoiceClient | None: + if self.parent is not None: + return self.parent.client + else: + return self._client - self.header = data[:cutoff] - self.data = self.data[cutoff:] + @classmethod + def listener(cls, name: str = MISSING): + """Registers a sink method as a listener. - self.decrypted_data = getattr(self.client, f"_decrypt_{self.client.mode}")( - self.header, self.data - ) - self.decoded_data = None + You can stack this decorator and pass the ``name`` parameter to mark the same function + to listen to various events. + + Parameters + ---------- + name: :class:`str` + The name of the event, must not be prefixed with ``on_``. Defaults to the function name. + """ - self.user_id = None - self.receive_time = time.perf_counter() + if name is not MISSING and not isinstance(name, str): + raise TypeError( + f"expected a str for listener name, got {name.__class__.__name__} instead" + ) + def decorator(func): + actual = func -class AudioData: - """Handles data that's been completely decrypted and decoded and is ready to be saved to file. + if isinstance(actual, staticmethod): + actual = actual.__func__ - .. versionadded:: 2.0 + if inspect.iscoroutinefunction(actual): + raise TypeError("listener functions must not be coroutines") + + actual.__sink_listener__ = True + to_assign = name or actual.__name__.removeprefix("on_") + + try: + actual.__sink_listener_names__.append(to_assign) + except AttributeError: + actual.__sink_listener_names__ = [to_assign] + + return func + + return decorator + + +class MultiSink(Sink): + """A sink that can handle multiple sinks concurrently. + + .. versionadded:: 2.7 """ - def __init__(self, file): - self.file = file - self.finished = False + def __init__(self, *destinations: Sink) -> None: + for dest in destinations: + self._register_child(dest) + self._children: list[Sink] = list(destinations) + + def _register_child(self, child: Sink) -> None: + if child in self.root.walk_children(): + raise RuntimeError("Sink is already registered") + child._parent = self + + @property + def child(self) -> Sink | None: + return self._children[0] if self._children else None + + @property + def children(self) -> Sequence[Sink]: + return SequenceProxy(self._children) - def write(self, data): - """Writes audio data. + def add_destination(self, dest: Sink, /) -> None: + """Adds a sink to be dispatched in this sink. + + Parameters + ---------- + dest: :class:`Sink` + The sink to register as this one's child. Raises ------ - ClientException - The AudioData is already finished writing. + RuntimeError + The sink is already registered. + """ + self._register_child(dest) + + def remove_destination(self, dest: Sink, /) -> None: + """Removes a sink from this sink dispatch. + + Parameters + ---------- + dest: :class:`Sink` + The sink to remove. """ - if self.finished: - raise SinkException("The AudioData is already finished writing.") + try: - self.file.write(data) + self._children.remove(dest) except ValueError: pass + else: + dest._parent = None - def cleanup(self): - """Finishes and cleans up the audio data. - Raises - ------ - ClientException - The AudioData is already finished writing. - """ - if self.finished: - raise SinkException("The AudioData is already finished writing.") - self.file.seek(0) - self.finished = True +if TYPE_CHECKING: + from typing_extensions import deprecated - def on_format(self, encoding): - """Called when audio data is formatted. + @deprecated( + "RawData has been deprecated and will be removed in 3.0 in favour of VoiceData", + category=DeprecationWarning, + ) + def RawData(**kwargs: Any) -> Any: + """Deprecated since version 2.7, use :class:`VoiceData` instead.""" + +else: + + class RawData: + def __init__(self, **kwargs: Any) -> None: + raise DeprecationWarning( + "RawData has been deprecated in favour of VoiceData" + ) - Raises - ------ - ClientException - The AudioData is still writing. - """ - if not self.finished: - raise SinkException("The AudioData is still writing.") +class FFmpegSink(Sink): + """A :class:`Sink` built to use ffmpeg executables. -class Sink(Filters): - """A sink "stores" recorded audio data. + You can find default implementations of this sink in: - Can be subclassed for extra customizablilty. + - :class:`M4ASink` + - :class:`MKASink` - .. warning:: - It is recommended you use - the officially provided sink classes, - such as :class:`~discord.sinks.WaveSink`. + .. versionadded:: 2.7 - just replace the following like so: :: + Parameters + ---------- + filename: :class:`str` + The file in which the ffmpeg buffer should be saved to. + Can not be mixed with ``buffer``. + buffer: IO[:class:`bytes`] + The buffer in which the ffmpeg result would be written to. + Can not be mixed with ``filename``. + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in whcih will be written. Defaults to ``None``. + before_options: :class:`str` | :data:`None` + The options to append **before** the default ones. + options: :class:`str` | :data:`None` + The options to append **after** the default ones. You can override the + default ones with this. + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + """ - vc.start_recording( - MySubClassedSink(), - finished_callback, - ctx.channel, + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] = ..., + before_options: str | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] = ..., + before_options: str | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + before_options: str | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__(dest=None) + + if filename is not MISSING and buffer is not MISSING: + raise TypeError("can't mix filename and buffer parameters") + + self.filename: str = filename or "pipe:1" + self.buffer: IO[bytes] = buffer + + self.on_error = error_hook or self._on_error + + args = [executable, "-hide_banner"] + subprocess_kwargs: dict[str, Any] = {"stdin": subprocess.PIPE} + if self.buffer is not MISSING: + subprocess_kwargs["stdout"] = subprocess.PIPE + + piping_stderr = False + if stderr is not None: + try: + stderr.fileno() + except Exception: + piping_stderr = True + subprocess_kwargs["stderr"] = subprocess.PIPE + + if isinstance(before_options, str): + args.extend(shlex.split(before_options)) + + args.extend( + { + "-f": "s16le", + "-ar": "48000", + "-ac": "2", + "-i": "pipe:0", + "-loglevel": "warning", + "-blocksize": str(FFmpegAudio.BLOCKSIZE), + } ) - .. versionadded:: 2.0 + if isinstance(options, str): + args.extend(shlex.split(options)) + + args.append(self.filename) + + self._process: subprocess.Popen = MISSING + self._process = self._spawn_process(args, **subprocess_kwargs) + + self._stdin: IO[bytes] = self._process.stdin # type: ignore + self._stdout: IO[bytes] | None = None + self._stderr: IO[bytes] | None = None + self._stdout_reader_thread: threading.Thread | None = None + self._stderr_reader_thread: threading.Thread | None = None + + if self.buffer: + n = f"popen-stdout-reader:pid-{self._process.pid}" + self._stdout = self._process.stdout + _args = (self._stdout, self.buffer) + self._stdout_reader_thread = threading.Thread( + target=self._pipe_reader, args=_args, daemon=True, name=n + ) + self._stdout_reader_thread.start() + + if piping_stderr: + n = f"popen-stderr-reader:pid-{self._process.pid}" + self._stderr = self._process.stderr + _args = (self._stderr, stderr) + self._stderr_reader_thread = threading.Thread( + target=self._pipe_reader, args=_args, daemon=True, name=n + ) + self._stderr_reader_thread.start() + + @staticmethod + def _on_error(_self: FFmpegSink, error: Exception, data: VoiceData | None) -> None: + _self.client.stop_recording() # type: ignore + + def is_opus(self) -> bool: + return False + + def cleanup(self) -> None: + self._kill_processes() + self._process = self._stdout = self._stdin = self._stderr = MISSING + + def write(self, user: User | Member | None, data: VoiceData) -> None: + if self._process and not self._stdin.closed: + audio = data.opus if self.is_opus() else data.pcm + assert audio is not None - Raises - ------ - ClientException - An invalid encoding type was specified. - ClientException - Audio may only be formatted after recording is finished. + try: + self._stdin.write(audio) + except Exception as exc: + _log.exception("Error while writing audio data to stdin ffmpeg") + self._kill_processes() + self.on_error(self, exc, data) + + def to_file( + self, filename: str, /, *, description: str | None = None, spoiler: bool = False + ) -> File | None: + """Returns the :class:`discord.File` of this sink. + + This is only applicable if this sink uses a ``buffer`` instead of a ``filename``. + + .. warning:: + + This should be used only after the sink has stopped recording. + """ + if self.buffer is not MISSING: + fp = File( + self.buffer.read(), + filename=filename, + description=description, + spoiler=spoiler, + ) + return fp + return None + + def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: + _log.debug( + "Spawning ffmpeg process with command %s and kwargs %s", + args, + subprocess_kwargs, + ) + process = None + + try: + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs + ) + except FileNotFoundError: + executable = args.partition(" ")[0] if isinstance(args, str) else args[0] + raise FFmpegNotFound(f"{executable!r} executable was not found") from None + except subprocess.SubprocessError as exc: + raise Exception(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc + else: + return process + + def _kill_processes(self) -> None: + proc: subprocess.Popen = getattr(self, "_process", MISSING) + + if proc is MISSING: + return + + _log.debug("Terminating ffmpeg process %s", proc.pid) + + try: + self._stdin.close() + except Exception: + pass + + _log.debug("Waiting for ffmpeg process %s", proc.pid) + + try: + proc.wait(5) + except Exception: + pass + + try: + proc.kill() + except Exception as exc: + _log.exception( + "Ignoring exception while killing Popen process %s", + proc.pid, + exc_info=exc, + ) + + if proc.poll() is None: + _log.info( + "ffmpeg process %s has not terminated. Waiting to terminate...", + proc.pid, + ) + proc.communicate() + _log.info( + "ffmpeg process %s should have terminated with a return code of %s", + proc.pid, + proc.returncode, + ) + else: + _log.info( + "ffmpeg process %s successfully terminated with return code of %s", + proc.pid, + proc.returncode, + ) + + self._process = MISSING + + def _pipe_reader(self, source: IO[bytes], dest: IO[bytes]) -> None: + while self._process: + if source.closed: + return + + try: + data = source.read(FFmpegAudio.BLOCKSIZE) + except (OSError, ValueError) as exc: + _log.debug("FFmpeg stdin pipe closed with exception %s", exc) + return + except Exception: + _log.debug( + "An error ocurred in %s, this can be ignored", self, exc_info=True + ) + return + + if data is None: + return + + try: + dest.write(data) + except Exception as exc: + _log.exception( + "Error while writing to destination pipe %s", self, exc_info=exc + ) + self._kill_processes() + self.on_error(self, exc, None) + return + + +class FilterSink(Sink): + r"""A sink that calls filtering callbacks before writing. + + .. versionadded:: 2.7 + + Parameters + ---------- + destination: :class:`Sink` + The sink that is being filtered. + filters: Sequence[Callable[[:class:`User` | :class`Member` | :data:`None`, :class:`VoiceData`], :class:`bool`]] + The filters of this sink. + filtering_mode: Literal["all", "any"] + How the filters should work, if ``all`, all filters must be successful in order for + a voice data packet to be written. Using ``any`` will make it so only one filter is + required to be successful in order for a voice data packet to be written. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - self.vc: VoiceClient | None = None - self.audio_data = {} - - def init(self, vc): # called under listen - self.vc: VoiceClient = vc - super().init() - - @Filters.container - def write(self, data, user): - if user not in self.audio_data: - file = io.BytesIO() - self.audio_data.update({user: AudioData(file)}) - - file = self.audio_data[user] - file.write(data) - - def cleanup(self): - self.finished = True - for file in self.audio_data.values(): - file.cleanup() - self.format_audio(file) - - def get_all_audio(self): - """Gets all audio files.""" - return [x.file for x in self.audio_data.values()] - - def get_user_audio(self, user: snowflake.Snowflake): - """Gets the audio file(s) of one specific user.""" - return os.path.realpath(self.audio_data.pop(user)) + def __init__( + self, + destination: Sink, + filters: Sequence[Callable[[User | Member | None, VoiceData], bool]], + *, + filtering_mode: Literal["all", "any"] = "all", + ) -> None: + if not filters: + raise ValueError("filters must have at least one callback") + + if not isinstance(destination, SinkBase): + raise TypeError( + f"expected a Sink object, got {destination.__class__.__name__}" + ) + + self._filter_strat = all if filtering_mode == "all" else any + self.filters: Sequence[Callable[[User | Member | None, VoiceData], bool]] = ( + filters + ) + self.destination: Sink = destination + super().__init__(dest=destination) + + def is_opus(self) -> bool: + return self.destination.is_opus() + + def write(self, user: User | Member | None, data: VoiceData) -> None: + if self._filter_strat(f(user, data) for f in self.filters): + self.destination.write(user, data) + + def cleanup(self) -> None: + self.filters = [] + self.destination.cleanup() diff --git a/discord/sinks/errors.py b/discord/sinks/errors.py index 5f036efff5..9f3d081bc4 100644 --- a/discord/sinks/errors.py +++ b/discord/sinks/errors.py @@ -87,3 +87,25 @@ class MKASinkError(SinkException): .. versionadded:: 2.0 """ + + +class MaxProcessesCountReached(SinkException): + """Exception thrown when you try to create an audio converter process and the maximum + process count threshold is exceeded. + + .. versionadded:: 2.7 + """ + + +class FFmpegNotFound(SinkException): + """Exception thrown when the provided FFmpeg executable path was not found. + + .. versionadded:: 2.7 + """ + + +class NoUserAudio(SinkException): + """Exception thrown when you try to format the audio of a user not saved in a sink. + + .. versionadded:: 2.7 + """ diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 1cff9da538..e1a1fd5ce1 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -22,82 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import os -import subprocess -import time +from __future__ import annotations -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import M4ASinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class M4ASink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from ..voice.packets import VoiceData + +__all__ = ("M4ASink",) + + +class M4ASink(FFmpegSink): """A special sink for .m4a files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "m4a" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - M4ASinkError - Audio may only be formatted after recording is finished. - M4ASinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise M4ASinkError( - "Audio may only be formatted after recording is finished." - ) - m4a_file = f"{time.time()}.tmp" - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ipod", - m4a_file, - ] - if os.path.exists(m4a_file): - os.remove( - m4a_file - ) # process will get stuck asking whether to overwrite, if file already exists. - try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - except FileNotFoundError: - raise M4ASinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise M4ASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - process.communicate(audio.file.read()) - - with open(m4a_file, "rb") as f: - audio.file = io.BytesIO(f.read()) - audio.file.seek(0) - os.remove(m4a_file) - - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f ipod -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index c2bbefb923..4c9878ebeb 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -22,75 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import subprocess +from __future__ import annotations -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MKASinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class MKASink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord.voice import VoiceData + +__all__ = ("MKASink",) + + +class MKASink(FFmpegSink): """A special sink for .mka files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "mka" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - MKASinkError - Audio may only be formatted after recording is finished. - MKASinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise MKASinkError( - "Audio may only be formatted after recording is finished." - ) - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", - ] - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MKASinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise MKASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f matroska -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 93f4cc7444..c4b180bff9 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -22,74 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import subprocess +from __future__ import annotations -from .core import Filters, Sink, default_filters -from .errors import MKVSinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class MKVSink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord.voice import VoiceData + +__all__ = ("MKVSink",) + + +class MKVSink(FFmpegSink): """A special sink for .mkv files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "mkv" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - MKVSinkError - Audio may only be formatted after recording is finished. - MKVSinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise MKVSinkError( - "Audio may only be formatted after recording is finished." - ) - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", - ] - try: - process = subprocess.Popen( - args, # creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MKVSinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise MKVSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f matroska -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 74386a2738..6c62e8e8ef 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -22,75 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import subprocess +from __future__ import annotations -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MP3SinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class MP3Sink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord.voice import VoiceData + +__all__ = ("MP3Sink",) + + +class MP3Sink(FFmpegSink): """A special sink for .mp3 files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "mp3" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - MP3SinkError - Audio may only be formatted after recording is finished. - MP3SinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise MP3SinkError( - "Audio may only be formatted after recording is finished." - ) - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp3", - "pipe:1", - ] - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MP3SinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise MP3SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f mp3 -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index c4d0ed2b63..fe050ec696 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -22,82 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import os -import subprocess -import time +from __future__ import annotations -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MP4SinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class MP4Sink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord.voice import VoiceData + +__all__ = ("MP4Sink",) + + +class MP4Sink(FFmpegSink): """A special sink for .mp4 files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "mp4" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - MP4SinkError - Audio may only be formatted after recording is finished. - MP4SinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise MP4SinkError( - "Audio may only be formatted after recording is finished." - ) - mp4_file = f"{time.time()}.tmp" - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp4", - mp4_file, - ] - if os.path.exists(mp4_file): - os.remove( - mp4_file - ) # process will get stuck asking whether to overwrite, if file already exists. - try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - except FileNotFoundError: - raise MP4SinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise MP4SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - process.communicate(audio.file.read()) - - with open(mp4_file, "rb") as f: - audio.file = io.BytesIO(f.read()) - audio.file.seek(0) - os.remove(mp4_file) - - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f mp4 -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 7b531464bd..62b0717be7 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -22,75 +22,97 @@ DEALINGS IN THE SOFTWARE. """ -import io -import subprocess +from __future__ import annotations -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import OGGSinkError +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload +from discord.utils import MISSING -class OGGSink(Sink): +from .core import FFmpegSink + +if TYPE_CHECKING: + from typing_extensions import Self + + from discord.voice import VoiceData + +__all__ = ("OGGSink",) + + +class OGGSink(FFmpegSink): """A special sink for .ogg files. .. versionadded:: 2.0 + + Parameters + ---------- + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "ogg" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - OGGSinkError - Audio may only be formatted after recording is finished. - OGGSinkError - Formatting the audio failed. - """ - if self.vc.recording: - raise OGGSinkError( - "Audio may only be formatted after recording is finished." - ) - args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ogg", - "pipe:1", - ] - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - except FileNotFoundError: - raise OGGSinkError("ffmpeg was not found.") from None - except subprocess.SubprocessError as exc: - raise OGGSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + def __init__( + self, + *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = "ffmpeg", + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f ogg -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index c587da349a..5ae17b2df4 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -22,7 +22,18 @@ DEALINGS IN THE SOFTWARE. """ -from .core import Filters, Sink, default_filters +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +from .core import Sink + +if TYPE_CHECKING: + from discord import abc + from discord.voice import VoiceData + +__all__ = ("PCMSink",) class PCMSink(Sink): @@ -31,15 +42,10 @@ class PCMSink(Sink): .. versionadded:: 2.0 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__(self) -> None: + super().__init__(dest=None) - self.encoding = "pcm" - self.vc = None - self.audio_data = {} + self.buffer: io.BytesIO = io.BytesIO() - def format_audio(self, audio): - return + def write(self, user: abc.User | None, data: VoiceData) -> None: + self.buffer.write(data.pcm) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index 37f5aac933..b200ff3956 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -22,48 +22,108 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +import io +import logging import wave +from typing import TYPE_CHECKING + +from discord.file import File +from discord.opus import Decoder + +from .core import Sink -from .core import Filters, Sink, default_filters -from .errors import WaveSinkError +if TYPE_CHECKING: + from discord import abc + from discord.voice import VoiceData + +_log = logging.getLogger(__name__) + +__all__ = ( + "WaveSink", + "WavSink", +) class WaveSink(Sink): - """A special sink for .wav(wave) files. + """A special sink for .wav(e) files. .. versionadded:: 2.0 + + Parameters + ---------- + destination: :class:`str` | :term:`py:bytes-like object` + The destination in which the data should be saved into. + + If this is a filename, then it is saved into that. Else, treats + it like a buffer. + + .. versionadded:: 2.7 + channels: :class:`int` + The amount of channels. + + .. versionadded:: 2.7 + sample_width: :class:`int` + The sample width to "n" bytes. + + .. versionadded:: 2.7 + sampling_rate: :class:`int` + The frame rate. A non-integral input is rounded to the nearest int. + + .. versionadded:: 2.7 """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - - self.encoding = "wav" - self.vc = None - self.audio_data = {} - - def format_audio(self, audio): - """Formats the recorded audio. - - Raises - ------ - WaveSinkError - Audio may only be formatted after recording is finished. - WaveSinkError - Formatting the audio failed. + def __init__( + self, + destination: wave._File, + *, + channels: int = Decoder.CHANNELS, + sample_width: int = Decoder.SAMPLE_SIZE // Decoder.CHANNELS, + sampling_rate: int = Decoder.SAMPLING_RATE, + ) -> None: + super().__init__() + + self._destination: wave._File = destination + self._file: wave.Wave_write = wave.open(destination, "wb") + self._file.setnchannels(channels) + self._file.setsampwidth(sample_width) + self._file.setframerate(sampling_rate) + + def is_opus(self) -> bool: + return False + + def write(self, user: abc.User | None, data: VoiceData) -> None: + self._file.writeframes(data.pcm) + + def to_file( + self, filename: str, /, *, description: str | None = None, spoiler: bool = False + ) -> File | None: + """Returns the :class:`discord.File` of this sink. + + .. warning:: + + This should be used only after the sink has stopped recording. """ - if self.vc.recording: - raise WaveSinkError( - "Audio may only be formatted after recording is finished." + + f = wave.open(self._destination, "rb") + data = f.readframes(f.getnframes()) + return File( + io.BytesIO(data), filename, description=description, spoiler=spoiler + ) + + def cleanup(self) -> None: + try: + self._file.close() + except Exception as exc: + _log.warning( + "An error ocurred while closing the wave writing file on cleanup", + exc_info=exc, ) - data = audio.file - with wave.open(data, "wb") as f: - f.setnchannels(self.vc.decoder.CHANNELS) - f.setsampwidth(self.vc.decoder.SAMPLE_SIZE // self.vc.decoder.CHANNELS) - f.setframerate(self.vc.decoder.SAMPLING_RATE) - data.seek(0) - audio.on_format(self.encoding) +WavSink = WaveSink +"""An alias for :class:`~.WaveSink`. + +.. versionadded:: 2.7 +""" diff --git a/discord/state.py b/discord/state.py index 56bdcbd684..aa2749a14c 100644 --- a/discord/state.py +++ b/discord/state.py @@ -89,7 +89,7 @@ from .types.poll import Poll as PollPayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.user import User as UserPayload - from .voice_client import VoiceClient + from .voice import VoiceClient T = TypeVar("T") CS = TypeVar("CS", bound="ConnectionState") @@ -1860,7 +1860,8 @@ def parse_voice_state_update(self, data) -> None: if int(data["user_id"]) == self_id: voice = self._get_voice_client(guild.id) if voice is not None: - coro = voice.on_voice_state_update(data) + payload = RawVoiceStateUpdateEvent(data=data, state=self) + coro = voice.on_voice_state_update(payload) asyncio.create_task( logging_coroutine( coro, info="Voice Protocol voice state update handler" @@ -1897,8 +1898,9 @@ def parse_voice_server_update(self, data) -> None: key_id = int(data["channel_id"]) vc = self._get_voice_client(key_id) + payload = RawVoiceServerUpdateEvent(data=data, state=self) if vc is not None: - coro = vc.on_voice_server_update(data) + coro = vc.on_voice_server_update(payload) asyncio.create_task( logging_coroutine( coro, info="Voice Protocol voice server update handler" diff --git a/discord/types/raw_models.py b/discord/types/raw_models.py index 1a7feee059..7ec8a446ca 100644 --- a/discord/types/raw_models.py +++ b/discord/types/raw_models.py @@ -33,6 +33,8 @@ from .snowflake import Snowflake from .threads import Thread, ThreadMember from .user import User +from .voice import VoiceServerUpdate as VoiceServerUpdateEvent +from .voice import VoiceState as VoiceStateEvent class _MessageEventOptional(TypedDict, total=False): diff --git a/discord/types/voice.py b/discord/types/voice.py index 68d99ccd48..307f98cee3 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -40,7 +40,7 @@ ] -class _VoiceState(TypedDict): +class VoiceState(TypedDict): member: NotRequired[MemberWithUser] self_stream: NotRequired[bool] user_id: Snowflake @@ -51,15 +51,12 @@ class _VoiceState(TypedDict): self_mute: bool self_video: bool suppress: bool + request_to_speak_timestamp: str | None + channel_id: Snowflake | None + guild_id: NotRequired[Snowflake] -class GuildVoiceState(_VoiceState): - channel_id: Snowflake - - -class VoiceState(_VoiceState, total=False): - channel_id: Snowflake | None - guild_id: Snowflake +GuildVoiceState = VoiceState class VoiceRegion(TypedDict): diff --git a/discord/utils.py b/discord/utils.py index c42a51cbd8..c50df31b26 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -335,7 +335,7 @@ def deprecated( stacklevel: int = 3, *, use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when the decorated function is called. @@ -360,7 +360,7 @@ def deprecated( will display as ``login``. Defaults to ``True``. """ - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + def actual_decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) def decorated(*args: P.args, **kwargs: P.kwargs) -> T: warn_deprecated( diff --git a/discord/voice/__init__.py b/discord/voice/__init__.py new file mode 100644 index 0000000000..d6d13ab0df --- /dev/null +++ b/discord/voice/__init__.py @@ -0,0 +1,13 @@ +""" +discord.voice +~~~~~~~~~~~~~ + +Voice support for the Discord API. + +:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development +:license: MIT, see LICENSE for more details. +""" + +from ._types import * +from .client import * +from .packets import * diff --git a/discord/voice/_types.py b/discord/voice/_types.py new file mode 100644 index 0000000000..71015900f2 --- /dev/null +++ b/discord/voice/_types.py @@ -0,0 +1,166 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord import abc + from discord.client import Client + from discord.raw_models import ( + RawVoiceServerUpdateEvent, + RawVoiceStateUpdateEvent, + ) + + P = ParamSpec("P") + R = TypeVar("R") + +ClientT = TypeVar("ClientT", bound="Client", covariant=True) + +__all__ = ("VoiceProtocol",) + + +class VoiceProtocol(Generic[ClientT]): + """A class that represents the Discord voice protocol. + + .. warning:: + + If you are an end user, you **should not construct this manually** but instead + take it from the return type in :meth:`abc.Connectable.connect `. + The parameters and methods being documented here is so third party libraries can refer to it + when implementing their own VoiceProtocol types. + + This is an abstract class. The library provides a concrete implementation + under :class:`VoiceClient`. + + This class allows you to implement a protocol to allow for an external + method of sending voice, such as Lavalink_ or a native library implementation. + + These classes are passed to :meth:`abc.Connectable.connect `. + + .. _Lavalink: https://github.com/freyacodes/Lavalink + + Parameters + ---------- + client: :class:`Client` + The client (or its subclasses) that started the connection request. + channel: :class:`abc.Connectable` + The voice channel that is being connected to. + """ + + def __init__(self, client: ClientT, channel: abc.Connectable) -> None: + self.client: ClientT = client + self.channel: abc.Connectable = channel + + async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + """|coro| + + A method called when the client's voice state has changed. This corresponds + to the ``VOICE_STATE_UPDATE`` event. + + Parameters + ---------- + data: :class:`RawVoiceStateUpdateEvent` + The voice state payload. + + .. versionchanged:: 2.7 + This now gets passed a `RawVoiceStateUpdateEvent` object instead of a :class:`dict`, but + accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated. + """ + raise NotImplementedError + + async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + """|coro| + + A method called when the client's intially connecting to voice. This corresponds + to the ``VOICE_SERVER_UPDATE`` event. + + Parameters + ---------- + data: :class:`RawVoiceServerUpdateEvent` + The voice server payload. + + .. versionchanged:: 2.7 + This now gets passed a `RawVoiceServerUpdateEvent` object instead of a :class:`dict`, but + accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated. + """ + raise NotImplementedError + + async def connect(self, *, timeout: float, reconnect: bool) -> None: + """|coro| + + A method called to initialise the connection. + + The library initialises this class and calls ``__init__``, and then :meth:`connect` when attempting + to start a connection to the voice. If an error ocurrs, it calls :meth:`disconnect`, so if you need + to implement any cleanup, you should manually call it in :meth:`disconnect` as the library will not + do so for you. + + Within this method, to start the voice connection flow, it is recommened to use :meth:`Guild.change_voice_state` + to start the flow. After which :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called, + although this could vary and cause unexpected behaviour, but that falls under Discord's way of handling the voice + connection. + + Parameters + ---------- + timeout: :class:`float` + The timeout for the connection. + reconnect: :class:`bool` + Whether reconnection is expected. + """ + raise NotImplementedError + + async def disconnect(self, *, force: bool) -> None: + """|coro| + + A method called to terminate the voice connection. + + This can be either called manually when forcing a disconnection, or when an exception in :meth:`connect` ocurrs. + + It is recommended to call :meth:`cleanup` here. + + Parameters + ---------- + force: :class:`bool` + Whether the disconnection was forced. + """ + + def cleanup(self) -> None: + """This method *must* be called to ensure proper clean-up during a disconnect. + + It is advisable to call this from within :meth:`disconnect` when you are completely + done with the voice protocol instance. + + This method removes it from the internal state cache that keeps track of the currently + alive voice clients. Failure to clean-up will cause subsequent connections to report that + it's still connected. + + **The library will NOT automatically call this for you**, unlike :meth:`connect` and :meth:`disconnect`. + """ + key, _ = self.channel._get_voice_client_key() + self.client._connection._remove_voice_client(key) diff --git a/discord/voice/client.py b/discord/voice/client.py new file mode 100644 index 0000000000..179aa42818 --- /dev/null +++ b/discord/voice/client.py @@ -0,0 +1,769 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import datetime +import logging +import struct +import warnings +from typing import TYPE_CHECKING, Any, Literal, overload + +from discord import opus +from discord.enums import SpeakingState, try_enum +from discord.errors import ClientException +from discord.player import AudioPlayer, AudioSource +from discord.sinks.core import Sink +from discord.sinks.errors import RecordingException +from discord.utils import MISSING + +from ._types import VoiceProtocol +from .enums import OpCodes +from .receive import AudioReader +from .state import VoiceConnectionState + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord import abc + from discord.client import Client + from discord.guild import Guild, VocalGuildChannel + from discord.member import Member + from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Encoder + from discord.raw_models import ( + RawVoiceServerUpdateEvent, + RawVoiceStateUpdateEvent, + ) + from discord.state import ConnectionState + from discord.types.voice import SupportedModes + from discord.user import ClientUser, User + + from .gateway import VoiceWebSocket + from .receive.reader import AfterCallback + + P = ParamSpec("P") + +_log = logging.getLogger(__name__) + +has_nacl: bool + +try: + import nacl.secret + import nacl.utils + + has_nacl = True +except ImportError: + has_nacl = False + +__all__ = ("VoiceClient",) + + +class VoiceClient(VoiceProtocol): + """Represents a Discord voice connection. + + You do not create these, you typically get them from e.g. + :meth:`VoiceChannel.connect`. + + Attributes + ---------- + session_id: :class:`str` + The voice connection session ID. + token: :class:`str` + The voice connection token. + endpoint: :class:`str` + The endpoint we are connecting to. + channel: Union[:class:`VoiceChannel`, :class:`StageChannel`] + The channel we are connected to. + + Warning + ------- + In order to use PCM based AudioSources, you must have the opus library + installed on your system and loaded through :func:`opus.load_opus`. + Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) + or the library will not be able ot transmit audio. + """ + + channel: VocalGuildChannel + + def __init__( + self, + client: Client, + channel: abc.Connectable, + ) -> None: + if not has_nacl: + raise RuntimeError( + "PyNaCl library is needed in order to use voice related features, " + 'you can run "pip install py-cord[voice]" to install all voice-related ' + "dependencies." + ) + + super().__init__(client, channel) + state = client._connection + + self.server_id: int = MISSING + self.socket = MISSING + self.loop: asyncio.AbstractEventLoop = state.loop + self._state: ConnectionState = state + + self.sequence: int = 0 + self.timestamp: int = 0 + self._player: AudioPlayer | None = None + self._player_future: asyncio.Future[None] | None = None + self.encoder: Encoder = MISSING + self._incr_nonce: int = 0 + + self._connection: VoiceConnectionState = self.create_connection_state() + + self._ssrc_to_id: dict[int, int] = {} + self._id_to_ssrc: dict[int, int] = {} + self._event_listeners: dict[str, list] = {} + self._reader: AudioReader = MISSING + + warn_nacl: bool = not has_nacl + supported_modes: tuple[SupportedModes, ...] = ( + "aead_xchacha20_poly1305_rtpsize", + "xsalsa20_poly1305_lite", + "xsalsa20_poly1305_suffix", + "xsalsa20_poly1305", + ) + + @property + def guild(self) -> Guild: + """Returns the guild the channel we're connecting to is bound to.""" + return self.channel.guild + + @property + def user(self) -> ClientUser: + """The user connected to voice (i.e. ourselves)""" + return self._state.user # type: ignore + + @property + def session_id(self) -> str | None: + return self._connection.session_id + + @property + def token(self) -> str | None: + return self._connection.token + + @property + def endpoint(self) -> str | None: + return self._connection.endpoint + + @property + def ssrc(self) -> int: + return self._connection.ssrc + + @property + def mode(self) -> SupportedModes: + return self._connection.mode + + @property + def secret_key(self) -> list[int]: + return self._connection.secret_key + + @property + def ws(self) -> VoiceWebSocket: + return self._connection.ws + + @property + def timeout(self) -> float: + return self._connection.timeout + + def checked_add(self, attr: str, value: int, limit: int) -> None: + val = getattr(self, attr) + if val + value > limit: + setattr(self, attr, 0) + else: + setattr(self, attr, val + value) + + def create_connection_state(self) -> VoiceConnectionState: + return VoiceConnectionState(self, hook=self._recv_hook) + + async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: + op = msg["op"] + data = msg.get("d", {}) + + if op == OpCodes.ready: + self._add_ssrc(self.guild.me.id, data["ssrc"]) + elif op == OpCodes.speaking: + uid = int(data["user_id"]) + ssrc = data["ssrc"] + + self._add_ssrc(uid, ssrc) + + member = self.guild.get_member(uid) + state = try_enum(SpeakingState, data["speaking"]) + self.dispatch("member_speaking_state_update", member, ssrc, state) + elif op == OpCodes.clients_connect: + uids = list(map(int, data["user_ids"])) + + for uid in uids: + member = self.guild.get_member(uid) + if not member: + _log.warning( + "Skipping member referencing ID %d on member_connect", uid + ) + continue + self.dispatch("member_connect", member) + elif op == OpCodes.client_disconnect: + uid = int(data["user_id"]) + ssrc = self._id_to_ssrc.get(uid) + + if self._reader and ssrc is not None: + _log.debug("Destroying decoder for user %d, ssrc=%d", uid, ssrc) + self._reader.packet_router.destroy_decoder(ssrc) + + self._remove_ssrc(user_id=uid) + member = self.guild.get_member(uid) + self.dispatch("member_disconnect", member, ssrc) + + # maybe handle video and such things? + + async def _run_event( + self, coro, event_name: str, *args: Any, **kwargs: Any + ) -> None: + try: + await coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception: + _log.exception("Error calling %s", event_name) + + def _schedule_event( + self, coro, event_name: str, *args: Any, **kwargs: Any + ) -> asyncio.Task: + wrapped = self._run_event(coro, event_name, *args, **kwargs) + return self.client.loop.create_task( + wrapped, name=f"voice-receiver-event-dispatch: {event_name}" + ) + + def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: + _log.debug("Dispatching voice_client event %s", event) + + event_name = f"on_{event}" + for coro in self._event_listeners.get(event_name, []): + task = self._schedule_event(coro, event_name, *args, **kwargs) + self._connection._dispatch_task_set.add(task) + task.add_done_callback(self._connection._dispatch_task_set.discard) + + self._dispatch_sink(event, *args, **kwargs) + self.client.dispatch(event, *args, **kwargs) + + async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + old_channel_id = self.channel.id if self.channel else None + await self._connection.voice_state_update(data) + + if data.channel_id is None: + return + + if self._reader and data.channel_id != old_channel_id: + _log.debug("Destroying voice receive decoders in guild %s", self.guild.id) + self._reader.packet_router.destroy_all_decoders() + + async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + await self._connection.voice_server_update(data) + + def _dispatch_sink(self, event: str, /, *args: Any, **kwargs: Any) -> None: + if self._reader: + self._reader.event_router.dispatch(event, *args, **kwargs) + + def _add_ssrc(self, user_id: int, ssrc: int) -> None: + self._ssrc_to_id[ssrc] = user_id + self._id_to_ssrc[user_id] = ssrc + + if self._reader: + self._reader.packet_router.set_user_id(ssrc, user_id) + + def _remove_ssrc(self, *, user_id: int) -> None: + ssrc = self._id_to_ssrc.pop(user_id, None) + + if ssrc: + self._reader.speaking_timer.drop_ssrc(ssrc) + self._ssrc_to_id.pop(ssrc, None) + + async def connect( + self, + *, + reconnect: bool, + timeout: float, + self_deaf: bool = False, + self_mute: bool = False, + ) -> None: + await self._connection.connect( + reconnect=reconnect, + timeout=timeout, + self_deaf=self_deaf, + self_mute=self_mute, + resume=False, + ) + + def wait_until_connected(self, timeout: float | None = 30.0) -> bool: + self._connection.wait_for(timeout=timeout) + return self._connection.is_connected() + + @property + def latency(self) -> float: + """Latency between a HEARTBEAT and a HEARBEAT_ACK in seconds. + + This chould be referred to as the Discord Voice WebSocket latency and is + and analogue of user's voice latencies as seen in the Discord client. + + .. versionadded:: 1.4 + """ + ws = self.ws + return float("inf") if not ws else ws.latency + + @property + def average_latency(self) -> float: + """Average of most recent 20 HEARBEAT latencies in seconds. + + .. versionadded:: 1.4 + """ + ws = self.ws + return float("inf") if not ws else ws.average_latency + + @property + def privacy_code(self) -> str | None: + """Returns the current voice session's privacy code, only available if the call has upgraded to use the + DAVE protocol + """ + session = self._connection.dave_session + return session and session.voice_privacy_code + + async def disconnect(self, *, force: bool = False) -> None: + """|coro| + + Disconnects this voice client from voice. + """ + + self.stop() + await self._connection.disconnect(force=force, wait=True) + self.cleanup() + + async def move_to( + self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0 + ) -> None: + """|coro| + + moves you to a different voice channel. + + Parameters + ---------- + channel: Optional[:class:`abc.Snowflake`] + The channel to move to. If this is ``None``, it is an equivalent of calling :meth:`.disconnect`. + timeout: Optional[:class:`float`] + The maximum time in seconds to wait for the channel move to be completed, defaults to 30. + If it is ``None``, then there is no timeout. + + Raises + ------ + asyncio.TimeoutError + Waiting for channel move timed out. + """ + await self._connection.move_to(channel, timeout) + + def is_connected(self) -> bool: + """Whether the voice client is connected to voice.""" + return self._connection.is_connected() + + def is_playing(self) -> bool: + """INdicates if we're playing audio.""" + return self._player is not None and self._player.is_playing() + + def is_paused(self) -> bool: + """Indicates if we're playing audio, but if we're paused.""" + return self._player is not None and self._player.is_paused() + + # audio related + + def _get_voice_packet(self, data: Any) -> bytes: + + session = self._connection.dave_session + packet = session.encrypt_opus(data) if session and session.ready else data + + header = bytearray(12) + + # formulate rtp header + header[0] = 0x80 + header[1] = 0x78 + struct.pack_into(">H", header, 2, self.sequence) + struct.pack_into(">I", header, 4, self.timestamp) + struct.pack_into(">I", header, 8, self.ssrc) + + encrypt_packet = getattr(self, f"_encrypt_{self.mode}") + return encrypt_packet(header, packet) + + # encryption methods + + def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:12] = header + return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + + def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) + return header + box.encrypt(bytes(data), nonce).ciphertext + nonce + + def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:4] = struct.pack(">I", self._incr_nonce) + self.checked_add("_incr_nonce", 1, 4294967295) + + return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] + + def _encrypt_aead_xchacha20_poly1305_rtpsize( + self, header: bytes, data: Any + ) -> bytes: + box = nacl.secret.Aead(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:4] = struct.pack(">I", self._incr_nonce) + self.checked_add("_incr_nonce", 1, 4294967295) + return ( + header + + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext + + nonce[:4] + ) + + @overload + def play( + self, + source: AudioSource, + *, + after: AfterCallback | None = ..., + application: APPLICATION_CTL = ..., + bitrate: int = ..., + fec: bool = ..., + expected_packet_loss: float = ..., + bandwidth: BAND_CTL = ..., + signal_type: SIGNAL_CTL = ..., + wait_finish: Literal[False] = ..., + ) -> None: ... + + @overload + def play( + self, + source: AudioSource, + *, + after: AfterCallback = ..., + application: APPLICATION_CTL = ..., + bitrate: int = ..., + fec: bool = ..., + expected_packet_loss: float = ..., + bandwidth: BAND_CTL = ..., + signal_type: SIGNAL_CTL = ..., + wait_finish: Literal[True], + ) -> asyncio.Future[None]: ... + + def play( + self, + source: AudioSource, + *, + after: AfterCallback | None = None, + application: APPLICATION_CTL = "audio", + bitrate: int = 128, + fec: bool = True, + expected_packet_loss: float = 0.15, + bandwidth: BAND_CTL = "full", + signal_type: SIGNAL_CTL = "auto", + wait_finish: bool = False, + ) -> None | asyncio.Future[None]: + """Plays an :class:`AudioSource`. + + The finalizer, ``after`` is called after the source has been exhausted + or an error occurred. + + IF an error happens while the audio player is running, the exception is + caught and the audio player is then stopped. If no after callback is passed, + any caught exception will be displayed as if it were raised. + + Parameters + ---------- + source: :class:`AudioSource` + The audio source we're reading from. + after: Callable[[Optional[:class:`Exception`]], Any] + The finalizer that is called after the stream is exhausted. + This function must have a single parameter, ``error``, that + denotes an optional exception that was raised during playing. + application: :class:`str` + The intended application encoder application type. Must be one of + ``audio``, ``voip``, or ``lowdelay``. Defaults to ``audio``. + bitrate: :class:`int` + The encoder's bitrate. Must be between ``16`` and ``512``. Defaults + to ``128``. + fec: :class:`bool` + Configures the encoder's use of inband forward error correction (fec). + Defaults to ``True``. + expected_packet_loss: :class:`float` + How much packet loss percentage is expected from the encoder. This requires ``fec`` + to be set to ``True``. Defaults to ``0.15``. + bandwidth: :class:`str` + The encoder's bandpass. Must be one of ``narrow``, ``medium``, ``wide``, + ``superwide``, or ``full``. Defaults to ``full``. + signal_type: :class:`str` + The type of signal being encoded. Must be one of ``auto``, ``voice``, ``music``. + Defaults to ``auto``. + wait_finish: :class:`bool` + If ``True``, then an awaitable is returned that waits for the audio source to be + exhausted, and will return an optional exception that could have been raised. + + If ``False``, ``None`` is returned and the function does not block. + + .. versionadded:: 2.5 + + Raises + ------ + ClientException + Already playing audio, or not connected to voice. + TypeError + Source is not a :class:`AudioSource`, or after is not a callable. + OpusNotLoaded + Source is not opus encoded and opus is not loaded. + """ + + if not self.is_connected(): + raise ClientException("Not connected to voice") + if self.is_playing(): + raise ClientException("Already playing audio") + if not isinstance(source, AudioSource): + raise TypeError( + f"Source must be an AudioSource, not {source.__class__.__name__}", + ) + if not self.encoder and not source.is_opus(): + self.encoder = opus.Encoder( + application=application, + bitrate=bitrate, + fec=fec, + expected_packet_loss=expected_packet_loss, + bandwidth=bandwidth, + signal_type=signal_type, + ) + + future = None + if wait_finish: + self._player_future = future = self.loop.create_future() + after_callback = after + + def _after(exc: Exception | None) -> None: + if callable(after_callback): + after_callback(exc) + future.set_result(exc) + + after = _after + + self._player = AudioPlayer(source, self, after=after) + self._player.start() + return future + + def stop(self) -> None: + """Stops playing audio, if applicable.""" + if self._player: + self._player.stop() + if self._player_future: + for cb, _ in self._player_future._callbacks: + self._player_future.remove_done_callback(cb) + self._player_future.set_result(None) + if self._reader is not MISSING: + self._reader.stop() + self._reader = MISSING + + self._player = None + self._player_future = None + + def pause(self) -> None: + """Pauses the audio playing.""" + if self._player: + self._player.pause() + + def resume(self) -> None: + """Resumes the audio playing.""" + if self._player: + self._player.resume() + + @property + def source(self) -> AudioSource | None: + """The audio source being player, if playing. + + This property can also be used to change the audio source currently being played. + """ + return self._player and self._player.source + + @source.setter + def source(self, value: AudioSource) -> None: + if not isinstance(value, AudioSource): + raise TypeError(f"expected AudioSource, not {value.__class__.__name__}") + + if self._player is None: + raise ValueError("the client is not playing anything") + + self._player.set_source(value) + + def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: + """Sends an audio packet composed of the ``data``. + + You must be connected to play audio. + + Parameters + ---------- + data: :class:`bytes` + The :term:`py:bytes-like object` denoting PCM or Opus voice data. + encode: :class:`bool` + Indicates if ``data`` should be encoded into Opus. + + Raises + ------ + ClientException + You are not connected. + opus.OpusError + Encoding the data failed. + """ + + self.checked_add("sequence", 1, 65535) + if encode: + encoded = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) + else: + encoded = data + + packet = self._get_voice_packet(encoded) + try: + self._connection.send_packet(packet) + except OSError: + _log.debug( + "A packet has been dropped (seq: %s, timestamp: %s)", + self.sequence, + self.timestamp, + ) + + self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295) + + def elapsed(self) -> datetime.timedelta: + """Returns the elapsed time of the playing audio.""" + if self._player: + return datetime.timedelta(milliseconds=self._player.played_frames() * 20) + return datetime.timedelta() + + def start_recording( + self, + sink: Sink, + callback: AfterCallback | None = None, + *args: Any, + sync_start: bool = MISSING, + ) -> None: + r"""Start recording the audio from the current connected channel to the provided sink. + + .. versionadded:: 2.0 + .. versionchanged:: 2.7 + You can now have multiple concurrent recording sinks in the same voice client. + + Parameters + ---------- + sink: :class:`~.Sink` + A Sink in which all audio packets will be processed in. + callback: Callable[[:class:`Exception` | None], Any] + A function which is called after the bot has stopped recording. This must take exactly one positonal(-only) + parameter, ``exception``, which is the exception that was raised during the recording of the Sink. + + .. versionchanged:: 2.7 + This parameter is now optional, and must take exactly one parameter, ``exception``. + \*args: + The arguments to pass to the callback coroutine. + + .. deprecated:: 2.7 + Passing custom arguments to the callback is now deprecated and ignored. + sync_start: :class:`bool` + If ``True``, the recordings of subsequent users will start with silence. + This is useful for recording audio just as it was heard. + + .. deprecated:: 2.7 + This parameter is now ignored and deprecated. + + Raises + ------ + RecordingException + Not connected to a voice channel + TypeError + You did not provide a Sink object. + """ + + if not self.is_connected(): + raise RecordingException("not connected to a voice channel") + if not isinstance(sink, Sink): + raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}") + + if self.is_recording(): + raise ClientException("Already recording audio") + + if len(args) > 0: + warnings.warn( + "'args' parameter is deprecated since 2.7 and will be removed in 3.0" + ) + if sync_start is not MISSING: + warnings.warn( + "'sync_tart' parameter is deprecated since 2.7 and will be removed in 3.0" + ) + + self._reader = AudioReader(sink, self, after=callback) + self._reader.start() + + def stop_recording(self) -> None: + """Stops the recording of the provided ``sink``, or all recording sinks. + + .. versionadded:: 2.0 + + Raises + ------ + RecordingException + You are not recording. + """ + if self._reader is not MISSING: + self._reader.stop() + self._reader = MISSING + else: + raise RecordingException("You are not recording") + + def is_recording(self) -> bool: + """Whether the current client is recording in any sink.""" + return self._reader and self._reader.is_listening() + + def is_speaking(self, member: Member | User) -> bool | None: + """Whether a user is speaking. + + This is an approximate calculation and may have outdated or wrong data. + + If the member speaking status has not been yet saved, it returns ``None``. + + .. versionadded:: 2.7 + """ + ssrc = self._id_to_ssrc.get(member.id) + if ssrc is None: + return None + if self._reader: + return self._reader.speaking_timer.get_speaking(ssrc) diff --git a/discord/voice/enums.py b/discord/voice/enums.py new file mode 100644 index 0000000000..76f78c768d --- /dev/null +++ b/discord/voice/enums.py @@ -0,0 +1,79 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from discord.enums import Enum + + +class OpCodes(Enum): + identify = 0 + select_protocol = 1 + ready = 2 + heartbeat = 3 + session_description = 4 + speaking = 5 + heartbeat_ack = 6 + resume = 7 + hello = 8 + resumed = 9 + clients_connect = 11 + client_connect = 12 + client_disconnect = 13 + + # dave protocol stuff + dave_prepare_transition = 21 + dave_execute_transition = 22 + dave_transition_ready = 23 + dave_prepare_epoch = 24 + mls_external_sender_package = 25 + mls_key_package = 26 + mls_proposals = 27 + mls_commit_welcome = 28 + mls_commit_transition = 29 + mls_welcome = 30 + mls_invalid_commit_welcome = 31 + + def __eq__(self, other: object) -> bool: + if isinstance(other, int): + return self.value == other + elif isinstance(other, self.__class__): + return self is other + return NotImplemented + + def __int__(self) -> int: + return self.value + + +class ConnectionFlowState(Enum): + disconnected = 0 + set_guild_voice_state = 1 + got_voice_state_update = 2 + got_voice_server_update = 3 + got_both_voice_updates = 4 + websocket_connected = 5 + got_websocket_ready = 6 + got_ip_discovery = 7 + connected = 8 diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py new file mode 100644 index 0000000000..e622b07cd7 --- /dev/null +++ b/discord/voice/gateway.py @@ -0,0 +1,512 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import logging +import struct +import threading +import time +from collections import deque +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any + +import aiohttp +import davey + +from discord import utils +from discord.enums import SpeakingState +from discord.errors import ConnectionClosed +from discord.gateway import DiscordWebSocket +from discord.gateway import KeepAliveHandler as KeepAliveHandlerBase + +from .enums import OpCodes + +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt + from typing_extensions import Self + + from .state import VoiceConnectionState + +_log = logging.getLogger(__name__) + + +class KeepAliveHandler(KeepAliveHandlerBase): + if TYPE_CHECKING: + ws: VoiceWebSocket + + def __init__( + self, + *args: Any, + ws: VoiceWebSocket, + interval: float | None = None, + **kwargs: Any, + ) -> None: + daemon: bool = kwargs.pop("daemon", True) + name: str = kwargs.pop("name", f"voice-keep-alive-handler:{id(self):#x}") + super().__init__( + *args, + **kwargs, + name=name, + daemon=daemon, + ws=ws, + interval=interval, + ) + + self.msg: str = "Keeping shard ID %s voice websocket alive with timestamp %s." + self.block_msg: str = ( + "Shard ID %s voice heartbeat blocked for more than %s seconds." + ) + self.behing_msg: str = ( + "High socket latency, shard ID %s heartbeat is %.1fs behind." + ) + self.recent_ack_latencies: deque[float] = deque(maxlen=20) + + def get_payload(self) -> dict[str, Any]: + return { + "op": int(OpCodes.heartbeat), + "d": { + "t": int(time.time() * 1000), + "seq_ack": self.ws.seq_ack, + }, + } + + def ack(self) -> None: + ack_time = time.perf_counter() + self._last_ack = ack_time + self._last_recv = ack_time + self.latency = ack_time - self._last_send + self.recent_ack_latencies.append(self.latency) + + +class VoiceWebSocket(DiscordWebSocket): + def __init__( + self, + socket: aiohttp.ClientWebSocketResponse, + loop: asyncio.AbstractEventLoop, + state: VoiceConnectionState, + *, + hook: Callable[..., Coroutine[Any, Any, Any]] | None = None, + ) -> None: + self.ws: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop + self._keep_alive: KeepAliveHandler | None = None + self._close_code: int | None = None + self.secret_key: list[int] | None = None + self.seq_ack: int = -1 + self.state: VoiceConnectionState = state + self.ssrc_map: dict[str, dict[str, Any]] = {} + self.known_users: dict[int, Any] = {} + + if hook: + self._hook = hook or state.ws_hook # type: ignore + + @property + def token(self) -> str | None: + return self.state.token + + @token.setter + def token(self, value: str | None) -> None: + self.state.token = value + + @property + def session_id(self) -> str | None: + return self.state.session_id + + @session_id.setter + def session_id(self, value: str | None) -> None: + self.state.session_id = value + + @property + def self_id(self) -> int: + return self._connection.self_id + + async def _hook(self, *args: Any) -> Any: + pass + + async def send_as_bytes(self, op: ConvertibleToInt, data: bytes) -> None: + packet = bytes([int(op)]) + data + _log.debug( + "Sending voice websocket binary frame: op: %s size: %d", op, len(data) + ) + await self.ws.send_bytes(packet) + + async def send_as_json(self, data: Any) -> None: + _log.debug("Sending voice websocket frame: %s.", data) + if data.get("op", None) == OpCodes.identify: + _log.info("Identifying ourselves: %s", data) + await self.ws.send_str(utils._to_json(data)) + + send_heartbeat = send_as_json + + async def resume(self) -> None: + payload = { + "op": int(OpCodes.resume), + "d": { + "token": self.token, + "server_id": str(self.state.server_id), + "session_id": self.session_id, + "seq_ack": self.seq_ack, + }, + } + await self.send_as_json(payload) + + async def received_message(self, msg: Any, /): + _log.debug("Voice websocket frame received: %s", msg) + op = msg["op"] + data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways + self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated + state = self.state + + if op == OpCodes.ready: + await self.ready(data) + elif op == OpCodes.heartbeat_ack: + if not self._keep_alive: + _log.error( + "Received a heartbeat ACK but no keep alive handler was set.", + ) + return + self._keep_alive.ack() + elif op == OpCodes.resumed: + _log.info( + f"Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was " + "successfully RESUMED.", + ) + elif op == OpCodes.session_description: + state.mode = data["mode"] + state.dave_protocol_version = data["dave_protocol_version"] + await self.load_secret_key(data) + await state.reinit_dave_session() + elif op == OpCodes.hello: + interval = data["heartbeat_interval"] / 1000.0 + self._keep_alive = KeepAliveHandler( + ws=self, + interval=min(interval, 5), + ) + self._keep_alive.start() + elif state.dave_session: + if op == OpCodes.dave_prepare_transition: + _log.info( + "Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d", + state.channel_id, + data["transition_id"], + data["protocol_version"], + ) + state.dave_pending_transition = data + + transition_id = data["transition_id"] + + if transition_id == 0: + await state.execute_dave_transition(data["transition_id"]) + else: + if data["protocol_version"] == 0 and state.dave_session: + state.dave_session.set_passthrough_mode(True, 120) + await self.send_dave_transition_ready(transition_id) + elif op == OpCodes.dave_execute_transition: + _log.info( + "Upgrading to DAVE connection for channel %s", state.channel_id + ) + await state.execute_dave_transition(data["transition_id"]) + elif op == OpCodes.dave_prepare_epoch: + epoch = data["epoch"] + _log.debug( + "Preparing for DAVE epoch in channel %s: %s", + state.channel_id, + epoch, + ) + # if epoch is 1 then a new MLS group is to be created for the proto version + if epoch == 1: + state.dave_protocol_version = data["protocol_version"] + await state.reinit_dave_session() + else: + _log.debug("Unhandled op code: %s with data %s", op, data) + + await utils.maybe_coroutine(self._hook, self, msg) + + async def received_binary_message(self, msg: bytes) -> None: + self.seq_ack = struct.unpack_from(">H", msg, 0)[0] + op = msg[2] + _log.debug( + "Voice websocket binary frame received: %d bytes, seq: %s, op: %s", + len(msg), + self.seq_ack, + op, + ) + + state = self.state + + if not state.dave_session: + return + + if op == OpCodes.mls_external_sender_package: + state.dave_session.set_external_sender(msg[3:]) + elif op == OpCodes.mls_proposals: + op_type = msg[3] + result = state.dave_session.process_proposals( + ( + davey.ProposalsOperationType.append + if op_type == 0 + else davey.ProposalsOperationType.revoke + ), + msg[4:], + ) + + if isinstance(result, davey.CommitWelcome): + data = ( + (result.commit + result.welcome) + if result.welcome + else result.commit + ) + _log.debug("Sending MLS key package with data: %s", data) + await self.send_as_bytes( + OpCodes.mls_commit_welcome, + data, + ) + _log.debug("Processed MLS proposals for current dave session: %r", result) + elif op == OpCodes.mls_commit_transition: + transt_id = struct.unpack_from(">H", msg, 3)[0] + try: + state.dave_session.process_commit(msg[5:]) + if transt_id != 0: + state.dave_pending_transition = { + "transition_id": transt_id, + "protocol_version": state.dave_protocol_version, + } + _log.debug( + "Sending DAVE transition ready from MLS commit transition with data: %s", + state.dave_pending_transition, + ) + await self.send_dave_transition_ready(transt_id) + _log.debug("Processed MLS commit for transition %s", transt_id) + except Exception as exc: + _log.debug( + "An exception ocurred while processing a MLS commit, this should be safe to ignore: %s", + exc, + ) + await state.recover_dave_from_invalid_commit(transt_id) + elif op == OpCodes.mls_welcome: + transt_id = struct.unpack_from(">H", msg, 3)[0] + try: + state.dave_session.process_welcome(msg[5:]) + if transt_id != 0: + state.dave_pending_transition = { + "transition_id": transt_id, + "protocol_version": state.dave_protocol_version, + } + _log.debug( + "Sending DAVE transition ready from MLS welcome with data: %s", + state.dave_pending_transition, + ) + await self.send_dave_transition_ready(transt_id) + _log.debug("Processed MLS welcome for transition %s", transt_id) + except Exception as exc: + _log.debug( + "An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s", + exc, + ) + await state.recover_dave_from_invalid_commit(transt_id) + + async def ready(self, data: dict[str, Any]) -> None: + state = self.state + + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + state.endpoint_ip = data["ip"] + + _log.debug( + f"Connecting to {state.endpoint_ip} (port {state.voice_port}).", + ) + + await self.loop.sock_connect( + state.socket, + (state.endpoint_ip, state.voice_port), + ) + + _log.debug( + "Connected socket to %s (port %s)", + state.endpoint_ip, + state.voice_port, + ) + + state.ip, state.port = await self.get_ip() + + modes = [mode for mode in data["modes"] if mode in self.state.supported_modes] + _log.debug("Received available voice connection modes: %s", modes) + + mode = modes[0] + await self.select_protocol(state.ip, state.port, mode) + _log.debug("Selected voice protocol %s for this connection", mode) + + async def select_protocol(self, ip: str, port: int, mode: str) -> None: + payload = { + "op": int(OpCodes.select_protocol), + "d": { + "protocol": "udp", + "data": { + "address": ip, + "port": port, + "mode": mode, + }, + }, + } + await self.send_as_json(payload) + + async def get_ip(self) -> tuple[str, int]: + state = self.state + packet = bytearray(74) + struct.pack_into(">H", packet, 0, 1) # 1 = Send + struct.pack_into(">H", packet, 2, 70) # 70 = Length + struct.pack_into(">I", packet, 4, state.ssrc) + + _log.debug( + f"Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})" + ) + await self.loop.sock_sendall(state.socket, packet) + + fut: asyncio.Future[bytes] = self.loop.create_future() + + def get_ip_packet(data: bytes) -> None: + if data[1] == 0x02 and len(data) == 74: + self.loop.call_soon_threadsafe(fut.set_result, data) + + fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) + state.add_socket_listener(get_ip_packet) + recv = await fut + + _log.debug("Received IP discovery packet with data %s", recv) + + ip_start = 8 + ip_end = recv.index(0, ip_start) + ip = recv[ip_start:ip_end].decode("ascii") + port = struct.unpack_from(">H", recv, len(recv) - 2)[0] + _log.debug("Detected IP %s with port %s", ip, port) + + return ip, port + + @property + def latency(self) -> float: + heartbeat = self._keep_alive + return float("inf") if heartbeat is None else heartbeat.latency + + @property + def average_latency(self) -> float: + heartbeat = self._keep_alive + if heartbeat is None or not heartbeat.recent_ack_latencies: + return float("inf") + return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) + + async def load_secret_key(self, data: dict[str, Any]) -> None: + _log.debug( + f"Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})" + ) + self.secret_key = self.state.secret_key = data["secret_key"] + await self.speak(SpeakingState.none) + + async def poll_event(self) -> None: + msg = await asyncio.wait_for(self.ws.receive(), timeout=30) + + if msg.type is aiohttp.WSMsgType.TEXT: + _log.debug("Received text payload: %s", msg.data) + await self.received_message(utils._from_json(msg.data)) + elif msg.type is aiohttp.WSMsgType.BINARY: + _log.debug("Received binary payload: size: %d", len(msg.data)) + await self.received_binary_message(msg.data) + elif msg.type is aiohttp.WSMsgType.ERROR: + _log.debug("Received %s", msg) + raise ConnectionClosed(self.ws, shard_id=None) from msg.data + elif msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + _log.debug("Received %s", msg) + raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) + + async def close(self, code: int = 1000) -> None: + if self._keep_alive: + self._keep_alive.stop() + + self._close_code = code + await self.ws.close(code=self._close_code) + + async def speak(self, state: SpeakingState = SpeakingState.voice) -> None: + await self.send_as_json( + { + "op": int(OpCodes.speaking), + "d": { + "speaking": int(state), + "delay": 0, + }, + }, + ) + + @classmethod + async def from_state( + cls, + state: VoiceConnectionState, + *, + resume: bool = False, + hook: Callable[..., Coroutine[Any, Any, Any]] | None = None, + seq_ack: int = -1, + ) -> Self: + gateway = f"wss://{state.endpoint}/?v=8" + client = state.client + http = client._state.http + socket = await http.ws_connect(gateway, compress=15) + ws = cls(socket, loop=client.loop, hook=hook, state=state) + ws.gateway = gateway + ws.seq_ack = seq_ack + ws._max_heartbeat_timeout = 60.0 + ws.thread_id = threading.get_ident() + + if resume: + await ws.resume() + else: + await ws.identify() + return ws + + async def identify(self) -> None: + state = self.state + payload = { + "op": int(OpCodes.identify), + "d": { + "server_id": str(state.server_id), + "user_id": str(state.user.id), + "session_id": self.session_id, + "token": self.token, + "max_dave_protocol_version": state.max_dave_proto_version, + }, + } + await self.send_as_json(payload) + + async def send_dave_transition_ready(self, transition_id: int) -> None: + payload = { + "op": int(OpCodes.dave_transition_ready), + "d": { + "transition_id": transition_id, + }, + } + await self.send_as_json(payload) diff --git a/discord/voice/packets/__init__.py b/discord/voice/packets/__init__.py new file mode 100644 index 0000000000..243189df61 --- /dev/null +++ b/discord/voice/packets/__init__.py @@ -0,0 +1,61 @@ +""" +discord.voice.packets +~~~~~~~~~~~~~~~~~~~~~ + +Sink packet handlers. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .core import Packet +from .rtp import ( + FakePacket, + ReceiverReportPacket, + RTCPPacket, + RTPPacket, + SenderReportPacket, + SilencePacket, +) + +if TYPE_CHECKING: + from discord import Member, User + +__all__ = ( + "Packet", + "RTPPacket", + "RTCPPacket", + "FakePacket", + "ReceiverReportPacket", + "SenderReportPacket", + "SilencePacket", + "VoiceData", +) + + +class VoiceData: + """Represents an audio data from a source. + + .. versionadded:: 2.7 + + Attributes + ---------- + packet: :class:`~discord.sinks.Packet` + The packet this source data contains. + source: :class:`~discord.User` | :class:`~discord.Member` | None + The user that emitted this audio source. + pcm: :class:`bytes` + The PCM bytes of this source. + """ + + def __init__( + self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None + ) -> None: + self.packet: Packet = packet + self.source: User | Member | None = source + self.pcm: bytes = pcm if pcm else b"" + + @property + def opus(self) -> bytes | None: + self.packet.decrypted_data diff --git a/discord/voice/packets/core.py b/discord/voice/packets/core.py new file mode 100644 index 0000000000..8f7a3960ec --- /dev/null +++ b/discord/voice/packets/core.py @@ -0,0 +1,90 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Final + +OPUS_SILENCE: Final = b"\xf8\xff\xfe" + + +class Packet: + """Represents an audio stream bytes packet. + + Attributes + ---------- + data: :class:`bytes` + The bytes data of this packet. This has not been decoded. + """ + + if TYPE_CHECKING: + ssrc: int + sequence: int + timestamp: int + type: int + decrypted_data: bytes + + def __init__(self, data: bytes) -> None: + self.data: bytes = data + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}> data={len(self.data)} bytes>" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) + return self.sequence == other.sequence and self.timestamp == other.timestamp + + def __gt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) + return self.sequence > other.sequence and self.timestamp > other.timestamp + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) + return self.sequence < other.sequence and self.timestamp < other.timestamp + + def is_silence(self) -> bool: + data = getattr(self, "decrypted_data", None) + return data == OPUS_SILENCE + + def __hash__(self) -> int: + return hash(self.data) diff --git a/discord/voice/packets/rtp.py b/discord/voice/packets/rtp.py new file mode 100644 index 0000000000..842a69fe2a --- /dev/null +++ b/discord/voice/packets/rtp.py @@ -0,0 +1,306 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import struct +from collections import namedtuple +from typing import TYPE_CHECKING, Any, Literal + +from .core import OPUS_SILENCE, Packet + +if TYPE_CHECKING: + from typing_extensions import Final + +MAX_UINT_32 = 0xFFFFFFFF +MAX_UINT_16 = 0xFFFF + +RTP_PACKET_TYPE_VOICE = 120 + + +def decode(data: bytes) -> Packet: + if not data[0] >> 6 == 2: + raise ValueError(f"Invalid packet header 0b{data[0]:0>8b}") + return _rtcp_map.get(data[1], RTPPacket)(data) + + +class FakePacket(Packet): + data = b"" + decrypted_data: bytes = b"" + extension_data: dict = {} + + def __init__( + self, + ssrc: int, + sequence: int, + timestamp: int, + ) -> None: + self.ssrc = ssrc + self.sequence = sequence + self.timestamp = timestamp + + def __bool__(self) -> Literal[False]: + return False + + +class SilencePacket(Packet): + decrypted_data: Final = OPUS_SILENCE + extension_data: Final[dict[int, Any]] = {} + sequence: int = -1 + + def __init__(self, ssrc: int, timestamp: int) -> None: + self.ssrc = ssrc + self.timestamp = timestamp + + def is_silence(self) -> bool: + return True + + +class RTPPacket(Packet): + """Represents an RTP packet. + + .. versionadded:: 2.7 + + Attributes + ---------- + data: :class:`bytes` + The raw data of the packet. + """ + + _hstruct = struct.Struct(">xxHII") + _ext_header = namedtuple("Extension", "profile length values") + _ext_magic = b"\xbe\xde" + + def __init__(self, data: bytes) -> None: + super().__init__(data) + + self.version: int = data[0] >> 6 + self.padding: bool = bool(data[0] & 0b00100000) + self.extended: bool = bool(data[0] & 0b00010000) + self.cc: int = data[0] & 0b00001111 + + self.marker: bool = bool(data[1] & 0b10000000) + self.payload: int = data[1] & 0b01111111 + + sequence, timestamp, ssrc = self._hstruct.unpack_from(data) + self.sequence = sequence + self.timestamp = timestamp + self.ssrc = ssrc + + self.csrcs: tuple[int, ...] = () + self.extension = None + self.extension_data: dict[int, bytes] = {} + + self.header = data[:12] + self.data = data[12:] + self.decrypted_data: bytes | None = None + + self.nonce: bytes = b"" + self._rtpsize: bool = False + + if self.cc: + fmt = ">%sI" % self.cc + offset = struct.calcsize(fmt) + 12 + self.csrcs = struct.unpack(fmt, data[12:offset]) + self.data = data[offset:] + + def adjust_rtpsize(self) -> None: + """Automatically adjusts this packet header and data based on the rtpsize format.""" + + self._rtpsize = True + self.nonce = self.data[-4:] + + if not self.extended: + self.data = self.data[:-4] + return + + self.header += self.data[:4] + self.data = self.data[4:-4] + + def update_extended_header(self, data: bytes) -> int: + """Updates the extended header using ``data`` and returns the pd offset.""" + + if not self.extended: + return 0 + + if self._rtpsize: + data = self.header[-4:] + data + + profile, length = struct.unpack_from(">2sH", data) + + if profile == self._ext_magic: + self._parse_bede_header(data, length) + + values = struct.unpack(">%sI" % length, data[4 : 4 + length * 4]) + self.extension = self._ext_header(profile, length, values) + + offset = 4 + length * 4 + if self._rtpsize: + offset -= 4 + + return offset + + def _parse_bede_header(self, data: bytes, length: int) -> None: + offset = 4 + n = 0 + + while n < length: + next_byte = data[offset : offset + 1] + + if next_byte == b"\x00": + offset += 1 + continue + + header = struct.unpack(">B", next_byte)[0] + el_id = header >> 4 + el_len = 1 + (header & 0b0000_1111) + + self.extension_data[el_id] = data[offset + 1 : offset + 1 + el_len] + offset += 1 + el_len + n += 1 + + def __repr__(self) -> str: + return ( + "" + ) + + +class RTCPPacket(Packet): + _header = struct.Struct(">BBH") + _ssrc_fmt = struct.Struct(">I") + type = None + + def __init__(self, data: bytes) -> None: + super().__init__(data) + self.length: int + head, _, self.length = self._header.unpack_from(data) + + self.version: int = head >> 6 + self.padding: bool = bool(head & 0b00100000) + setattr(self, "report_count", head & 0b00011111) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} version={self.version} padding={self.padding} length={self.length}>" + + @classmethod + def from_data(cls, data: bytes) -> Packet: + _, ptype, _ = cls._header.unpack_from(data) + return _rtcp_map[ptype](data) + + +def _parse_low(x: int, bitlen: int = 32) -> float: + return x / 2.0**bitlen + + +def _to_low(x: float, bitlen: int = 32) -> int: + return int(x * 2.0**bitlen) + + +class SenderReportPacket(RTCPPacket): + _info_fmt = struct.Struct(">5I") + _report_fmt = struct.Struct(">IB3x4I") + _24bit_int_fmt = struct.Struct(">4xI") + _info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count") + _report = namedtuple( + "RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr" + ) + type = 200 + + if TYPE_CHECKING: + report_count: int + + def __init__(self, data: bytes) -> None: + super().__init__(data) + + self.ssrc = self._ssrc_fmt.unpack_from(data, 4)[0] + self.info = self._read_sender_info(data, 8) + + _report = self._report + reports: list[_report] = [] + for x in range(self.report_count): + offset = 28 + 24 * x + reports.append(self._read_report(data, offset)) + + self.reports: tuple[_report, ...] = tuple(reports) + self.extension = None + if len(data) > 28 + 24 * self.report_count: + self.extension = data[28 + 24 * self.report_count :] + + def _read_sender_info(self, data: bytes, offset: int) -> _info: + nhigh, nlow, rtp_ts, pcount, ocount = self._info_fmt.unpack_from(data, offset) + ntotal = nhigh + _parse_low(nlow) + return self._info(ntotal, rtp_ts, pcount, ocount) + + def _read_report(self, data: bytes, offset: int) -> _report: + ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset) + clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF + return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr) + + +class ReceiverReportPacket(RTCPPacket): + _report_fmt = struct.Struct(">IB3x4I") + _24bit_int_fmt = struct.Struct(">4xI") + _report = namedtuple( + "RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr" + ) + type = 201 + + reports: tuple[_report, ...] + + if TYPE_CHECKING: + report_count: int + + def __init__(self, data: bytes) -> None: + super().__init__(data) + self.ssrc: int = self._ssrc_fmt.unpack_from(data, 4)[0] + + _report = self._report + reports: list[_report] = [] + for x in range(self.report_count): + offset = 8 + 24 * x + reports.append(self._read_report(data, offset)) + + self.reports = tuple(reports) + + self.extension: bytes | None = None + if len(data) > 8 + 24 * self.report_count: + self.extension = data[8 + 24 * self.report_count :] + + def _read_report(self, data: bytes, offset: int) -> _report: + ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset) + clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF + return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr) + + +_rtcp_map = { + 200: SenderReportPacket, + 201: ReceiverReportPacket, +} diff --git a/discord/voice/receive/__init__.py b/discord/voice/receive/__init__.py new file mode 100644 index 0000000000..b8d4662310 --- /dev/null +++ b/discord/voice/receive/__init__.py @@ -0,0 +1,2 @@ +from .reader import AudioReader +from .router import PacketRouter, SinkEventRouter diff --git a/discord/voice/receive/reader.py b/discord/voice/receive/reader.py new file mode 100644 index 0000000000..0329b5371a --- /dev/null +++ b/discord/voice/receive/reader.py @@ -0,0 +1,494 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import logging +import threading +import time +from collections.abc import Callable +from operator import itemgetter +from typing import TYPE_CHECKING, Any, Literal + +import davey + +from discord.opus import PacketDecoder + +from ..packets.rtp import ReceiverReportPacket, decode +from .router import PacketRouter, SinkEventRouter + +try: + import nacl.secret + from nacl.exceptions import CryptoError +except ImportError as exc: + raise RuntimeError( + "can't use voice receiver without PyNaCl installed, please install it with the 'py-cord[voice]' extra." + ) from exc + + +if TYPE_CHECKING: + from discord.member import Member + from discord.sinks import Sink + from discord.types.voice import SupportedModes + + from ..client import VoiceClient + from ..packets import RTPPacket + + AfterCallback = Callable[[Exception | None], Any] + DecryptRTP = Callable[[RTPPacket], bytes] + DecryptRTCP = Callable[[bytes], bytes] + SpeakingEvent = Literal["member_speaking_start", "member_speaking_stop"] + EncryptionBox = nacl.secret.SecretBox | nacl.secret.Aead + +_log = logging.getLogger(__name__) + +__all__ = ("AudioReader",) + + +def is_rtcp(data: bytes) -> bool: + return 200 <= data[1] <= 204 + + +class AudioReader: + def __init__( + self, sink: Sink, client: VoiceClient, *, after: AfterCallback | None = None + ) -> None: + if after is not None and not callable(after): + raise TypeError( + f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead" + ) + + self.sink: Sink = sink + self.client: VoiceClient = client + self.after: AfterCallback | None = after + + self.sink._client = client + + self.active: bool = False + self.error: Exception | None = None + self.packet_router: PacketRouter = PacketRouter(self.sink, self) + self.event_router: SinkEventRouter = SinkEventRouter(self.sink, self) + self.decryptor: PacketDecryptor = PacketDecryptor( + client.mode, bytes(client.secret_key), client + ) + self.speaking_timer: SpeakingTimer = SpeakingTimer(self) + self.keep_alive: UDPKeepAlive = UDPKeepAlive(client) + + def is_listening(self) -> bool: + return self.active + + def update_secret_key(self, secret_key: bytes) -> None: + self.decryptor.update_secret_key(secret_key) + + def start(self) -> None: + if self.active: + _log.debug("Reader is already running", exc_info=True) + return + + self.speaking_timer.start() + self.event_router.start() + self.packet_router.start() + self.client._connection.add_socket_listener(self.callback) + self.keep_alive.start() + self.active = True + + def stop(self) -> None: + if not self.active: + _log.debug("Reader is not active", exc_info=True) + return + + self.client._connection.remove_socket_listener(self.callback) + self.active = False + self.speaking_timer.notify() + + threading.Thread( + target=self._stop, name=f"voice-receiver-audio-reader-stop:{id(self):#x}" + ).start() + + def _stop(self) -> None: + try: + self.packet_router.stop() + except Exception as exc: + self.error = exc + _log.exception("An error ocurred while stopping packet router.") + + try: + self.event_router.stop() + except Exception as exc: + self.error = exc + _log.exception("An error ocurred while stopping event router.") + + self.speaking_timer.stop() + self.keep_alive.stop() + + if self.after: + try: + self.after(self.error) + except Exception: + _log.exception( + "An error ocurred while calling the after callback on audio reader" + ) + + for sink in self.sink.root.walk_children(with_self=True): + try: + sink.cleanup() + except Exception as exc: + _log.exception("Error calling cleanup() for %s", sink, exc_info=exc) + + def set_sink(self, sink: Sink) -> Sink: + old_sink = self.sink + old_sink._client = None + sink._client = self.client + self.packet_router.set_sink(sink) + self.sink = sink + return old_sink + + def _is_ip_discovery_packet(self, data: bytes) -> bool: + return len(data) == 74 and data[1] == 0x02 + + def callback(self, packet_data: bytes) -> None: + + packet = rtp_packet = rtcp_packet = None + + try: + if not is_rtcp(packet_data): + packet = rtp_packet = decode(packet_data) + packet.decrypted_data = self.decryptor.decrypt_rtp(packet) # type: ignore + else: + packet = rtcp_packet = decode(packet_data) + + if not isinstance(packet, ReceiverReportPacket): + _log.info( + "Received unexpected rtcp packet type=%s, %s", + packet.type, + type(packet), + ) + except CryptoError as exc: + _log.error("CryptoError while decoding a voice packet", exc_info=exc) + return + except Exception as exc: + if self._is_ip_discovery_packet(packet_data): + _log.debug("Received an IP Discovery Packet, ignoring...") + return + _log.exception( + "An exception ocurred while decoding voice packets", exc_info=exc + ) + finally: + if self.error: + self.stop() + return + if not packet: + return + + if rtcp_packet: + self.packet_router.feed_rtcp(rtcp_packet) # type: ignore + elif rtp_packet: + ssrc = rtp_packet.ssrc + + if ssrc not in self.client._connection.user_ssrc_map: + if rtp_packet.is_silence(): + return + else: + _log.info( + "Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet + ) + + self.speaking_timer.notify(ssrc) + + try: + self.packet_router.feed_rtp(rtp_packet) # type: ignore + except Exception as exc: + _log.exception( + "An error ocurred while processing RTP packet %s", rtp_packet + ) + self.error = exc + self.stop() + + +class PacketDecryptor: + supported_modes: list[SupportedModes] = [ + "aead_xchacha20_poly1305_rtpsize", + "xsalsa20_poly1305", + "xsalsa20_poly1305_lite", + "xsalsa20_poly1305_suffix", + ] + + def __init__( + self, mode: SupportedModes, secret_key: bytes, client: VoiceClient + ) -> None: + self.mode: SupportedModes = mode + self.client: VoiceClient = client + + try: + self._decryptor_rtp: DecryptRTP = getattr(self, "_decrypt_rtp_" + mode) + self._decryptor_rtcp: DecryptRTCP = getattr(self, "_decrypt_rtcp_" + mode) + except AttributeError as exc: + raise NotImplementedError(mode) from exc + + self.box: EncryptionBox = self._make_box(secret_key) + + def _make_box(self, secret_key: bytes) -> EncryptionBox: + if self.mode.startswith("aead"): + return nacl.secret.Aead(secret_key) + else: + return nacl.secret.SecretBox(secret_key) + + def decrypt_rtp(self, packet: RTPPacket) -> bytes: + state = self.client._connection + dave = state.dave_session + data = self._decryptor_rtp(packet) + + if dave is not None and dave.ready and packet.ssrc in state.user_ssrc_map: + return dave.decrypt( + state.user_ssrc_map[packet.ssrc], davey.MediaType.audio, data + ) + return data + + def decrypt_rtcp(self, packet: bytes) -> bytes: + data = self._decryptor_rtcp(packet) + # TODO: guess how to get the SSRC so we can use dave + return data + + def update_secret_key(self, secret_key: bytes) -> None: + self.box = self._make_box(secret_key) + + def _decrypt_rtp_xsalsa20_poly1305(self, packet: RTPPacket) -> bytes: + nonce = bytearray(24) + nonce[:12] = packet.header + result = self.box.decrypt(bytes(packet.data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:8] = data[:8] + result = self.box.decrypt(data[8:], bytes(nonce)) + + return data[:8] + result + + def _decrypt_rtp_xsalsa20_poly1305_suffix(self, packet: RTPPacket) -> bytes: + nonce = packet.data[-24:] + voice_data = packet.data[:-24] + result = self.box.decrypt(bytes(voice_data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: + nonce = data[-24:] + header = data[:8] + result = self.box.decrypt(data[8:-24], nonce) + + return header + result + + def _decrypt_rtp_xsalsa20_poly1305_lite(self, packet: RTPPacket) -> bytes: + nonce = bytearray(24) + nonce[:4] = packet.data[-4:] + voice_data = packet.data[:-4] + result = self.box.decrypt(bytes(voice_data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:4] = data[-4:] + header = data[:8] + result = self.box.decrypt(data[8:-4], bytes(nonce)) + + return header + result + + def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> bytes: + packet.adjust_rtpsize() + + nonce = bytearray(24) + nonce[:4] = packet.nonce + voice_data = packet.data + + # Blob vomit + assert isinstance(self.box, nacl.secret.Aead) + result = self.box.decrypt(bytes(voice_data), bytes(packet.header), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:4] = data[-4:] + header = data[:8] + + assert isinstance(self.box, nacl.secret.Aead) + result = self.box.decrypt(data[8:-4], bytes(header), bytes(nonce)) + + return header + result + + +class SpeakingTimer(threading.Thread): + def __init__(self, reader: AudioReader) -> None: + super().__init__( + daemon=True, + name=f"voice-receiver-speaking-timer:{id(self):#x}", + ) + + self.reader: AudioReader = reader + self.client: VoiceClient = reader.client + self.speaking_timeout_delay: float = 0.2 + self.last_speaking_state: dict[int, bool] = {} + self.speaking_cache: dict[int, float] = {} + self.speaking_timer_event: threading.Event = threading.Event() + self._end_thread: threading.Event = threading.Event() + + def _lookup_member(self, ssrc: int) -> Member | None: + id = self.client._connection.user_ssrc_map.get(ssrc) + if not self.client.guild: + return None + return self.client.guild.get_member(id) if id else None + + def maybe_dispatch_speaking_start(self, ssrc: int) -> None: + tlast = self.speaking_cache.get(ssrc) + if tlast is None or tlast + self.speaking_timeout_delay < time.perf_counter(): + self.dispatch("member_speaking_start", ssrc) + + def dispatch(self, event: SpeakingEvent, ssrc: int) -> None: + member = self._lookup_member(ssrc) + if not member: + return None + self.client._dispatch_sink(event, member) + + def notify(self, ssrc: int | None = None) -> None: + if ssrc is not None: + self.last_speaking_state[ssrc] = True + self.maybe_dispatch_speaking_start(ssrc) + self.speaking_cache[ssrc] = time.perf_counter() + + self.speaking_timer_event.set() + self.speaking_timer_event.clear() + + def drop_ssrc(self, ssrc: int) -> None: + self.speaking_cache.pop(ssrc, None) + state = self.last_speaking_state.pop(ssrc, None) + if state: + self.dispatch("member_speaking_stop", ssrc) + self.notify() + + def get_speaking(self, ssrc: int) -> bool | None: + return self.last_speaking_state.get(ssrc) + + def stop(self) -> None: + self._end_thread.set() + self.notify() + + def run(self) -> None: + _i1 = itemgetter(1) + + def get_next_entry(): + cache = sorted(self.speaking_cache.items(), key=_i1) + for ssrc, tlast in cache: + if self.last_speaking_state.get(ssrc): + return ssrc, tlast + return None, None + + self.speaking_timer_event.wait() + while not self._end_thread.is_set(): + if not self.speaking_cache: + self.speaking_timer_event.wait() + + tnow = time.perf_counter() + ssrc, tlast = get_next_entry() + + if ssrc is None or tlast is None: + self.speaking_timer_event.wait() + continue + + self.speaking_timer_event.wait(tlast + self.speaking_timeout_delay - tnow) + + if time.perf_counter() < tlast + self.speaking_timeout_delay: + continue + + self.dispatch("member_speaking_stop", ssrc) + self.last_speaking_state[ssrc] = False + + +class UDPKeepAlive(threading.Thread): + delay: int = 5000 + + def __init__(self, client: VoiceClient) -> None: + super().__init__( + daemon=True, + name=f"voice-receiver-udp-keep-alive:{id(self):#x}", + ) + + self.client: VoiceClient = client + self.last_time: float = 0 + self.counter: int = 0 + self._end_thread: threading.Event = threading.Event() + + def run(self) -> None: + self.client.wait_until_connected() + + while not self._end_thread.is_set(): + vc = self.client + + try: + packet = self.counter.to_bytes(8, "big") + except OverflowError: + self.counter = 0 + continue + + try: + vc._connection.socket.sendto( + packet, (vc._connection.endpoint_ip, vc._connection.voice_port) + ) + except Exception as exc: + _log.debug( + "Error while sending udp keep alive to socket %s at %s:%s", + vc._connection.socket, + vc._connection.endpoint_ip, + vc._connection.voice_port, + exc_info=exc, + ) + vc.wait_until_connected() + if vc.is_connected(): + continue + break + else: + self.counter += 1 + time.sleep(self.delay) + + def stop(self) -> None: + self._end_thread.set() diff --git a/discord/voice/receive/router.py b/discord/voice/receive/router.py new file mode 100644 index 0000000000..0d2445540b --- /dev/null +++ b/discord/voice/receive/router.py @@ -0,0 +1,231 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import logging +import queue +import threading +from collections import deque +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from discord.opus import PacketDecoder + +from ..utils.multidataevent import MultiDataEvent + +if TYPE_CHECKING: + from discord.sinks import Sink + + from ..packets import RTCPPacket, RTPPacket + from .reader import AudioReader + + EventCB = Callable[..., Any] + EventData = tuple[str, tuple[Any, ...], dict[str, Any]] + +_log = logging.getLogger(__name__) + + +class PacketRouter(threading.Thread): + def __init__(self, sink: Sink, reader: AudioReader) -> None: + super().__init__( + daemon=True, + name=f"voice-receiver-packet-router:{id(self):#x}", + ) + + self.sink: Sink = sink + self.decoders: dict[int, PacketDecoder] = {} + self.reader: AudioReader = reader + self.waiter: MultiDataEvent[PacketDecoder] = MultiDataEvent() + + self._lock: threading.RLock = threading.RLock() + self._end_thread: threading.Event = threading.Event() + self._dropped_ssrcs: deque[int] = deque(maxlen=16) + + def feed_rtp(self, packet: RTPPacket) -> None: + if packet.ssrc in self._dropped_ssrcs: + _log.debug("Ignoring packet from dropped ssrc %s", packet.ssrc) + + with self._lock: + decoder = self.get_decoder(packet.ssrc) + if decoder is not None: + decoder.push_packet(packet) + + def feed_rtcp(self, packet: RTCPPacket) -> None: + guild = self.sink.client.guild if self.sink.client else None + event_router = self.reader.event_router + event_router.dispatch("rtcp_packet", packet, guild) + + def get_decoder(self, ssrc: int) -> PacketDecoder | None: + with self._lock: + decoder = self.decoders.get(ssrc) + if decoder is None: + decoder = self.decoders[ssrc] = PacketDecoder(self, ssrc) + return decoder + + def set_sink(self, sink: Sink) -> None: + with self._lock: + self.sink = sink + + def set_user_id(self, ssrc: int, user_id: int) -> None: + with self._lock: + if ssrc in self._dropped_ssrcs: + self._dropped_ssrcs.remove(ssrc) + + decoder = self.decoders.get(ssrc) + if decoder is not None: + decoder.set_user_id(user_id) + + def destroy_decoder(self, ssrc: int) -> None: + with self._lock: + decoder = self.decoders.pop(ssrc, None) + if decoder is not None: + self._dropped_ssrcs.append(ssrc) + decoder.destroy() + + def destroy_all_decoders(self) -> None: + with self._lock: + for ssrc in self.decoders.keys(): + self.destroy_decoder(ssrc) + + def stop(self) -> None: + self._end_thread.set() + self.waiter.notify() + + def run(self) -> None: + try: + self._do_run() + except Exception as exc: + _log.exception("Error in %s loop", self) + self.reader.error = exc + finally: + self.reader.client.stop_recording() + self.waiter.clear() + + def _do_run(self) -> None: + while not self._end_thread.is_set(): + self.waiter.wait() + + with self._lock: + for decoder in self.waiter.items: + data = decoder.pop_data() + if data is not None: + self.sink.write(data.source, data) + + +class SinkEventRouter(threading.Thread): + def __init__(self, sink: Sink, reader: AudioReader) -> None: + super().__init__( + daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}" + ) + + self.sink: Sink = sink + self.reader: AudioReader = reader + + self._event_listeners: dict[str, list[EventCB]] = {} + self._buffer: queue.SimpleQueue[EventData] = queue.SimpleQueue() + self._lock = threading.RLock() + self._end_thread = threading.Event() + + self.register_events() + + def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: + _log.debug("Dispatch voice event %s", event) + self._buffer.put_nowait((event, args, kwargs)) + + def set_sink(self, sink: Sink) -> None: + with self._lock: + self.unregister_events() + self.sink = sink + self.register_events() + + def register_events(self) -> None: + with self._lock: + self._register_listeners(self.sink) + for child in self.sink.walk_children(): + self._register_listeners(child) + + def unregister_events(self) -> None: + with self._lock: + self._unregister_listeners(self.sink) + for child in self.sink.walk_children(): + self._unregister_listeners(child) + + def _register_listeners(self, sink: Sink) -> None: + _log.debug("Registering events for %s: %s", sink, sink.__sink_listeners__) + + for name, method_name in sink.__sink_listeners__: + func = getattr(sink, method_name) + _log.debug("Registering event: %r (callback at %r)", name, method_name) + + if name in self._event_listeners: + self._event_listeners[name].append(func) + else: + self._event_listeners[name] = [func] + + def _unregister_listeners(self, sink: Sink) -> None: + for name, method_name in sink.__sink_listeners__: + func = getattr(sink, method_name) + + if name in self._event_listeners: + try: + self._event_listeners[name].remove(func) + except ValueError: + pass + + def _dispatch_to_listeners(self, event: str, *args: Any, **kwargs: Any) -> None: + for listener in self._event_listeners.get(f"on_{event}", []): + try: + listener(*args, **kwargs) + except Exception as exc: + _log.exception( + "Unhandled exception while dispatching event %s (args: %s; kwargs: %s)", + event, + args, + kwargs, + exc_info=exc, + ) + + def stop(self) -> None: + self._end_thread.set() + + def run(self) -> None: + try: + self._do_run() + except Exception as exc: + _log.exception("Error in sink event router", exc_info=exc) + self.reader.error = exc + self.reader.client.stop_listening() + + def _do_run(self) -> None: + while not self._end_thread.is_set(): + try: + event, args, kwargs = self._buffer.get(timeout=0.5) + except queue.Empty: + continue + else: + with self._lock: + with self.reader.packet_router._lock: + self._dispatch_to_listeners(event, *args, **kwargs) diff --git a/discord/voice/state.py b/discord/voice/state.py new file mode 100644 index 0000000000..9b17e8e5e8 --- /dev/null +++ b/discord/voice/state.py @@ -0,0 +1,980 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import logging +import select +import socket +import threading +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any + +from discord import opus, utils +from discord.backoff import ExponentialBackoff +from discord.enums import SpeakingState, try_enum +from discord.errors import ConnectionClosed +from discord.object import Object +from discord.sinks import RawData, Sink + +try: + import davey +except ImportError: + import warnings + + warnings.warn_explicit() + +from .enums import ConnectionFlowState, OpCodes +from .gateway import VoiceWebSocket + +if TYPE_CHECKING: + from discord import abc + from discord.guild import Guild + from discord.member import VoiceState + from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent + from discord.state import ConnectionState + from discord.types.voice import SupportedModes + from discord.user import ClientUser + + from .client import VoiceClient + +MISSING = utils.MISSING +SocketReaderCallback = Callable[[bytes], Any] +_log = logging.getLogger(__name__) +_recv_log = logging.getLogger("discord.voice.receiver") +DAVE_PROTOCOL_VERSION = davey.DAVE_PROTOCOL_VERSION + + +class SocketReader(threading.Thread): + def __init__( + self, + state: VoiceConnectionState, + name: str, + buffer_size: int, + *, + start_paused: bool = True, + ) -> None: + super().__init__( + daemon=True, + name=name, + ) + + self.buffer_size: int = buffer_size + self.state: VoiceConnectionState = state + self.start_paused: bool = start_paused + self._callbacks: list[SocketReaderCallback] = [] + self._running: threading.Event = threading.Event() + self._end: threading.Event = threading.Event() + self._idle_paused: bool = True + self._started: threading.Event = threading.Event() + self._warned_wait: bool = False + + def is_running(self) -> bool: + return self._started.is_set() + + def register(self, callback: SocketReaderCallback) -> None: + self._callbacks.append(callback) + if self._idle_paused: + self._idle_paused = False + self._running.set() + + def unregister(self, callback: SocketReaderCallback) -> None: + try: + self._callbacks.remove(callback) + except ValueError: + pass + else: + if not self._callbacks and self._running.is_set(): + self._idle_paused = True + self._running.clear() + + def pause(self) -> None: + self._idle_paused = False + self._running.clear() + + def is_paused(self) -> bool: + return self._idle_paused or ( + not self._running.is_set() and not self._end.is_set() + ) + + def resume(self, *, force: bool = False) -> None: + if self._running.is_set(): + return + + if not force and not self._callbacks: + self._idle_paused = True + return + + self._idle_paused = False + self._running.set() + + def stop(self) -> None: + self._started.clear() + self._end.set() + self._running.set() + + def run(self) -> None: + self._started.set() + self._end.clear() + self._running.set() + + if self.start_paused: + self.pause() + + try: + self._do_run() + except Exception: + _log.exception( + "An error ocurred while running the socket reader %s", + self.name, + ) + finally: + self.stop() + self._started.clear() + self._running.clear() + self._callbacks.clear() + + def _do_run(self) -> None: + while not self._end.is_set(): + if not self._running.is_set(): + if not self._warned_wait: + _log.warning( + "Socket reader %s is waiting to be set as running", self.name + ) + self._warned_wait = True + self._running.wait() + continue + + if self._warned_wait: + _log.info("Socket reader %s was set as running", self.name) + self._warned_wait = False + + try: + readable, _, _ = select.select([self.state.socket], [], [], 30) + except (ValueError, TypeError, OSError) as e: + _log.debug( + "Select error handling socket in reader, this should be safe to ignore: %s: %s", + e.__class__.__name__, + e, + ) + continue + + if not readable: + continue + + try: + data = self.state.socket.recv(self.buffer_size) + except OSError: + _log.debug( + "Error reading from socket in %s, this should be safe to ignore", + self, + exc_info=True, + ) + else: + for cb in self._callbacks: + try: + task = asyncio.ensure_future( + self.state.loop.create_task( + utils.maybe_coroutine(cb, data) + ), + loop=self.state.loop, + ) + self.state._dispatch_task_set.add(task) + task.add_done_callback(self.state._dispatch_task_set.discard) + except Exception: + _log.exception( + "Error while calling %s in %s", + cb, + self, + ) + + +class SocketEventReader(SocketReader): + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True + ) -> None: + super().__init__( + state, + f"voice-socket-event-reader:{id(self):#x}", + 2048, + start_paused=start_paused, + ) + + +class VoiceConnectionState: + def __init__( + self, + client: VoiceClient, + *, + hook: ( + Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None + ) = None, + ) -> None: + self.client: VoiceClient = client + self.hook = hook + self.loop: asyncio.AbstractEventLoop = client.loop + + self.timeout: float = 30.0 + self.reconnect: bool = True + self.self_deaf: bool = False + self.self_mute: bool = False + self.endpoint: str | None = None + self.endpoint_ip: str | None = None + self.server_id: int | None = None + self.ip: str | None = None + self.port: int | None = None + self.voice_port: int | None = None + self.secret_key: list[int] = MISSING + self.ssrc: int = MISSING + self.mode: SupportedModes = MISSING + self.socket: socket.socket = MISSING + self.ws: VoiceWebSocket = MISSING + self.session_id: str | None = None + self.token: str | None = None + + self._connection: ConnectionState = client._state + self._state: ConnectionFlowState = ConnectionFlowState.disconnected + self._expecting_disconnect: bool = False + self._connected = threading.Event() + self._state_event = asyncio.Event() + self._disconnected = asyncio.Event() + self._runner: asyncio.Task[None] | None = None + self._connector: asyncio.Task[None] | None = None + self._socket_reader = SocketEventReader(self) + self._socket_reader.start() + self.recording_done_callbacks: list[ + tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]] + ] = [] + self._dispatch_task_set: set[asyncio.Task] = set() + + if not self._connection.self_id: + raise RuntimeError("client self ID is not set") + if not self.channel_id: + raise RuntimeError("client channel being connected to is not set") + + self.dave_session: davey.DaveSession | None = None + self.dave_protocol_version: int = 0 + self.dave_pending_transition: dict[str, int] | None = None + self.downgraded_dave = False + + @property + def user_ssrc_map(self) -> dict[int, int]: + return self.client._id_to_ssrc + + @property + def max_dave_proto_version(self) -> int: + return davey.DAVE_PROTOCOL_VERSION + + @property + def state(self) -> ConnectionFlowState: + return self._state + + @state.setter + def state(self, state: ConnectionFlowState) -> None: + if state is not self._state: + _log.debug("State changed from %s to %s", self._state.name, state.name) + + self._state = state + self._state_event.set() + self._state_event.clear() + + if state is ConnectionFlowState.connected: + self._connected.set() + else: + self._connected.clear() + + @property + def guild(self) -> Guild: + return self.client.guild + + @property + def user(self) -> ClientUser: + return self.client.user + + @property + def channel_id(self) -> int | None: + return self.client.channel and self.client.channel.id + + @property + def guild_id(self) -> int: + return self.guild.id + + @property + def supported_modes(self) -> tuple[SupportedModes, ...]: + return self.client.supported_modes + + @property + def self_voice_state(self) -> VoiceState | None: + return self.guild.me.voice + + def is_connected(self) -> bool: + return self.state is ConnectionFlowState.connected + + def _inside_runner(self) -> bool: + return self._runner is not None and asyncio.current_task() == self._runner + + async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + channel_id = data.channel_id + + if channel_id is None: + self._disconnected.set() + + if self._expecting_disconnect: + self._expecting_disconnect = False + else: + _log.debug("We have been disconnected from voice") + await self.disconnect() + return + + self.session_id = data["session_id"] + + if self.state in ( + ConnectionFlowState.set_guild_voice_state, + ConnectionFlowState.got_voice_server_update, + ): + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_state_update + + if channel_id != self.client.channel.id: + # moved from channel + self._update_voice_channel(channel_id) + else: + self.state = ConnectionFlowState.got_both_voice_updates + return + + if self.state is ConnectionFlowState.connected: + self._update_voice_channel(channel_id) + + elif self.state is not ConnectionFlowState.disconnected: + if channel_id != self.client.channel.id: + _log.info("We were moved from the channel while connecting...") + + self._update_voice_channel(channel_id) + await self.soft_disconnect( + with_state=ConnectionFlowState.got_voice_state_update + ) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + else: + _log.debug("Ignoring unexpected VOICE_STATEUPDATE event") + + async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + previous_token = self.token + previous_server_id = self.server_id + previous_endpoint = self.endpoint + + self.token = data.token + self.server_id = data.guild_id + endpoint = data.endpoint + + if self.token is None or endpoint is None: + _log.warning( + "Awaiting endpoint... This requires waiting. " + "If timeout occurred considering raising the timeout and reconnecting." + ) + return + + # strip the prefix off since we add it later + self.endpoint = endpoint.removeprefix("wss://") + + if self.state in ( + ConnectionFlowState.set_guild_voice_state, + ConnectionFlowState.got_voice_state_update, + ): + self.endpoint_ip = MISSING + self._create_socket() + + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_server_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + + elif self.state is ConnectionFlowState.connected: + _log.debug("Voice server update, closing old voice websocket") + await self.ws.close(4014) # 4014 = main gw dropped + self.state = ConnectionFlowState.got_voice_server_update + + elif self.state is not ConnectionFlowState.disconnected: + if ( + previous_token == self.token + and previous_server_id == self.server_id + and previous_endpoint == self.endpoint + ): + return + + _log.debug("Unexpected VOICE_SERVER_UPDATE event received, handling...") + + await self.soft_disconnect( + with_state=ConnectionFlowState.got_voice_server_update + ) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + self._create_socket() + + async def connect( + self, + *, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + wait: bool = True, + ) -> None: + if self._connector: + self._connector.cancel() + self._connector = None + + if self._runner: + self._runner.cancel() + self._runner = None + + self.timeout = timeout + self.reconnect = reconnect + self._connector = self.client.loop.create_task( + self._wrap_connect( + reconnect, + timeout, + self_deaf, + self_mute, + resume, + ), + name=f"voice-connector:{id(self):#x}", + ) + + if wait: + await self._connector + + async def _wrap_connect( + self, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + ) -> None: + try: + await self._connect(reconnect, timeout, self_deaf, self_mute, resume) + except asyncio.CancelledError: + _log.debug("Cancelling voice connection") + await self.soft_disconnect() + raise + except asyncio.TimeoutError: + _log.info("Timed out while connecting to voice") + await self.disconnect() + raise + except Exception: + _log.exception("Error while connecting to voice... disconnecting") + await self.disconnect() + raise + + async def _inner_connect( + self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool + ) -> None: + for i in range(5): + _log.info("Starting voice handshake (connection attempt %s)", i + 1) + + await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) + if self.state is ConnectionFlowState.disconnected: + self.state = ConnectionFlowState.set_guild_voice_state + + await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) + + _log.info("Voice handshake complete. Endpoint found: %s", self.endpoint) + + try: + self.ws = await self._connect_websocket(resume) + await self._handshake_websocket() + break + except ConnectionClosed: + if reconnect: + wait = 1 + i * 2 + _log.exception( + "Failed to connect to voice... Retrying in %s seconds", wait + ) + await self.disconnect(cleanup=False) + await asyncio.sleep(wait) + continue + else: + await self.disconnect() + raise + + async def _connect( + self, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + ) -> None: + _log.info(f"Connecting to voice {self.client.channel.id}") + + await asyncio.wait_for( + self._inner_connect( + reconnect=reconnect, + self_deaf=self_deaf, + self_mute=self_mute, + resume=resume, + ), + timeout=timeout, + ) + _log.info("Voice connection completed") + + if not self._runner: + self._runner = self.client.loop.create_task( + self._poll_ws(reconnect), + name=f"voice-ws-poller:{id(self):#x}", + ) + + async def disconnect( + self, *, force: bool = True, cleanup: bool = True, wait: bool = False + ) -> None: + if not force and not self.is_connected(): + return + + _log.debug( + "Attempting a voice disconnect for channel %s (guild %s)", + self.channel_id, + self.guild_id, + ) + try: + await self._voice_disconnect() + if self.ws: + await self.ws.close() + except Exception: + _log.debug( + "Ignoring exception while disconnecting from voice", exc_info=True + ) + finally: + self.state = ConnectionFlowState.disconnected + self._socket_reader.pause() + + if cleanup: + self._socket_reader.stop() + self.client.stop() + + self._connected.set() + self._connected.clear() + + if self.socket: + self.socket.close() + + self.ip = MISSING + self.port = MISSING + + if wait and not self._inside_runner(): + try: + await asyncio.wait_for( + self._disconnected.wait(), timeout=self.timeout + ) + except TimeoutError: + _log.debug("Timed out waiting for voice disconnect confirmation") + except asyncio.CancelledError: + pass + + if cleanup: + self.client.cleanup() + + async def soft_disconnect( + self, + *, + with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates, + ) -> None: + _log.debug("Soft disconnecting from voice") + + if self._runner: + self._runner.cancel() + self._runner = None + + try: + if self.ws: + await self.ws.close() + except Exception: + _log.debug( + "Ignoring exception while soft disconnecting from voice", exc_info=True + ) + finally: + self.state = with_state + self._socket_reader.pause() + + if self.socket: + self.socket.close() + + self.ip = MISSING + self.port = MISSING + + async def move_to( + self, channel: abc.Snowflake | None, timeout: float | None + ) -> None: + if channel is None: + await self.disconnect(wait=True) + return + + if self.client.channel and channel.id == self.client.channel.id: + return + + previous_state = self.state + await self._move_to(channel) + + last_state = self.state + + try: + await self.wait_for(timeout=timeout) + except asyncio.TimeoutError: + _log.warning( + "Timed out trying to move to channel %s in guild %s", + channel.id, + self.guild.id, + ) + if self.state is last_state: + _log.debug( + "Reverting state %s to previous state: %s", + last_state.name, + previous_state.name, + ) + self.state = previous_state + + def wait_for( + self, + state: ConnectionFlowState = ConnectionFlowState.connected, + timeout: float | None = None, + ) -> Any: + if state is ConnectionFlowState.connected: + return self._connected.wait(timeout) + return self._wait_for_state(state, timeout=timeout) + + def send_packet(self, packet: bytes) -> None: + self.socket.sendall(packet) + + def add_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug("Registering a socket listener callback %s", callback) + self._socket_reader.register(callback) + + def remove_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug("Unregistering a socket listener callback %s", callback) + self._socket_reader.unregister(callback) + + async def _wait_for_state( + self, + *states: ConnectionFlowState, + timeout: float | None = None, + ) -> None: + if not states: + raise ValueError + + while True: + if self.state in states: + return + + _, pending = await asyncio.wait( + [ + asyncio.ensure_future(self._state_event.wait()), + ], + timeout=timeout, + ) + if pending: + # if we're here, it means that the state event + # has timed out, so just raise the exception + raise asyncio.TimeoutError + + async def _voice_connect( + self, *, self_deaf: bool = False, self_mute: bool = False + ) -> None: + channel = self.client.channel + await channel.guild.change_voice_state( + channel=channel, self_deaf=self_deaf, self_mute=self_mute + ) + + async def _voice_disconnect(self) -> None: + _log.info( + "Terminating voice handshake for channel %s (guild %s)", + self.client.channel.id, + self.client.guild.id, + ) + + self.state = ConnectionFlowState.disconnected + await self.client.channel.guild.change_voice_state(channel=None) + self._expecting_disconnect = True + self._disconnected.clear() + + async def _connect_websocket(self, resume: bool) -> VoiceWebSocket: + seq_ack = -1 + if self.ws is not MISSING: + seq_ack = self.ws.seq_ack + ws = await VoiceWebSocket.from_state( + self, resume=resume, hook=self.hook, seq_ack=seq_ack + ) + self.state = ConnectionFlowState.websocket_connected + return ws + + async def _handshake_websocket(self) -> None: + while not self.ip: + await self.ws.poll_event() + + self.state = ConnectionFlowState.got_ip_discovery + while self.ws.secret_key is None: + await self.ws.poll_event() + + self.state = ConnectionFlowState.connected + + def _create_socket(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + self._socket_reader.resume() + + async def _poll_ws(self, reconnect: bool) -> None: + backoff = ExponentialBackoff() + + while True: + try: + await self.ws.poll_event() + except asyncio.CancelledError: + return + except (ConnectionClosed, asyncio.TimeoutError) as exc: + if isinstance(exc, ConnectionClosed): + # 1000 - normal closure - not resumable + # 4014 - externally disconnected - not resumable + # 4015 - voice server crashed - resumable + # 4021 - ratelimited, not reconnect - not resumable + # 4022 - call terminated, similar to 4014 - not resumable + + if exc.code == 1000: + if not self._expecting_disconnect: + _log.info( + "Disconnecting from voice manually, close code %d", + exc.code, + ) + await self.disconnect() + break + elif exc.code in (4014, 4022): + if self._disconnected.is_set(): + _log.info( + "Disconnectinf from voice by Discord, close code %d", + exc.code, + ) + await self.disconnect() + break + + _log.info( + "Disconnecting from voice by force... potentially reconnecting..." + ) + successful = await self._potential_reconnect() + if not successful: + _log.info( + "Reconnect was unsuccessful, disconnecting from voice normally" + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + else: + # we have successfully resumed so just keep polling events + continue + elif exc.code == 4021: + _log.warning( + "We are being rate limited while attempting to connect to voice. Disconnecting...", + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + elif exc.code == 4015: + _log.info( + "Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...", + ) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=True, + ) + except asyncio.TimeoutError: + _log.info( + "Could not resume the voice connection... Disconnecting..." + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + except Exception: + _log.exception( + "An exception was raised while attempting a reconnect and resume... Disconnecting...", + exc_info=True, + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + else: + _log.info( + "Successfully reconnected and resume the voice connection" + ) + continue + else: + _log.debug( + "Not handling close code %s (%s)", + exc.code, + exc.reason or "No reason was provided", + ) + + if not reconnect: + await self.disconnect() + raise + + retry = backoff.delay() + _log.exception( + "Disconnected from voice... Reconnecting in %.2fs", + retry, + ) + await asyncio.sleep(retry) + await self.disconnect(cleanup=False) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + ) + except asyncio.TimeoutError: + _log.warning("Could not connect to voice... Retrying...") + continue + + async def _potential_reconnect(self) -> bool: + try: + await self._wait_for_state( + ConnectionFlowState.got_voice_server_update, + ConnectionFlowState.got_both_voice_updates, + ConnectionFlowState.disconnected, + timeout=self.timeout, + ) + except asyncio.TimeoutError: + return False + else: + if self.state is ConnectionFlowState.disconnected: + return False + + previous_ws = self.ws + + try: + self.ws = await self._connect_websocket(False) + await self._handshake_websocket() + except (ConnectionClosed, asyncio.TimeoutError): + return False + else: + return True + finally: + await previous_ws.close() + + async def _move_to(self, channel: abc.Snowflake) -> None: + await self.client.channel.guild.change_voice_state(channel=channel) + self.state = ConnectionFlowState.set_guild_voice_state + + def _update_voice_channel(self, channel_id: int | None) -> None: + self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore + + async def reinit_dave_session(self) -> None: + if self.dave_protocol_version > 0: + if self.dave_session: + self.dave_session.reinit( + self.dave_protocol_version, self.user.id, self.channel_id + ) + else: + self.dave_session = davey.DaveSession( + self.dave_protocol_version, + self.user.id, + self.channel_id, + ) + + await self.ws.send_as_bytes( + OpCodes.mls_key_package, + self.dave_session.get_serialized_key_package(), + ) + elif self.dave_session: + self.dave_session.reset() + self.dave_session.set_passthrough_mode(True, 10) + + async def recover_dave_from_invalid_commit(self, transition: int) -> None: + payload = { + "op": int(OpCodes.mls_invalid_commit_welcome), + "d": {"transition_id": transition}, + } + await self.ws.send_as_json(payload) + await self.reinit_dave_session() + + async def execute_dave_transition(self, transition: int) -> None: + _log.debug("Executing DAVE transition with id %s", transition) + + if not self.dave_pending_transition: + _log.warning( + "Attempted to execute a transition without having a pending transition for id %s, " + "this is a Discord bug.", + transition, + ) + return + + pending_transition = self.dave_pending_transition["transition_id"] + pending_proto = self.dave_pending_transition["protocol_version"] + + session = self.dave_session + + if transition == pending_transition: + old_version = self.dave_protocol_version + self.dave_protocol_version = pending_proto + + if ( + old_version != self.dave_protocol_version + and self.dave_protocol_version == 0 + ): + _log.warning( + "DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7" + ) + self.downgraded_dave = True + elif transition > 0 and self.downgraded_dave: + self.downgraded_dave = False + if session: + session.set_passthrough_mode(True, 10) + _log.info("Upgraded voice session to use DAVE") + else: + _log.debug( + "Received an execute transition id %s when expected was %s, ignoring", + transition, + pending_proto, + ) + + self.dave_pending_transition = None diff --git a/discord/voice/utils/__init__.py b/discord/voice/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/discord/voice/utils/buffer.py b/discord/voice/utils/buffer.py new file mode 100644 index 0000000000..f12fb0ae6e --- /dev/null +++ b/discord/voice/utils/buffer.py @@ -0,0 +1,215 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import heapq +import logging +import threading +from typing import Protocol, TypeVar + +from ..packets import Packet +from .wrapped import add_wrapped, gap_wrapped + +__all__ = ( + "Buffer", + "JitterBuffer", +) + + +T = TypeVar("T") +PacketT = TypeVar("PacketT", bound=Packet) +_log = logging.getLogger(__name__) + + +class Buffer(Protocol[T]): + def __len__(self) -> int: ... + def push(self, item: T) -> None: ... + def pop(self) -> T | None: ... + def peek(self) -> T | None: ... + def flush(self) -> list[T]: ... + def reset(self) -> None: ... + + +class BaseBuff(Buffer[PacketT]): + def __init__(self) -> None: + self._buffer: list[PacketT] = [] + + def __len__(self) -> int: + return len(self._buffer) + + def push(self, item: PacketT) -> None: + self._buffer.append(item) + + def pop(self) -> PacketT | None: + return self._buffer.pop() + + def peek(self) -> PacketT | None: + return self._buffer[-1] if self._buffer else None + + def flush(self) -> list[PacketT]: + buf = self._buffer.copy() + self._buffer.clear() + return buf + + def reset(self) -> None: + self._buffer.clear() + + +class JitterBuffer(BaseBuff[PacketT]): + _threshold: int = 10000 + + def __init__( + self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1 + ) -> None: + if max_size < 1: + raise ValueError(f"max_size must be greater than 1, not {max_size}") + + if not 0 <= pref_size <= max_size: + raise ValueError(f"pref_size must be between 0 and max_size ({max_size})") + + self.max_size: int = max_size + self.pref_size: int = pref_size + self.prefill: int = prefill + self._prefill: int = prefill + self._last_tx_seq: int = -1 + self._has_item: threading.Event = threading.Event() + # self._lock: threading.Lock = threading.Lock() + self._buffer: list[Packet] = [] + + def _push(self, packet: Packet) -> None: + heapq.heappush(self._buffer, packet) + + def _pop(self) -> Packet: + return heapq.heappop(self._buffer) + + def _get_packet_if_ready(self) -> Packet | None: + return self._buffer[0] if len(self._buffer) > self.pref_size else None + + def _pop_if_ready(self) -> Packet | None: + return self._pop() if len(self._buffer) > self.pref_size else None + + def _update_has_item(self) -> None: + prefilled = self._prefill == 0 + packet_ready = len(self._buffer) > self.pref_size + + if not prefilled or not packet_ready: + self._has_item.clear() + return + + next_packet = self._buffer[0] + sequential = add_wrapped(self._last_tx_seq, 1) == next_packet.sequence + positive_seq = self._last_tx_seq >= 0 + + if ( + (sequential and positive_seq) + or not positive_seq + or len(self._buffer) >= self.max_size + ): + self._has_item.set() + else: + self._has_item.clear() + + def _cleanup(self) -> None: + while len(self._buffer) > self.max_size: + heapq.heappop(self._buffer) + + def push(self, packet: Packet) -> bool: + seq = packet.sequence + + if ( + gap_wrapped(self._last_tx_seq, seq) > self._threshold + and self._last_tx_seq != -1 + ): + _log.debug("Dropping old packet %s", packet) + return False + + self._push(packet) + + if self._prefill > 0: + self._prefill -= 1 + + self._cleanup() + self._update_has_item() + return True + + def pop(self, *, timeout: float | None = 0) -> Packet | None: + ok = self._has_item.wait(timeout) + if not ok: + return None + + if self._prefill > 0: + return None + + packet = self._pop_if_ready() + + if packet is not None: + self._last_tx_seq = packet.sequence + + self._update_has_item() + return packet + + def peek(self, *, all: bool = False) -> Packet | None: + if not self._buffer: + return None + + if all: + return self._buffer[0] + else: + return self._get_packet_if_ready() + + def peek_next(self) -> Packet | None: + packet = self.peek(all=True) + + if packet is None: + return None + + if ( + packet.sequence == add_wrapped(self._last_tx_seq, 1) + or self._last_tx_seq < 0 + ): + return packet + + def gap(self) -> int: + if self._buffer and self._last_tx_seq > 0: + return gap_wrapped(self._last_tx_seq, self._buffer[0].sequence) + return 0 + + def flush(self) -> list[Packet]: + packets = sorted(self._buffer) + self._buffer.clear() + + if packets: + self._last_tx_seq = packets[-1].sequence + + self._prefill = self.prefill + self._has_item.clear() + return packets + + def reset(self) -> None: + self._buffer.clear() + self._has_item.clear() + self._prefill = self.prefill + self._last_tx_seq = -1 diff --git a/discord/voice/utils/multidataevent.py b/discord/voice/utils/multidataevent.py new file mode 100644 index 0000000000..e0079b5a54 --- /dev/null +++ b/discord/voice/utils/multidataevent.py @@ -0,0 +1,79 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import threading +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class MultiDataEvent(Generic[T]): + """ + Something like the inverse of a Condition. A 1-waiting-on-N type of object, + with accompanying data object for convenience. + """ + + def __init__(self): + self._items: list[T] = [] + self._ready: threading.Event = threading.Event() + + @property + def items(self) -> list[T]: + """A shallow copy of the currently ready objects.""" + return self._items.copy() + + def is_ready(self) -> bool: + return self._ready.is_set() + + def _check_ready(self) -> None: + if self._items: + self._ready.set() + else: + self._ready.clear() + + def notify(self) -> None: + self._ready.set() + self._check_ready() + + def wait(self, timeout: float | None = None) -> bool: + self._check_ready() + return self._ready.wait(timeout) + + def register(self, item: T) -> None: + self._items.append(item) + self._ready.set() + + def unregister(self, item: T) -> None: + try: + self._items.remove(item) + except ValueError: + pass + self._check_ready() + + def clear(self) -> None: + self._items.clear() + self._ready.clear() diff --git a/discord/voice/utils/wrapped.py b/discord/voice/utils/wrapped.py new file mode 100644 index 0000000000..fb2aafc630 --- /dev/null +++ b/discord/voice/utils/wrapped.py @@ -0,0 +1,32 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +def gap_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: + return (b - (a + 1) + wrap) % wrap + + +def add_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: + return (a + b) % wrap diff --git a/discord/voice_client.py b/discord/voice_client.py deleted file mode 100644 index cb475a96ff..0000000000 --- a/discord/voice_client.py +++ /dev/null @@ -1,1080 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - -Some documentation to refer to: - -- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. -- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. -- We pull the session_id from VOICE_STATE_UPDATE. -- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. -- Then we initiate the voice web socket (vWS) pointing to the endpoint. -- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. -- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. -- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. -- Then we send our IP and port via vWS with opcode 1. -- When that's all done, we receive opcode 4 from the vWS. -- Finally we can transmit data to endpoint:port. -""" - -from __future__ import annotations - -import asyncio -import datetime -import logging -import select -import socket -import struct -import threading -import time -from typing import TYPE_CHECKING, Any, Callable, Literal, overload - -from . import opus, utils -from .backoff import ExponentialBackoff -from .errors import ClientException, ConnectionClosed -from .gateway import * -from .player import AudioPlayer, AudioSource -from .sinks import RawData, RecordingException, Sink -from .utils import MISSING - -if TYPE_CHECKING: - from . import abc - from .client import Client - from .guild import Guild - from .opus import Encoder - from .state import ConnectionState - from .types.voice import GuildVoiceState as GuildVoiceStatePayload - from .types.voice import SupportedModes - from .types.voice import VoiceServerUpdate as VoiceServerUpdatePayload - from .user import ClientUser - - -has_nacl: bool - -try: - import nacl.secret # type: ignore - - has_nacl = True -except ImportError: - has_nacl = False - -__all__ = ( - "VoiceProtocol", - "VoiceClient", -) - - -_log = logging.getLogger(__name__) - - -class VoiceProtocol: - """A class that represents the Discord voice protocol. - - This is an abstract class. The library provides a concrete implementation - under :class:`VoiceClient`. - - This class allows you to implement a protocol to allow for an external - method of sending voice, such as Lavalink_ or a native library implementation. - - These classes are passed to :meth:`abc.Connectable.connect `. - - .. _Lavalink: https://github.com/freyacodes/Lavalink - - Parameters - ---------- - client: :class:`Client` - The client (or its subclasses) that started the connection request. - channel: :class:`abc.Connectable` - The voice channel that is being connected to. - """ - - def __init__(self, client: Client, channel: abc.Connectable) -> None: - self.client: Client = client - self.channel: abc.Connectable = channel - - async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - """|coro| - - An abstract method that is called when the client's voice state - has changed. This corresponds to ``VOICE_STATE_UPDATE``. - - Parameters - ---------- - data: :class:`dict` - The raw `voice state payload`__. - - .. _voice_state_update_payload: https://discord.com/developers/docs/resources/voice#voice-state-object - - __ voice_state_update_payload_ - """ - raise NotImplementedError - - async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - """|coro| - - An abstract method that is called when initially connecting to voice. - This corresponds to ``VOICE_SERVER_UPDATE``. - - Parameters - ---------- - data: :class:`dict` - The raw `voice server update payload`__. - - .. _voice_server_update_payload: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields - - __ voice_server_update_payload_ - """ - raise NotImplementedError - - async def connect(self, *, timeout: float, reconnect: bool) -> None: - """|coro| - - An abstract method called when the client initiates the connection request. - - When a connection is requested initially, the library calls the constructor - under ``__init__`` and then calls :meth:`connect`. If :meth:`connect` fails at - some point then :meth:`disconnect` is called. - - Within this method, to start the voice connection flow it is recommended to - use :meth:`Guild.change_voice_state` to start the flow. After which, - :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called. - The order that these two are called is unspecified. - - Parameters - ---------- - timeout: :class:`float` - The timeout for the connection. - reconnect: :class:`bool` - Whether reconnection is expected. - """ - raise NotImplementedError - - async def disconnect(self, *, force: bool) -> None: - """|coro| - - An abstract method called when the client terminates the connection. - - See :meth:`cleanup`. - - Parameters - ---------- - force: :class:`bool` - Whether the disconnection was forced. - """ - raise NotImplementedError - - def cleanup(self) -> None: - """This method *must* be called to ensure proper clean-up during a disconnect. - - It is advisable to call this from within :meth:`disconnect` when you are - completely done with the voice protocol instance. - - This method removes it from the internal state cache that keeps track of - currently alive voice clients. Failure to clean-up will cause subsequent - connections to report that it's still connected. - """ - key_id, _ = self.channel._get_voice_client_key() - self.client._connection._remove_voice_client(key_id) - - -class VoiceClient(VoiceProtocol): - """Represents a Discord voice connection. - - You do not create these, you typically get them from - e.g. :meth:`VoiceChannel.connect`. - - Attributes - ---------- - session_id: :class:`str` - The voice connection session ID. - token: :class:`str` - The voice connection token. - endpoint: :class:`str` - The endpoint we are connecting to. - channel: :class:`abc.Connectable` - The voice channel connected to. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the voice client is running on. - - Warning - ------- - In order to use PCM based AudioSources, you must have the opus library - installed on your system and loaded through :func:`opus.load_opus`. - Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) - or the library will not be able to transmit audio. - """ - - endpoint_ip: str - voice_port: int - secret_key: list[int] - ssrc: int - - def __init__(self, client: Client, channel: abc.Connectable): - if not has_nacl: - raise RuntimeError("PyNaCl library needed in order to use voice") - - super().__init__(client, channel) - state = client._connection - self.token: str = MISSING - self.socket = MISSING - self.loop: asyncio.AbstractEventLoop = state.loop - self._state: ConnectionState = state - # this will be used in the AudioPlayer thread - self._connected: threading.Event = threading.Event() - - self._handshaking: bool = False - self._potentially_reconnecting: bool = False - self._voice_state_complete: asyncio.Event = asyncio.Event() - self._voice_server_complete: asyncio.Event = asyncio.Event() - - self.mode: str = MISSING - self._connections: int = 0 - self.sequence: int = 0 - self.timestamp: int = 0 - self.timeout: float = 0 - self._runner: asyncio.Task = MISSING - self._player: AudioPlayer | None = None - self.encoder: Encoder = MISSING - self.decoder = None - self._lite_nonce: int = 0 - self.ws: DiscordVoiceWebSocket = MISSING - - self.paused = False - self.recording = False - self.user_timestamps = {} - self.sink = None - self.starting_time = None - self.stopping_time = None - self.temp_queued_data: dict[int, list] = {} - - warn_nacl = not has_nacl - supported_modes: tuple[SupportedModes, ...] = ( - "xsalsa20_poly1305_lite", - "xsalsa20_poly1305_suffix", - "xsalsa20_poly1305", - "aead_xchacha20_poly1305_rtpsize", - ) - - @property - def guild(self) -> Guild | None: - """The guild we're connected to, if applicable.""" - return getattr(self.channel, "guild", None) - - @property - def user(self) -> ClientUser: - """The user connected to voice (i.e. ourselves).""" - return self._state.user - - def checked_add(self, attr, value, limit): - val = getattr(self, attr) - if val + value > limit: - setattr(self, attr, 0) - else: - setattr(self, attr, val + value) - - # connection related - - async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id = data["session_id"] - channel_id = data["channel_id"] - - if not self._handshaking or self._potentially_reconnecting: - # If we're done handshaking then we just need to update ourselves - # If we're potentially reconnecting due to a 4014, then we need to differentiate - # a channel move and an actual force disconnect - if channel_id is None: - # We're being disconnected so cleanup - await self.disconnect() - else: - guild = self.guild - self.channel = channel_id and guild and guild.get_channel(int(channel_id)) # type: ignore - else: - self._voice_state_complete.set() - - async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - if self._voice_server_complete.is_set(): - _log.info("Ignoring extraneous voice server update.") - return - - self.token = data.get("token") - self.server_id = int(data["guild_id"]) - endpoint = data.get("endpoint") - - if endpoint is None or self.token is None: - _log.warning( - "Awaiting endpoint... This requires waiting. " - "If timeout occurred considering raising the timeout and reconnecting." - ) - return - - self.endpoint = endpoint.removeprefix("wss://") - # This gets set later - self.endpoint_ip = MISSING - - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setblocking(False) - - if not self._handshaking: - # If we're not handshaking then we need to terminate our previous connection in the websocket - await self.ws.close(4000) - return - - self._voice_server_complete.set() - - async def voice_connect(self) -> None: - await self.channel.guild.change_voice_state(channel=self.channel) - - async def voice_disconnect(self) -> None: - _log.info( - "The voice handshake is being terminated for Channel ID %s (Guild ID %s)", - self.channel.id, - self.guild.id, - ) - await self.channel.guild.change_voice_state(channel=None) - - def prepare_handshake(self) -> None: - self._voice_state_complete.clear() - self._voice_server_complete.clear() - self._handshaking = True - _log.info( - "Starting voice handshake... (connection attempt %d)", self._connections + 1 - ) - self._connections += 1 - - def finish_handshake(self) -> None: - _log.info("Voice handshake complete. Endpoint found %s", self.endpoint) - self._handshaking = False - self._voice_server_complete.clear() - self._voice_state_complete.clear() - - async def connect_websocket(self) -> DiscordVoiceWebSocket: - ws = await DiscordVoiceWebSocket.from_client(self) - self._connected.clear() - while ws.secret_key is None: - await ws.poll_event() - self._connected.set() - return ws - - async def connect(self, *, reconnect: bool, timeout: float) -> None: - _log.info("Connecting to voice...") - self.timeout = timeout - - for i in range(5): - self.prepare_handshake() - - # This has to be created before we start the flow. - futures = [ - self._voice_state_complete.wait(), - self._voice_server_complete.wait(), - ] - - # Start the connection flow - await self.voice_connect() - - try: - await utils.sane_wait_for(futures, timeout=timeout) - except asyncio.TimeoutError: - await self.disconnect(force=True) - raise - - self.finish_handshake() - - try: - self.ws = await self.connect_websocket() - break - except (ConnectionClosed, asyncio.TimeoutError): - if reconnect: - _log.exception("Failed to connect to voice... Retrying...") - await asyncio.sleep(1 + i * 2.0) - await self.voice_disconnect() - continue - else: - raise - - if self._runner is MISSING: - self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) - - async def potential_reconnect(self) -> bool: - # Attempt to stop the player thread from playing early - self._connected.clear() - self.prepare_handshake() - self._potentially_reconnecting = True - try: - # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for( - self._voice_server_complete.wait(), timeout=self.timeout - ) - except asyncio.TimeoutError: - self._potentially_reconnecting = False - await self.disconnect(force=True) - return False - - self.finish_handshake() - self._potentially_reconnecting = False - try: - self.ws = await self.connect_websocket() - except (ConnectionClosed, asyncio.TimeoutError): - return False - else: - return True - - @property - def latency(self) -> float: - """Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. - - This could be referred to as the Discord Voice WebSocket latency and is - an analogue of user's voice latencies as seen in the Discord client. - - .. versionadded:: 1.4 - """ - ws = self.ws - return float("inf") if not ws else ws.latency - - @property - def average_latency(self) -> float: - """Average of most recent 20 HEARTBEAT latencies in seconds. - - .. versionadded:: 1.4 - """ - ws = self.ws - return float("inf") if not ws else ws.average_latency - - async def poll_voice_ws(self, reconnect: bool) -> None: - backoff = ExponentialBackoff() - while True: - try: - await self.ws.poll_event() - except (ConnectionClosed, asyncio.TimeoutError) as exc: - if isinstance(exc, ConnectionClosed): - # The following close codes are undocumented, so I will document them here. - # 1000 - normal closure (obviously) - # 4014 - voice channel has been deleted. - # 4015 - voice server has crashed, we should resume - if exc.code == 1000: - _log.info( - "Disconnecting from voice normally, close code %d.", - exc.code, - ) - await self.disconnect() - break - if exc.code == 4014: - _log.info( - "Disconnected from voice by force... potentially" - " reconnecting." - ) - successful = await self.potential_reconnect() - if successful: - continue - - _log.info( - "Reconnect was unsuccessful, disconnecting from voice" - " normally..." - ) - await self.disconnect() - break - if exc.code == 4015: - _log.info("Disconnected from voice, trying to resume...") - - try: - await self.ws.resume() - except asyncio.TimeoutError: - _log.info( - "Could not resume the voice connection... Disconnection..." - ) - if self._connected.is_set(): - await self.disconnect(force=True) - else: - _log.info("Successfully resumed voice connection") - continue - - if not reconnect: - await self.disconnect() - raise - - retry = backoff.delay() - _log.exception( - "Disconnected from voice... Reconnecting in %.2fs.", retry - ) - self._connected.clear() - await asyncio.sleep(retry) - await self.voice_disconnect() - try: - await self.connect(reconnect=True, timeout=self.timeout) - except asyncio.TimeoutError: - # at this point we've retried 5 times... let's continue the loop. - _log.warning("Could not connect to voice... Retrying...") - continue - - async def disconnect(self, *, force: bool = False) -> None: - """|coro| - - Disconnects this voice client from voice. - """ - if not force and not self.is_connected(): - return - - self.stop() - self._connected.clear() - - try: - if self.ws: - await self.ws.close() - - await self.voice_disconnect() - finally: - self.cleanup() - if self.socket: - self.socket.close() - - async def move_to(self, channel: abc.Connectable) -> None: - """|coro| - - Moves you to a different voice channel. - - Parameters - ---------- - channel: :class:`abc.Connectable` - The channel to move to. Must be a voice channel. - """ - await self.channel.guild.change_voice_state(channel=channel) - - def is_connected(self) -> bool: - """Indicates if the voice client is connected to voice.""" - return self._connected.is_set() - - # audio related - - def _get_voice_packet(self, data): - header = bytearray(12) - - # Formulate rtp header - header[0] = 0x80 - header[1] = 0x78 - struct.pack_into(">H", header, 2, self.sequence) - struct.pack_into(">I", header, 4, self.timestamp) - struct.pack_into(">I", header, 8, self.ssrc) - - encrypt_packet = getattr(self, f"_encrypt_{self.mode}") - return encrypt_packet(header, data) - - def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = bytearray(24) - nonce[:12] = header - - return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext - - def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) - - return header + box.encrypt(bytes(data), nonce).ciphertext + nonce - - def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = bytearray(24) - - nonce[:4] = struct.pack(">I", self._lite_nonce) - self.checked_add("_lite_nonce", 1, 4294967295) - - return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - - def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes: - # Required as of Nov 18 2024 - box = nacl.secret.Aead(bytes(self.secret_key)) - nonce = bytearray(24) - - nonce[:4] = struct.pack(">I", self._lite_nonce) - self.checked_add("_lite_nonce", 1, 4294967295) - - return ( - header - + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext - + nonce[:4] - ) - - def _decrypt_xsalsa20_poly1305(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:12] = header - - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) - - def _decrypt_xsalsa20_poly1305_suffix(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce_size = nacl.secret.SecretBox.NONCE_SIZE - nonce = data[-nonce_size:] - - return self.strip_header_ext(box.decrypt(bytes(data[:-nonce_size]), nonce)) - - def _decrypt_xsalsa20_poly1305_lite(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:4] = data[-4:] - data = data[:-4] - - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) - - def _decrypt_aead_xchacha20_poly1305_rtpsize(self, header, data): - # Required as of Nov 18 2024 - box = nacl.secret.Aead(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:4] = data[-4:] - data = data[:-4] - - return self.strip_header_ext( - box.decrypt(bytes(data), bytes(header), bytes(nonce)) - ) - - @staticmethod - def strip_header_ext(data): - if len(data) > 4 and data[0] == 0xBE and data[1] == 0xDE: - _, length = struct.unpack_from(">HH", data) - offset = 4 + length * 4 - data = data[offset:] - return data - - def get_ssrc(self, user_id): - return {info["user_id"]: ssrc for ssrc, info in self.ws.ssrc_map.items()}[ - user_id - ] - - @overload - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: Literal[False] = False, - ) -> None: ... - - @overload - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: Literal[True], - ) -> asyncio.Future: ... - - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: bool = False, - ) -> None | asyncio.Future: - """Plays an :class:`AudioSource`. - - The finalizer, ``after`` is called after the source has been exhausted - or an error occurred. - - If an error happens while the audio player is running, the exception is - caught and the audio player is then stopped. If no after callback is - passed, any caught exception will be displayed as if it were raised. - - Parameters - ---------- - source: :class:`AudioSource` - The audio source we're reading from. - after: Callable[[Optional[:class:`Exception`]], Any] - The finalizer that is called after the stream is exhausted. - This function must have a single parameter, ``error``, that - denotes an optional exception that was raised during playing. - wait_finish: bool - If True, an awaitable will be returned, which can be used to wait for - audio to stop playing. This awaitable will return an exception if raised, - or None when no exception is raised. - - If False, None is returned and the function does not block. - - .. versionadded:: v2.5 - - Raises - ------ - ClientException - Already playing audio or not connected. - TypeError - Source is not a :class:`AudioSource` or after is not a callable. - OpusNotLoaded - Source is not opus encoded and opus is not loaded. - """ - - if not self.is_connected(): - raise ClientException("Not connected to voice.") - - if self.is_playing(): - raise ClientException("Already playing audio.") - - if not isinstance(source, AudioSource): - raise TypeError( - f"source must be an AudioSource not {source.__class__.__name__}" - ) - - if not self.encoder and not source.is_opus(): - self.encoder = opus.Encoder() - - future = None - if wait_finish: - future = asyncio.Future() - after_callback = after - - def _after(exc: Exception | None): - if callable(after_callback): - after_callback(exc) - - future.set_result(exc) - - after = _after - - self._player = AudioPlayer(source, self, after=after) - self._player.start() - return future - - def unpack_audio(self, data): - """Takes an audio packet received from Discord and decodes it into pcm audio data. - If there are no users talking in the channel, `None` will be returned. - - You must be connected to receive audio. - - .. versionadded:: 2.0 - - Parameters - ---------- - data: :class:`bytes` - Bytes received by Discord via the UDP connection used for sending and receiving voice data. - """ - if data[1] != 0x78: - # We Should Ignore Any Payload Types We Do Not Understand - # Ref RFC 3550 5.1 payload type - # At Some Point We Noted That We Should Ignore Only Types 200 - 204 inclusive. - # They Were Marked As RTCP: Provides Information About The Connection - # This Was Too Broad Of A Whitelist, It Is Unclear If This Is Too Narrow Of A Whitelist - return - if self.paused: - return - - data = RawData(data, self) - - if data.decrypted_data == b"\xf8\xff\xfe": # Frame of silence - return - - self.decoder.decode(data) - - def start_recording(self, sink, callback, *args, sync_start: bool = False): - """The bot will begin recording audio from the current voice channel it is in. - This function uses a thread so the current code line will not be stopped. - Must be in a voice channel to use. - Must not be already recording. - - .. versionadded:: 2.0 - - Parameters - ---------- - sink: :class:`.Sink` - A Sink which will "store" all the audio data. - callback: :ref:`coroutine ` - A function which is called after the bot has stopped recording. - *args: - Args which will be passed to the callback function. - sync_start: :class:`bool` - If True, the recordings of subsequent users will start with silence. - This is useful for recording audio just as it was heard. - - Raises - ------ - RecordingException - Not connected to a voice channel. - RecordingException - Already recording. - RecordingException - Must provide a Sink object. - """ - if not self.is_connected(): - raise RecordingException("Not connected to voice channel.") - if self.recording: - raise RecordingException("Already recording.") - if not isinstance(sink, Sink): - raise RecordingException("Must provide a Sink object.") - - self.empty_socket() - - self.decoder = opus.DecodeManager(self) - self.decoder.start() - self.recording = True - self.sync_start = sync_start - self.sink = sink - sink.init(self) - - t = threading.Thread( - target=self.recv_audio, - args=( - sink, - callback, - *args, - ), - ) - t.start() - - def stop_recording(self): - """Stops the recording. - Must be already recording. - - .. versionadded:: 2.0 - - Raises - ------ - RecordingException - Not currently recording. - """ - if not self.recording: - raise RecordingException("Not currently recording audio.") - self.decoder.stop() - self.recording = False - self.paused = False - - def toggle_pause(self): - """Pauses or unpauses the recording. - Must be already recording. - - .. versionadded:: 2.0 - - Raises - ------ - RecordingException - Not currently recording. - """ - if not self.recording: - raise RecordingException("Not currently recording audio.") - self.paused = not self.paused - - def empty_socket(self): - while True: - ready, _, _ = select.select([self.socket], [], [], 0.0) - if not ready: - break - for s in ready: - s.recv(4096) - - def recv_audio(self, sink, callback, *args): - # Gets data from _recv_audio and sorts - # it by user, handles pcm files and - # silence that should be added. - - self.user_timestamps: dict[int, tuple[int, int, float]] = {} - self.starting_time = time.perf_counter() - self.first_packet_timestamp: float - while self.recording: - ready, _, err = select.select([self.socket], [], [self.socket], 0.01) - if not ready: - if err: - print(f"Socket error: {err}") - continue - - try: - data = self.socket.recv(4096) - except OSError: - self.stop_recording() - continue - - self.unpack_audio(data) - - self.stopping_time = time.perf_counter() - self.sink.cleanup() - callback = asyncio.run_coroutine_threadsafe(callback(sink, *args), self.loop) - result = callback.result() - - if result is not None: - print(result) - - def recv_decoded_audio(self, data: RawData): - # Add silence when they were not being recorded. - data.user_id = self.ws.ssrc_map.get(data.ssrc, {}).get("user_id") - - if data.user_id is None: - _log.debug( - f"DEBUG: received packet with SSRC {data.ssrc} not linked to a user_id." - f"Queueing for later processing." - ) - self.temp_queued_data.setdefault(data.ssrc, []).append(data) - return - elif data.ssrc in self.temp_queued_data: - _log.debug( - "DEBUG: We got %d packet(s) in queue for SSRC %d", - len(self.temp_queued_data[data.ssrc]), - data.ssrc, - ) - queued_packets = self.temp_queued_data.pop(data.ssrc) - for q_packet in queued_packets: - q_packet.user_id = data.user_id - self._process_audio_packet(q_packet) - - self._process_audio_packet(data) - - def _process_audio_packet(self, data: RawData): - if data.user_id not in self.user_timestamps: # First packet from user - if ( - not self.user_timestamps or not self.sync_start - ): # First packet from anyone - self.first_packet_timestamp = data.receive_time - silence = 0 - - else: # Previously received a packet from someone else - silence = ( - (data.receive_time - self.first_packet_timestamp) * 48000 - ) - 960 - - else: # Previously received a packet from user - prev_ssrc = self.user_timestamps[data.user_id][0] - prev_timestamp = self.user_timestamps[data.user_id][1] - prev_receive_time = self.user_timestamps[data.user_id][2] - - if data.ssrc != prev_ssrc: - _log.info( - f"Received audio data from USER_ID {data.user_id} with a previous SSRC {prev_ssrc} and new " - f"SSRC {data.ssrc}." - ) - dRT = (data.receive_time - prev_receive_time) * 1000 - silence = max(0, int(dRT / (1000 / 48000))) - 960 - else: - dRT = ( - data.receive_time - prev_receive_time - ) * 48000 # delta receive time - dT = data.timestamp - prev_timestamp # delta timestamp - diff = abs(100 - dT * 100 / dRT) - if ( - diff > 60 and dT != 960 - ): # If the difference in change is more than 60% threshold - silence = dRT - 960 - else: - silence = dT - 960 - - self.user_timestamps.update( - {data.user_id: (data.ssrc, data.timestamp, data.receive_time)} - ) - - data.decoded_data = ( - struct.pack(" bool: - """Indicates if we're currently playing audio.""" - return self._player is not None and self._player.is_playing() - - def is_paused(self) -> bool: - """Indicates if we're playing audio, but if we're paused.""" - return self._player is not None and self._player.is_paused() - - def stop(self) -> None: - """Stops playing audio.""" - if self._player: - self._player.stop() - self._player = None - - def pause(self) -> None: - """Pauses the audio playing.""" - if self._player: - self._player.pause() - - def resume(self) -> None: - """Resumes the audio playing.""" - if self._player: - self._player.resume() - - @property - def source(self) -> AudioSource | None: - """The audio source being played, if playing. - - This property can also be used to change the audio source currently being played. - """ - return self._player.source if self._player else None - - @source.setter - def source(self, value: AudioSource) -> None: - if not isinstance(value, AudioSource): - raise TypeError(f"expected AudioSource not {value.__class__.__name__}.") - - if self._player is None: - raise ValueError("Not playing anything.") - - self._player._set_source(value) - - def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: - """Sends an audio packet composed of the data. - - You must be connected to play audio. - - Parameters - ---------- - data: :class:`bytes` - The :term:`py:bytes-like object` denoting PCM or Opus voice data. - encode: :class:`bool` - Indicates if ``data`` should be encoded into Opus. - - Raises - ------ - ClientException - You are not connected. - opus.OpusError - Encoding the data failed. - """ - - self.checked_add("sequence", 1, 65535) - if encode: - if not self.encoder: - self.encoder = opus.Encoder() - encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) - else: - encoded_data = data - packet = self._get_voice_packet(encoded_data) - try: - self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) - except BlockingIOError: - _log.warning( - "A packet has been dropped (seq: %s, timestamp: %s)", - self.sequence, - self.timestamp, - ) - - self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295) - - def elapsed(self) -> datetime.timedelta: - """Returns the elapsed time of the playing audio.""" - if self._player: - return datetime.timedelta(milliseconds=self._player.played_frames() * 20) - return datetime.timedelta() diff --git a/docs/api/sinks.rst b/docs/api/sinks.rst index 0ee8ca3a73..e5a199489d 100644 --- a/docs/api/sinks.rst +++ b/docs/api/sinks.rst @@ -6,25 +6,28 @@ Sinks Core ---- -.. autoclass:: discord.sinks.Filters +.. autoclass:: discord.sinks.Sink :members: -.. autoclass:: discord.sinks.Sink +.. autoclass:: discord.sinks.RawData :members: -.. autoclass:: discord.sinks.AudioData +.. autoclass:: discord.sinks.SinkHandler :members: -.. autoclass:: discord.sinks.RawData +.. autoclass:: discord.sinks.SinkFilter :members: -Sink Classes ------------- +Default Sinks +------------- .. autoclass:: discord.sinks.WaveSink :members: +.. autoclass:: discord.sinks.WavSink + :members: + .. autoclass:: discord.sinks.MP3Sink :members: @@ -42,3 +45,84 @@ Sink Classes .. autoclass:: discord.sinks.OGGSink :members: + +Events +------ +These section outlines all the available sink events. + +.. function:: on_voice_packet_receive(user, data) + Called when a voice packet is received from a member. + + This is called **after** the filters went through. + + :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param data: The RawData of the packet. + :type data: :class:`~.RawData` + +.. function:: on_unfiltered_voice_packet_receive(user, data) + Called when a voice packet is received from a member. + + Unlike ``on_voice_packet_receive``, this is called **before any filters** are called. + + :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param data: The RawData of the packet. + :type data: :class:`~.RawData` + +.. function:: on_speaking_state_update(user, before, after) + Called when a member's voice state changes. + + This is called **after** the filters went through. + + :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param before: The user's state before it was updated. + :type before: :class:`~discord.SpeakingFlags` + :param after: The user's state after it was updated. + :type after: :class:`~discord.SpeakingFlags` + +.. function:: on_unfiltered_speaking_state_update(user, before, after) + Called when a voice packet is received from a member. + + Unlike ``on_speaking_state_update``, this is called **before any filters** are called. + + :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param before: The user's state before it was updated. + :type before: :class:`~discord.SpeakingFlags` + :param after: The user's state after it was updated. + :type after: :class:`~discord.SpeakingFlags` + +.. function:: on_user_connect(user, channel) + Called when a user connects to a voice channel. + + This is called **after** the filters went through. + + :param user: The user that has connected to the voice channel. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param channel: The channel the user has connected to. This usually resolved to the correct channel type, but if it fails + it defaults to a :class:`~discord.Object` object. + :type channel: :class:`~discord.abc.Snowflake` + +.. function:: on_unfiltered_user_connect(user, channel) + Called when a user connects to a voice channel. + + Unlike ``on_user_connect``, this is called **before any filters** are called. + + :param user: The user that has connected to the voice channel. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param channel: The channel the user has connected to. This usually resolved to the correct channel type, but if it fails + it defaults to a :class:`~discord.Object` object. + :type channel: :class:`~discord.abc.Snowflake` + +.. function:: on_error(event, exception, \*args, \*\*kwargs) + Called when an error ocurrs in any of the events above. The default implementation logs the exception + to stdout. + + :param event: The event in which the error ocurred. + :type event: :class:`str` + :param exception: The exception that ocurred. + :type exception: :class:`Exception` + :param \*args: The arguments that were passed to the event. + :param \*\*kwargs: The key-word arguments that were passed to the event. diff --git a/docs/api/voice.rst b/docs/api/voice.rst index 95d2b8441d..59de02185d 100644 --- a/docs/api/voice.rst +++ b/docs/api/voice.rst @@ -6,13 +6,13 @@ Voice Related Objects ------- -.. attributetable:: VoiceClient +.. attributetable:: discord.voice.VoiceClient .. autoclass:: VoiceClient() :members: :exclude-members: connect, on_voice_state_update, on_voice_server_update -.. attributetable:: VoiceProtocol +.. attributetable:: discord.voice.VoiceProtocol .. autoclass:: VoiceProtocol :members: diff --git a/pyproject.toml b/pyproject.toml index 7f2930ee65..4b0cbde12a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ packages = [ "discord.ext.pages", "discord.ext.bridge", "discord.bin", + "discord.voice", ] [tool.setuptools.dynamic] diff --git a/requirements/voice.txt b/requirements/voice.txt index 6382712eac..71df4cdb2d 100644 --- a/requirements/voice.txt +++ b/requirements/voice.txt @@ -1 +1,2 @@ PyNaCl>=1.3.0,<1.6 +davey==0.1.0rc3