From a98025a0235c81d214a0b4c78617b4e622661308 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 22 Aug 2025 00:13:16 +0200 Subject: [PATCH 01/40] first commit on voice fixes --- discord/__init__.py | 2 +- discord/gateway.py | 180 ++++++++++++++++++++------------ discord/raw_models.py | 199 +++++++++++++++++++++++++++++++++++- discord/types/raw_models.py | 4 + discord/types/voice.py | 12 +-- discord/utils.py | 4 +- discord/voice/__init__.py | 17 +++ discord/voice/_types.py | 158 ++++++++++++++++++++++++++++ discord/voice/client.py | 54 ++++++++++ discord/voice/errors.py | 92 +++++++++++++++++ discord/voice/gateway.py | 92 +++++++++++++++++ discord/voice/state.py | 87 ++++++++++++++++ discord/voice_client.py | 111 -------------------- 13 files changed, 821 insertions(+), 191 deletions(-) create mode 100644 discord/voice/__init__.py create mode 100644 discord/voice/_types.py create mode 100644 discord/voice/client.py create mode 100644 discord/voice/errors.py create mode 100644 discord/voice/gateway.py create mode 100644 discord/voice/state.py diff --git a/discord/__init__.py b/discord/__init__.py index afe3002e00..e1ea74d185 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -71,7 +71,7 @@ from .template import * from .threads import * from .user import * -from .voice_client import * +from .voice import * from .webhook import * from .welcome_screen import * from .widget import * diff --git a/discord/gateway.py b/discord/gateway.py index e968bc9858..d08b7307e7 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable import concurrent.futures import logging import struct @@ -34,8 +35,8 @@ import time import traceback import zlib -from collections import deque, namedtuple -from typing import TYPE_CHECKING +from collections import deque +from typing import TYPE_CHECKING, Any, NamedTuple import aiohttp @@ -44,6 +45,12 @@ from .enums import SpeakingState from .errors import ConnectionClosed, InvalidArgument +if TYPE_CHECKING: + from typing_extensions import Self + + from .client import Client + from .state import ConnectionState + _log = logging.getLogger(__name__) __all__ = ( @@ -68,26 +75,30 @@ class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" -EventListener = namedtuple("EventListener", "predicate event result future") +class EventListener(NamedTuple): + predicate: Callable[[dict[str, Any]], bool] + event: str + result: Callable[[dict[str, Any]], Any] | None + future: asyncio.Future[Any] class GatewayRatelimiter: - def __init__(self, count=110, per=60.0): + def __init__(self, count: int = 110, per: float = 60.0): # The default is 110 to give room for at least 10 heartbeats per minute - self.max = count - self.remaining = count - self.window = 0.0 - self.per = per - self.lock = asyncio.Lock() - self.shard_id = None - - def is_ratelimited(self): + self.max: int = count + self.remaining: int = count + self.window: float = 0.0 + self.per: float = per + self.lock: asyncio.Lock = asyncio.Lock() + self.shard_id: int | None = None + + def is_ratelimited(self) -> bool: current = time.time() if current > self.window + self.per: return False return self.remaining == 0 - def get_delay(self): + def get_delay(self) -> float: current = time.time() if current > self.window + self.per: @@ -105,7 +116,7 @@ def get_delay(self): return 0.0 - async def block(self): + async def block(self) -> None: async with self.lock: delta = self.get_delay() if delta: @@ -118,12 +129,16 @@ async def block(self): class KeepAliveHandler(threading.Thread): - def __init__(self, *args, **kwargs): - ws = kwargs.pop("ws", None) - interval = kwargs.pop("interval", None) - shard_id = kwargs.pop("shard_id", None) + def __init__( + self, + *args: Any, + ws: DiscordWebSocket, + shard_id: int | None = None, + interval: float | None = None, + **kwargs: Any, + ) -> None: threading.Thread.__init__(self, *args, **kwargs) - self.ws = ws + self.ws: DiscordWebSocket = ws self._main_thread_id = ws.thread_id self.interval = interval self.daemon = True @@ -292,52 +307,63 @@ class DiscordWebSocket: HEARTBEAT_ACK = 11 GUILD_SYNC = 12 - def __init__(self, socket, *, loop): - self.socket = socket - self.loop = loop + if TYPE_CHECKING: + token: str | None + _connection: ConnectionState + _discord_parsers: dict[str, Callable[..., Any]] + call_hooks: Callable[..., Any] + gateway: str + _initial_identify: bool + shard_id: int | None + shard_count: int | None + _max_heartbeat_timeout: float + + def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: + self.socket: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes - self._dispatch = lambda *args: None + self._dispatch: Callable[..., Any] = lambda *args: None # generic event listeners - self._dispatch_listeners = [] + self._dispatch_listeners: list[EventListener] = [] # the keep alive - self._keep_alive = None - self.thread_id = threading.get_ident() + self._keep_alive: KeepAliveHandler | None = None + self.thread_id: int = threading.get_ident() # ws related stuff - self.session_id = None - self.sequence = None - self.resume_gateway_url = None - self._zlib = zlib.decompressobj() - self._buffer = bytearray() - self._close_code = None - self._rate_limiter = GatewayRatelimiter() + self.session_id: str | None = None + self.sequence: int | None = None + self.resume_gateway_url: str | None = None + self._zlib: zlib._Decompress = zlib.decompressobj() + self._buffer: bytearray = bytearray() + self._close_code: int | None = None + self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @property - def open(self): + def open(self) -> bool: return not self.socket.closed - def is_ratelimited(self): + def is_ratelimited(self) -> bool: return self._rate_limiter.is_ratelimited() - def debug_log_receive(self, data, /): + def debug_log_receive(self, data: dict[str, Any], /) -> None: self._dispatch("socket_raw_receive", data) - def log_receive(self, _, /): + def log_receive(self, _: dict[str, Any], /) -> None: pass @classmethod async def from_client( cls, - client, + client: Client, *, - initial=False, - gateway=None, - shard_id=None, - session=None, - sequence=None, - resume=False, - ): + initial: bool = False, + gateway: str | None = None, + shard_id: int | None = None, + session: str | None = None, + sequence: int | None = None, + resume: bool = False, + ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -379,7 +405,12 @@ async def from_client( await ws.resume() return ws - def wait_for(self, event, predicate, result=None): + def wait_for( + self, + event: str, + predicate: Callable[[dict[str, Any]], bool], + result: Callable[[dict[str, Any]], Any] | None = None, + ) -> asyncio.Future[Any]: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -406,7 +437,7 @@ def wait_for(self, event, predicate, result=None): self._dispatch_listeners.append(entry) return future - async def identify(self): + async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { "op": self.IDENTIFY, @@ -419,7 +450,6 @@ async def identify(self): }, "compress": True, "large_threshold": 250, - "v": 3, }, } @@ -444,7 +474,7 @@ async def identify(self): await self.send_as_json(payload) _log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) - async def resume(self): + async def resume(self) -> None: """Sends the RESUME packet.""" payload = { "op": self.RESUME, @@ -458,7 +488,7 @@ async def resume(self): await self.send_as_json(payload) _log.info("Shard ID %s has sent the RESUME payload.", self.shard_id) - async def received_message(self, msg, /): + async def received_message(self, msg: Any, /): if type(msg) is bytes: self._buffer.extend(msg) @@ -594,7 +624,7 @@ def latency(self) -> float: heartbeat = self._keep_alive return float("inf") if heartbeat is None else heartbeat.latency - def _can_handle_close(self): + def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code is_improper_close = self._close_code is None and self.socket.close_code == 1000 return is_improper_close or code not in ( @@ -607,7 +637,7 @@ def _can_handle_close(self): 4014, ) - async def poll_event(self): + async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. Raises @@ -621,11 +651,12 @@ async def poll_event(self): await self.received_message(msg.data) elif msg.type is aiohttp.WSMsgType.BINARY: await self.received_message(msg.data) + elif msg.type is aiohttp.WSMsgType.ERROR: + _log.debug('Received an error %s', msg) elif msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.ERROR, ): _log.debug("Received %s", msg) raise WebSocketClosure @@ -649,23 +680,23 @@ async def poll_event(self): self.socket, shard_id=self.shard_id, code=code ) from None - async def debug_send(self, data, /): + async def debug_send(self, data: str, /) -> None: await self._rate_limiter.block() self._dispatch("socket_raw_send", data) await self.socket.send_str(data) - async def send(self, data, /): + async def send(self, data: str, /) -> None: await self._rate_limiter.block() await self.socket.send_str(data) - async def send_as_json(self, data): + async def send_as_json(self, data: Any) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def send_heartbeat(self, data): + async def send_heartbeat(self, data: Any) -> None: # This bypasses the rate limit handling code since it has a higher priority try: await self.socket.send_str(utils._to_json(data)) @@ -673,13 +704,19 @@ async def send_heartbeat(self, data): if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity=None, status=None, since=0.0): + async def change_presence( + self, + *, + activity: BaseActivity | None = None, + status: str | None = None, + since: float = 0.0, + ) -> None: if activity is not None: if not isinstance(activity, BaseActivity): raise InvalidArgument("activity must derive from BaseActivity.") - activity = [activity.to_dict()] + activities = [activity.to_dict()] else: - activity = [] + activities = [] if status == "idle": since = int(time.time() * 1000) @@ -687,7 +724,7 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): payload = { "op": self.PRESENCE, "d": { - "activities": activity, + "activities": activities, "afk": False, "since": since, "status": status, @@ -699,8 +736,15 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): await self.send(sent) async def request_chunks( - self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None - ): + self, + guild_id: int, + query: str | None = None, + *, + limit: int, + user_ids: list[int] | None = None, + presences: bool = False, + nonce: str | None = None, + ) -> None: payload = { "op": self.REQUEST_MEMBERS, "d": {"guild_id": guild_id, "presences": presences, "limit": limit}, @@ -717,7 +761,13 @@ async def request_chunks( await self.send_as_json(payload) - async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + async def voice_state( + self, + guild_id: int, + channel_id: int, + self_mute: bool = False, + self_deaf: bool = False, + ) -> None: payload = { "op": self.VOICE_STATE, "d": { @@ -731,7 +781,7 @@ async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=Fal _log.debug("Updating our voice state to %s.", payload) await self.send_as_json(payload) - async def close(self, code=4000): + async def close(self, code: int = 4000) -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None diff --git a/discord/raw_models.py b/discord/raw_models.py index 73da688b7f..d29f513161 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -25,14 +25,16 @@ from __future__ import annotations +from collections.abc import ItemsView, KeysView, ValuesView import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .automod import AutoModAction, AutoModTriggerType from .enums import AuditLogAction, ChannelType, ReactionType, try_enum +from . import utils if TYPE_CHECKING: - from .abc import MessageableChannel + from .abc import MessageableChannel, GuildChannel from .guild import Guild from .member import Member from .message import Message @@ -59,6 +61,8 @@ ThreadUpdateEvent, TypingEvent, VoiceChannelStatusUpdateEvent, + VoiceServerUpdateEvent, + VoiceStateEvent, ) from .user import User @@ -81,12 +85,16 @@ "RawAuditLogEntryEvent", "RawVoiceChannelStatusUpdateEvent", "RawMessagePollVoteEvent", + "RawVoiceServerUpdateEvent", + "RawVoiceStateUpdateEvent", ) class _RawReprMixin: + __slots__: tuple[str, ...] + def __repr__(self) -> str: - value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__) + value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__ if not attr.startswith('_')) return f"<{self.__class__.__name__} {value}>" @@ -841,3 +849,188 @@ def __init__(self, data: MessagePollVoteEvent, added: bool) -> None: self.guild_id: int | None = int(data["guild_id"]) except KeyError: self.guild_id: int | None = None + +# this is for backwards compatibility because VoiceProtocol.on_voice_..._update +# passed the raw payload instead of a raw object. Emit deprecation warning. +class _PayloadLike(_RawReprMixin): + _raw_data: dict[str, Any] + + @utils.deprecated( + 'the attributes', + '2.7', + '3.0', + ) + def __getitem__(self, key: str) -> Any: + return self._raw_data[key] + + @utils.deprecated( + 'the attributes', + '2.7', + '3.0', + ) + def get(self, key: str, default: Any = None) -> Any: + """Gets an item from this raw event, and returns its value or ``default``. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.get(key, default) + + @utils.deprecated( + 'the attributes', + '2.7', + '3.0', + ) + def items(self) -> ItemsView: + """Returns the (key, value) pairs of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.items() + + @utils.deprecated( + 'the attributes', + '2.7', + '3.0', + ) + def values(self) -> ValuesView: + """Returns the values of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.values() + + @utils.deprecated( + 'the attributes', + '2.7', + '3.0', + ) + def keys(self) -> KeysView: + """Returns the keys of this raw event. + + .. deprecated:: 2.7 + Use the attributes instead. + """ + return self._raw_data.keys() + + +class RawVoiceStateUpdateEvent(_PayloadLike): + """Represents the payload for a :meth:`VoiceProtocol.on_voice_state_update` event. + + .. versionadded:: 2.7 + + Attributes + ---------- + deaf: :class:`bool` + Whether the user is guild deafened. + mute: :class:`bool` + Whether the user is guild muted. + self_mute: :class:`bool` + Whether the user has muted themselves by their own accord. + self_deaf: :class:`bool` + Whether the user has deafened themselves by their own accord. + self_stream: :class:`bool` + Whether the user is currently streaming via the 'Go Live' feature. + self_video: :class:`bool` + Whether the user is currently broadcasting video. + suppress: :class:`bool` + Whether the user is suppressed from speaking in a stage channel. + requested_to_speak_at: Optional[:class:`datetime.datetime`] + An aware datetime object that specifies when a member has requested to speak + in a stage channel. It will be ``None`` if they are not requesting to speak + anymore or have been accepted to. + afk: :class:`bool` + Whether the user is connected on the guild's AFK channel. + channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]] + The voice channel that the user is currently connected to. ``None`` if the user + is not currently in a voice channel. + + There are certain scenarios in which this is impossible to be ``None``. + session_id: :class:`str` + The voice connection session ID. + guild_id: Optional[:class:`int`] + The guild ID the user channel is from. + channel_id: Optional[:class:`int`] + The channel ID the user is connected to. Or ``None`` if not connected to any. + """ + + __slots__ = ( + 'session_id', + 'mute', + 'deaf', + 'self_mute', + 'self_deaf', + 'self_stream', + 'self_video', + 'suppress', + 'requested_to_speak_at', + 'afk', + 'channel', + 'guild_id', + 'channel_id', + '_state', + '_raw_data', + ) + + def __init__(self, *, data: VoiceStateEvent, state: ConnectionState) -> None: + self.session_id: str = data['session_id'] + self._state: ConnectionState = state + + self.self_mute: bool = data.get('self_mute', False) + self.self_deaf: bool = data.get('self_deaf', False) + self.mute: bool = data.get('mute', False) + self.deaf: bool = data.get('deaf', False) + self.suppress: bool = data.get('suppress', False) + self.requested_to_speak_at: datetime.datetime | None = utils.parse_time( + data.get('request_to_speak_timestamp') + ) + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.channel_id: int | None = utils._get_as_snowflake(data, 'channel_id') + self._raw_data: VoiceStateEvent = data + + @property + def guild(self) -> Guild | None: + """Returns the guild channel the user is connected to, or ``None``.""" + return self._state._get_guild(self.guild_id) + + @property + def channel(self) -> GuildChannel | None: + """Returns the channel the user is connected to, or ``None``.""" + return self._state.get_channel(self.channel_id) # type: ignore + + +class RawVoiceServerUpdateEvent(_PayloadLike): + """Represents the payload for a :meth:`VoiceProtocol.on_voice_server_update` event. + + .. versionadded:: 2.7 + + Attributes + ---------- + token: :class:`str` + The voice connection token. This should not be shared. + guild_id: :class:`int` + The guild ID this token is part from. + endpoint: Optional[:class:`str`] + The voice server host to connect to. + """ + + __slots__ = ( + 'token', + 'guild_id', + 'endpoint', + '_raw_data', + '_state', + ) + + def __init__(self, *, data: VoiceServerUpdateEvent, state: ConnectionState) -> None: + self._state: ConnectionState = state + self.guild_id: int = int(data['guild_id']) + self.token: str = data['token'] + self.endpoint: str | None = data['endpoint'] + + @property + def guild(self) -> Guild | None: + """Returns the guild this server update is from.""" + return self._state._get_guild(self.guild_id) diff --git a/discord/types/raw_models.py b/discord/types/raw_models.py index 1a7feee059..473434b6de 100644 --- a/discord/types/raw_models.py +++ b/discord/types/raw_models.py @@ -33,6 +33,10 @@ from .snowflake import Snowflake from .threads import Thread, ThreadMember from .user import User +from .voice import ( + VoiceState as VoiceStateEvent, + VoiceServerUpdate as VoiceServerUpdateEvent, +) class _MessageEventOptional(TypedDict, total=False): diff --git a/discord/types/voice.py b/discord/types/voice.py index 68d99ccd48..3840782a2f 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -40,7 +40,7 @@ ] -class _VoiceState(TypedDict): +class VoiceState(TypedDict): member: NotRequired[MemberWithUser] self_stream: NotRequired[bool] user_id: Snowflake @@ -51,15 +51,9 @@ class _VoiceState(TypedDict): self_mute: bool self_video: bool suppress: bool - - -class GuildVoiceState(_VoiceState): - channel_id: Snowflake - - -class VoiceState(_VoiceState, total=False): + request_to_speak_timestamp: str | None channel_id: Snowflake | None - guild_id: Snowflake + guild_id: NotRequired[Snowflake] class VoiceRegion(TypedDict): diff --git a/discord/utils.py b/discord/utils.py index ee4fc2ccd7..e7b4ffde0d 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -325,7 +325,7 @@ def deprecated( stacklevel: int = 3, *, use_qualname: bool = True, -) -> Callable[[Callable[[P], T]], Callable[[P], T]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """A decorator implementation of :func:`warn_deprecated`. This will automatically call :func:`warn_deprecated` when the decorated function is called. @@ -350,7 +350,7 @@ def deprecated( will display as ``login``. Defaults to ``True``. """ - def actual_decorator(func: Callable[[P], T]) -> Callable[[P], T]: + def actual_decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) def decorated(*args: P.args, **kwargs: P.kwargs) -> T: warn_deprecated( diff --git a/discord/voice/__init__.py b/discord/voice/__init__.py new file mode 100644 index 0000000000..ad1e47b1de --- /dev/null +++ b/discord/voice/__init__.py @@ -0,0 +1,17 @@ +""" +discord.voice +~~~~~~~~~~~~~ + +Voice support for the Discord API. + +:copyright: (c) 2015-2021 Rapptz & 2021-present Pycord Development +:license: MIT, see LICENSE for more details. +""" + +from .client import VoiceClient +from ._types import VoiceProtocol + +__all__ = ( + 'VoiceClient', + 'VoiceProtocol', +) diff --git a/discord/voice/_types.py b/discord/voice/_types.py new file mode 100644 index 0000000000..2a677edaf4 --- /dev/null +++ b/discord/voice/_types.py @@ -0,0 +1,158 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from discord import abc + from discord.client import Client + from discord.raw_models import ( + RawVoiceStateUpdateEvent, + RawVoiceServerUpdateEvent, + ) + +ClientT = TypeVar('ClientT', bound='Client', covariant=True) + + +class VoiceProtocol(Generic[ClientT]): + """A class that represents the Discord voice protocol. + + .. warning:: + + If you are a end user, you **should not construct this manually** but instead + take it from the return type in :meth:`abc.Connectable.connect `. + The parameters and methods being documented here is so third party libraries can refer to it + when implementing their own VoiceProtocol types. + + This is an abstract class. The library provides a concrete implementation + under :class:`VoiceClient`. + + This class allows you to implement a protocol to allow for an external + method of sending voice, such as Lavalink_ or a native library implementation. + + These classes are passed to :meth:`abc.Connectable.connect `. + + .. _Lavalink: https://github.com/freyacodes/Lavalink + + Parameters + ---------- + client: :class:`Client` + The client (or its subclasses) that started the connection request. + channel: :class:`abc.Connectable` + The voice channel that is being connected to. + """ + + def __init__(self, client: ClientT, channel: abc.Connectable) -> None: + self.client: ClientT = client + self.channel: abc.Connectable = channel + + async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + """|coro| + + A method called when the client's voice state has changed. This corresponds + to the ``VOICE_STATE_UPDATE`` event. + + Parameters + ---------- + data: :class:`RawVoiceStateUpdateEvent` + The voice state payload. + + .. versionchanged:: 2.7 + This now gets passed a `RawVoiceStateUpdateEvent` object instead of a :class:`dict`, but + accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated. + """ + raise NotImplementedError + + async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + """|coro| + + A method called when the client's intially connecting to voice. This corresponds + to the ``VOICE_SERVER_UPDATE`` event. + + Parameters + ---------- + data: :class:`RawVoiceServerUpdateEvent` + The voice server payload. + + .. versionchanged:: 2.7 + This now gets passed a `RawVoiceServerUpdateEvent` object instead of a :class:`dict`, but + accessing keys via ``data[key]`` or ``data.get(key)`` is still supported, but deprecated. + """ + raise NotImplementedError + + async def connect(self, *, timeout: float, reconnect: bool) -> None: + """|coro| + + A method called to initialise the connection. + + The library initialises this class and calls ``__init__``, and then :meth:`connect` when attempting + to start a connection to the voice. If an error ocurrs, it calls :meth:`disconnect`, so if you need + to implement any cleanup, you should manually call it in :meth:`disconnect` as the library will not + do so for you. + + Within this method, to start the voice connection flow, it is recommened to use :meth:`Guild.change_voice_state` + to start the flow. After which :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called, + although this could vary and cause unexpected behaviour, but that falls under Discord's way of handling the voice + connection. + + Parameters + ---------- + timeout: :class:`float` + The timeout for the connection. + reconnect: :class:`bool` + Whether reconnection is expected. + """ + raise NotImplementedError + + async def disconnect(self, *, force: bool) -> None: + """|coro| + + A method called to terminate the voice connection. + + This can be either called manually when forcing a disconnection, or when an exception in :meth:`connect` ocurrs. + + It is recommended to call :meth:`cleanup` here. + + Parameters + ---------- + force: :class:`bool` + Whether the disconnection was forced. + """ + + def cleanup(self) -> None: + """This method *must* be called to ensure proper clean-up during a disconnect. + + It is advisable to call this from within :meth:`disconnect` when you are completely + done with the voice protocol instance. + + This method removes it from the internal state cache that keeps track of the currently + alive voice clients. Failure to clean-up will cause subsequent connections to report that + it's still connected. + + **The library will NOT automatically call this for you**, unlike :meth:`connect` and :meth:`disconnect`. + """ + key, _ = self.channel._get_voice_client_key() + self.client._connection._remove_voice_client(key) diff --git a/discord/voice/client.py b/discord/voice/client.py new file mode 100644 index 0000000000..702d24c46f --- /dev/null +++ b/discord/voice/client.py @@ -0,0 +1,54 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from ._types import VoiceProtocol + + +class VoiceClient(VoiceProtocol): + """Represents a Discord voice connection. + + You do not create these, you typically get them from e.g. :meth:`VoiceChannel.connect`. + + Attributes + ---------- + session_id: :class:`str` + The voice connection session ID. You should not share this. + token: :class:`str` + The voice connection token. You should not share this. + endpoint: :class:`str` + The endpoint the current client is connected to. + channel: :class:`abc.Connectable` + The voice channel connected to. + loop: :class:`asyncio.AbstractEventLoop` + The event loop that the voice client is running on. + + Warning + ------- + In order to use PCM based AudioSources, you must have the opus library + installed on your system and loaded through :func:`opus.load_opus`. + Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) + or the library will not be able to transmit audio. + """ diff --git a/discord/voice/errors.py b/discord/voice/errors.py new file mode 100644 index 0000000000..e29b50a106 --- /dev/null +++ b/discord/voice/errors.py @@ -0,0 +1,92 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from aiohttp import ClientWebSocketResponse + +from discord.errors import ClientException + + +class VoiceConnectionClosed(ClientException): + """Exception that's raised when a voice websocket connection + is closed for reasons that could not be handled internally. + + Attributes + ---------- + code: :class:`int` + The close code of the websocket. + reason: :class:`str` + The reason provided for the closure. + guild_id: :class:`int` + The guild ID the client was connected to. + channel_id: :class:`int` + The channel ID the client was connected to. + """ + + __slots__ = ( + 'code', + 'reason', + 'channel_id', + 'guild_id', + ) + + def __init__( + self, + socket: ClientWebSocketResponse, + channel_id: int, + guild_id: int, + *, + reason: str | None = None, + code: int | None = None, + ) -> None: + self.code: int = code or socket.close_code or -1 + self.reason: str = reason if reason is not None else "" + self.channel_id: int = channel_id + self.guild_id: int = guild_id + super().__init__( + f"The voice connection on {self.channel_id} (guild {self.guild_id}) was closed with {self.code}", + ) + + +class VoiceGuildMismatch(ClientException): + """Exception that's raised when, while connecting to a voice channel, the data + the library has differs from the one discord sends. + + Attributes + ---------- + expected: :class:`int` + The expected guild ID. This is the one the library has. + received: :class:`int` + The received guild ID. This is the one sent by discord. + """ + + __slots__ = ( + 'expected', + 'received', + ) + + def __init__(self, expt: int, recv: int) -> None: + self.expected: int = expt + self.received: int = recv diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py new file mode 100644 index 0000000000..06735de2fb --- /dev/null +++ b/discord/voice/gateway.py @@ -0,0 +1,92 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Coroutine +import logging +from typing import TYPE_CHECKING, Any + +import aiohttp + +from discord import utils +from discord.gateway import DiscordWebSocket + +from .errors import VoiceConnectionClosed + +if TYPE_CHECKING: + from .state import VoiceConnectionState + +_log = logging.getLogger(__name__) + + +class VoiceWebsocket(DiscordWebSocket): + if TYPE_CHECKING: + thread_id: int + gateway: str + _max_heartbeat_timeout: float + + VERSION = 8 + + def __init__( + self, + socket: aiohttp.ClientWebSocketResponse, + loop: asyncio.AbstractEventLoop, + state: VoiceConnectionState, + *, + hook: Callable[..., Coroutine[Any, Any, Any]] | None = None, + ) -> None: + self.ws: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop + self._keep_alive: VoiceKeepAliveHandler | None = None + self._close_code: int | None = None + self.secrety_key: list[int] | None = None + self.seq_ack: int = -1 + self.state: VoiceConnectionState = state + + if hook: + self._hook = hook + + def _hook(self, *args: Any) -> Any: + pass + + async def send_as_json(self, data: Any) -> None: + _log.debug('Sending voice websocket frame: %s', data) + await self.ws.send_str(utils._to_json(data)) + + send_heartbeat = send_as_json + + async def resume(self) -> None: + state = self._connection + + if not state.should_resume(): + if self.state.is_connected(): + await self.state.disconnect() + raise VoiceConnectionClosed( + self.ws, + channel_id=self.state.channel_id, + guild_id=self.state.guild_id, + reason='The library attempted a resume when it was not expected', + ) diff --git a/discord/voice/state.py b/discord/voice/state.py new file mode 100644 index 0000000000..e3d1e8689c --- /dev/null +++ b/discord/voice/state.py @@ -0,0 +1,87 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from .errors import VoiceGuildMismatch + +if TYPE_CHECKING: + from .client import VoiceClient + from .gateway import VoiceWebsocket + + from discord.member import VoiceState + from discord.types.voice import ( + VoiceState as VoiceStatePayload, + ) + + +class VoiceConnectionState: + def __init__( + self, + client: VoiceClient, + ws: VoiceWebsocket, + ) -> None: + self.client: VoiceClient = client + self.self_id: int = client.user.id + self.loop: asyncio.AbstractEventLoop = client.loop + + # this is used if we don't have our self-voice state available + self.self_mute: bool = client.self_muted + self.self_deaf: bool = client.self_deaf + + self._updated_server: asyncio.Event = asyncio.Event() + self._updated_state: asyncio.Event = asyncio.Event() + self.ws: VoiceWebsocket = ws + + @property + def connected(self) -> bool: + return ( + self._updated_server.is_set() and + self._updated_state.is_set() + ) + + @property + def guild_id(self) -> int: + return self.client.guild_id + + @property + def voice_guild_state(self) -> VoiceState: + return self.client.guild.me.voice + + @property + def channel_id(self) -> int: + return self.client.channel_id + + def update_state(self, payload: VoiceStatePayload) -> None: + # if we're here it means the guild is found + guild_id = int(payload['guild_id']) # type: ignore + + if self.guild_id != guild_id: + raise VoiceGuildMismatch(self.guild_id, guild_id) + + self.self_mute = payload['self_mute'] + self.self_deaf = payload['self_deaf'] diff --git a/discord/voice_client.py b/discord/voice_client.py index a60d730413..5fa3a67053 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -86,117 +86,6 @@ _log = logging.getLogger(__name__) - -class VoiceProtocol: - """A class that represents the Discord voice protocol. - - This is an abstract class. The library provides a concrete implementation - under :class:`VoiceClient`. - - This class allows you to implement a protocol to allow for an external - method of sending voice, such as Lavalink_ or a native library implementation. - - These classes are passed to :meth:`abc.Connectable.connect `. - - .. _Lavalink: https://github.com/freyacodes/Lavalink - - Parameters - ---------- - client: :class:`Client` - The client (or its subclasses) that started the connection request. - channel: :class:`abc.Connectable` - The voice channel that is being connected to. - """ - - def __init__(self, client: Client, channel: abc.Connectable) -> None: - self.client: Client = client - self.channel: abc.Connectable = channel - - async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - """|coro| - - An abstract method that is called when the client's voice state - has changed. This corresponds to ``VOICE_STATE_UPDATE``. - - Parameters - ---------- - data: :class:`dict` - The raw `voice state payload`__. - - .. _voice_state_update_payload: https://discord.com/developers/docs/resources/voice#voice-state-object - - __ voice_state_update_payload_ - """ - raise NotImplementedError - - async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - """|coro| - - An abstract method that is called when initially connecting to voice. - This corresponds to ``VOICE_SERVER_UPDATE``. - - Parameters - ---------- - data: :class:`dict` - The raw `voice server update payload`__. - - .. _voice_server_update_payload: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields - - __ voice_server_update_payload_ - """ - raise NotImplementedError - - async def connect(self, *, timeout: float, reconnect: bool) -> None: - """|coro| - - An abstract method called when the client initiates the connection request. - - When a connection is requested initially, the library calls the constructor - under ``__init__`` and then calls :meth:`connect`. If :meth:`connect` fails at - some point then :meth:`disconnect` is called. - - Within this method, to start the voice connection flow it is recommended to - use :meth:`Guild.change_voice_state` to start the flow. After which, - :meth:`on_voice_server_update` and :meth:`on_voice_state_update` will be called. - The order that these two are called is unspecified. - - Parameters - ---------- - timeout: :class:`float` - The timeout for the connection. - reconnect: :class:`bool` - Whether reconnection is expected. - """ - raise NotImplementedError - - async def disconnect(self, *, force: bool) -> None: - """|coro| - - An abstract method called when the client terminates the connection. - - See :meth:`cleanup`. - - Parameters - ---------- - force: :class:`bool` - Whether the disconnection was forced. - """ - raise NotImplementedError - - def cleanup(self) -> None: - """This method *must* be called to ensure proper clean-up during a disconnect. - - It is advisable to call this from within :meth:`disconnect` when you are - completely done with the voice protocol instance. - - This method removes it from the internal state cache that keeps track of - currently alive voice clients. Failure to clean-up will cause subsequent - connections to report that it's still connected. - """ - key_id, _ = self.channel._get_voice_client_key() - self.client._connection._remove_voice_client(key_id) - - class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. From 0ea20772ace3b8cd7c305b02fcdede501b795a5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Aug 2025 22:15:54 +0000 Subject: [PATCH 02/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/gateway.py | 11 +++- discord/raw_models.py | 107 +++++++++++++++++++----------------- discord/types/raw_models.py | 6 +- discord/voice/__init__.py | 6 +- discord/voice/_types.py | 5 +- discord/voice/client.py | 1 + discord/voice/errors.py | 13 +++-- discord/voice/gateway.py | 7 ++- discord/voice/state.py | 20 +++---- discord/voice_client.py | 1 + 10 files changed, 93 insertions(+), 84 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index d08b7307e7..91576bd4d6 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -26,7 +26,6 @@ from __future__ import annotations import asyncio -from collections.abc import Callable import concurrent.futures import logging import struct @@ -36,6 +35,7 @@ import traceback import zlib from collections import deque +from collections.abc import Callable from typing import TYPE_CHECKING, Any, NamedTuple import aiohttp @@ -318,7 +318,12 @@ class DiscordWebSocket: shard_count: int | None _max_heartbeat_timeout: float - def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + socket: aiohttp.ClientWebSocketResponse, + *, + loop: asyncio.AbstractEventLoop, + ) -> None: self.socket: aiohttp.ClientWebSocketResponse = socket self.loop: asyncio.AbstractEventLoop = loop @@ -652,7 +657,7 @@ async def poll_event(self) -> None: elif msg.type is aiohttp.WSMsgType.BINARY: await self.received_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received an error %s', msg) + _log.debug("Received an error %s", msg) elif msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, diff --git a/discord/raw_models.py b/discord/raw_models.py index d29f513161..0771159f28 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -25,16 +25,16 @@ from __future__ import annotations -from collections.abc import ItemsView, KeysView, ValuesView import datetime +from collections.abc import ItemsView, KeysView, ValuesView from typing import TYPE_CHECKING, Any +from . import utils from .automod import AutoModAction, AutoModTriggerType from .enums import AuditLogAction, ChannelType, ReactionType, try_enum -from . import utils if TYPE_CHECKING: - from .abc import MessageableChannel, GuildChannel + from .abc import GuildChannel, MessageableChannel from .guild import Guild from .member import Member from .message import Message @@ -94,7 +94,11 @@ class _RawReprMixin: __slots__: tuple[str, ...] def __repr__(self) -> str: - value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__ if not attr.startswith('_')) + value = " ".join( + f"{attr}={getattr(self, attr)!r}" + for attr in self.__slots__ + if not attr.startswith("_") + ) return f"<{self.__class__.__name__} {value}>" @@ -850,23 +854,24 @@ def __init__(self, data: MessagePollVoteEvent, added: bool) -> None: except KeyError: self.guild_id: int | None = None + # this is for backwards compatibility because VoiceProtocol.on_voice_..._update # passed the raw payload instead of a raw object. Emit deprecation warning. class _PayloadLike(_RawReprMixin): _raw_data: dict[str, Any] @utils.deprecated( - 'the attributes', - '2.7', - '3.0', + "the attributes", + "2.7", + "3.0", ) def __getitem__(self, key: str) -> Any: return self._raw_data[key] @utils.deprecated( - 'the attributes', - '2.7', - '3.0', + "the attributes", + "2.7", + "3.0", ) def get(self, key: str, default: Any = None) -> Any: """Gets an item from this raw event, and returns its value or ``default``. @@ -877,9 +882,9 @@ def get(self, key: str, default: Any = None) -> Any: return self._raw_data.get(key, default) @utils.deprecated( - 'the attributes', - '2.7', - '3.0', + "the attributes", + "2.7", + "3.0", ) def items(self) -> ItemsView: """Returns the (key, value) pairs of this raw event. @@ -890,9 +895,9 @@ def items(self) -> ItemsView: return self._raw_data.items() @utils.deprecated( - 'the attributes', - '2.7', - '3.0', + "the attributes", + "2.7", + "3.0", ) def values(self) -> ValuesView: """Returns the values of this raw event. @@ -903,9 +908,9 @@ def values(self) -> ValuesView: return self._raw_data.values() @utils.deprecated( - 'the attributes', - '2.7', - '3.0', + "the attributes", + "2.7", + "3.0", ) def keys(self) -> KeysView: """Returns the keys of this raw event. @@ -957,37 +962,37 @@ class RawVoiceStateUpdateEvent(_PayloadLike): """ __slots__ = ( - 'session_id', - 'mute', - 'deaf', - 'self_mute', - 'self_deaf', - 'self_stream', - 'self_video', - 'suppress', - 'requested_to_speak_at', - 'afk', - 'channel', - 'guild_id', - 'channel_id', - '_state', - '_raw_data', + "session_id", + "mute", + "deaf", + "self_mute", + "self_deaf", + "self_stream", + "self_video", + "suppress", + "requested_to_speak_at", + "afk", + "channel", + "guild_id", + "channel_id", + "_state", + "_raw_data", ) def __init__(self, *, data: VoiceStateEvent, state: ConnectionState) -> None: - self.session_id: str = data['session_id'] + self.session_id: str = data["session_id"] self._state: ConnectionState = state - self.self_mute: bool = data.get('self_mute', False) - self.self_deaf: bool = data.get('self_deaf', False) - self.mute: bool = data.get('mute', False) - self.deaf: bool = data.get('deaf', False) - self.suppress: bool = data.get('suppress', False) + self.self_mute: bool = data.get("self_mute", False) + self.self_deaf: bool = data.get("self_deaf", False) + self.mute: bool = data.get("mute", False) + self.deaf: bool = data.get("deaf", False) + self.suppress: bool = data.get("suppress", False) self.requested_to_speak_at: datetime.datetime | None = utils.parse_time( - data.get('request_to_speak_timestamp') + data.get("request_to_speak_timestamp") ) - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') - self.channel_id: int | None = utils._get_as_snowflake(data, 'channel_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") + self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") self._raw_data: VoiceStateEvent = data @property @@ -1017,18 +1022,18 @@ class RawVoiceServerUpdateEvent(_PayloadLike): """ __slots__ = ( - 'token', - 'guild_id', - 'endpoint', - '_raw_data', - '_state', + "token", + "guild_id", + "endpoint", + "_raw_data", + "_state", ) def __init__(self, *, data: VoiceServerUpdateEvent, state: ConnectionState) -> None: self._state: ConnectionState = state - self.guild_id: int = int(data['guild_id']) - self.token: str = data['token'] - self.endpoint: str | None = data['endpoint'] + self.guild_id: int = int(data["guild_id"]) + self.token: str = data["token"] + self.endpoint: str | None = data["endpoint"] @property def guild(self) -> Guild | None: diff --git a/discord/types/raw_models.py b/discord/types/raw_models.py index 473434b6de..7ec8a446ca 100644 --- a/discord/types/raw_models.py +++ b/discord/types/raw_models.py @@ -33,10 +33,8 @@ from .snowflake import Snowflake from .threads import Thread, ThreadMember from .user import User -from .voice import ( - VoiceState as VoiceStateEvent, - VoiceServerUpdate as VoiceServerUpdateEvent, -) +from .voice import VoiceServerUpdate as VoiceServerUpdateEvent +from .voice import VoiceState as VoiceStateEvent class _MessageEventOptional(TypedDict, total=False): diff --git a/discord/voice/__init__.py b/discord/voice/__init__.py index ad1e47b1de..feedaa9f52 100644 --- a/discord/voice/__init__.py +++ b/discord/voice/__init__.py @@ -8,10 +8,10 @@ :license: MIT, see LICENSE for more details. """ -from .client import VoiceClient from ._types import VoiceProtocol +from .client import VoiceClient __all__ = ( - 'VoiceClient', - 'VoiceProtocol', + "VoiceClient", + "VoiceProtocol", ) diff --git a/discord/voice/_types.py b/discord/voice/_types.py index 2a677edaf4..4acf832657 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from typing import TYPE_CHECKING, Generic, TypeVar @@ -30,11 +31,11 @@ from discord import abc from discord.client import Client from discord.raw_models import ( - RawVoiceStateUpdateEvent, RawVoiceServerUpdateEvent, + RawVoiceStateUpdateEvent, ) -ClientT = TypeVar('ClientT', bound='Client', covariant=True) +ClientT = TypeVar("ClientT", bound="Client", covariant=True) class VoiceProtocol(Generic[ClientT]): diff --git a/discord/voice/client.py b/discord/voice/client.py index 702d24c46f..8e1b787384 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from ._types import VoiceProtocol diff --git a/discord/voice/errors.py b/discord/voice/errors.py index e29b50a106..be35852c22 100644 --- a/discord/voice/errors.py +++ b/discord/voice/errors.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from aiohttp import ClientWebSocketResponse @@ -46,10 +47,10 @@ class VoiceConnectionClosed(ClientException): """ __slots__ = ( - 'code', - 'reason', - 'channel_id', - 'guild_id', + "code", + "reason", + "channel_id", + "guild_id", ) def __init__( @@ -83,8 +84,8 @@ class VoiceGuildMismatch(ClientException): """ __slots__ = ( - 'expected', - 'received', + "expected", + "received", ) def __init__(self, expt: int, recv: int) -> None: diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 06735de2fb..533952ccdf 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -22,11 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine import logging +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any import aiohttp @@ -73,7 +74,7 @@ def _hook(self, *args: Any) -> Any: pass async def send_as_json(self, data: Any) -> None: - _log.debug('Sending voice websocket frame: %s', data) + _log.debug("Sending voice websocket frame: %s", data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json @@ -88,5 +89,5 @@ async def resume(self) -> None: self.ws, channel_id=self.state.channel_id, guild_id=self.state.guild_id, - reason='The library attempted a resume when it was not expected', + reason="The library attempted a resume when it was not expected", ) diff --git a/discord/voice/state.py b/discord/voice/state.py index e3d1e8689c..089d18473d 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import asyncio @@ -30,14 +31,12 @@ from .errors import VoiceGuildMismatch if TYPE_CHECKING: + from discord.member import VoiceState + from discord.types.voice import VoiceState as VoiceStatePayload + from .client import VoiceClient from .gateway import VoiceWebsocket - from discord.member import VoiceState - from discord.types.voice import ( - VoiceState as VoiceStatePayload, - ) - class VoiceConnectionState: def __init__( @@ -59,10 +58,7 @@ def __init__( @property def connected(self) -> bool: - return ( - self._updated_server.is_set() and - self._updated_state.is_set() - ) + return self._updated_server.is_set() and self._updated_state.is_set() @property def guild_id(self) -> int: @@ -78,10 +74,10 @@ def channel_id(self) -> int: def update_state(self, payload: VoiceStatePayload) -> None: # if we're here it means the guild is found - guild_id = int(payload['guild_id']) # type: ignore + guild_id = int(payload["guild_id"]) # type: ignore if self.guild_id != guild_id: raise VoiceGuildMismatch(self.guild_id, guild_id) - self.self_mute = payload['self_mute'] - self.self_deaf = payload['self_deaf'] + self.self_mute = payload["self_mute"] + self.self_deaf = payload["self_deaf"] diff --git a/discord/voice_client.py b/discord/voice_client.py index 5fa3a67053..347c6bbc8e 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -86,6 +86,7 @@ _log = logging.getLogger(__name__) + class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. From 56caff57e44de7bb069dc79095ccf57b03e95685 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Tue, 26 Aug 2025 03:13:04 +0200 Subject: [PATCH 03/40] things --- discord/abc.py | 1 + discord/channel.py | 2 +- discord/gateway.py | 20 ++-- discord/member.py | 21 +++- discord/voice/client.py | 30 +----- discord/voice/enums.py | 54 ++++++++++ discord/voice/errors.py | 92 ----------------- discord/voice/gateway.py | 206 ++++++++++++++++++++++++++++++++++----- discord/voice/state.py | 160 ++++++++++++++++++++---------- discord/voice_client.py | 2 +- 10 files changed, 381 insertions(+), 207 deletions(-) create mode 100644 discord/voice/enums.py delete mode 100644 discord/voice/errors.py diff --git a/discord/abc.py b/discord/abc.py index 4f419a3942..6b90871436 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1934,6 +1934,7 @@ class Connectable(Protocol): __slots__ = () _state: ConnectionState + id: int def _get_voice_client_key(self) -> tuple[int, str]: raise NotImplementedError diff --git a/discord/channel.py b/discord/channel.py index de589a5502..c66fef3f87 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -1611,7 +1611,7 @@ def _update( self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload ) -> None: # This data will always exist - self.guild = guild + self.guild: Guild = guild self.name: str = data["name"] self.category_id: int | None = utils._get_as_snowflake(data, "parent_id") diff --git a/discord/gateway.py b/discord/gateway.py index d08b7307e7..855e393997 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -137,11 +137,17 @@ def __init__( interval: float | None = None, **kwargs: Any, ) -> None: - threading.Thread.__init__(self, *args, **kwargs) + daemon: bool = kwargs.pop('daemon', True) + name: str = kwargs.pop('name', f'keep-alive-handler:shard-{shard_id}') + super().__init__( + *args, + **kwargs, + daemon=daemon, + name=name, + ) self.ws: DiscordWebSocket = ws self._main_thread_id = ws.thread_id self.interval = interval - self.daemon = True self.shard_id = shard_id self.msg = "Keeping shard ID %s websocket alive with sequence %s." self.block_msg = "Shard ID %s heartbeat blocked for more than %s seconds." @@ -153,7 +159,7 @@ def __init__( self.latency = float("inf") self.heartbeat_timeout = ws._max_heartbeat_timeout - def run(self): + def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning( @@ -206,16 +212,16 @@ def run(self): else: self._last_send = time.perf_counter() - def get_payload(self): + def get_payload(self) -> dict[str, Any]: return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence} - def stop(self): + def stop(self) -> None: self._stop_ev.set() - def tick(self): + def tick(self) -> None: self._last_recv = time.perf_counter() - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self.latency = ack_time - self._last_send diff --git a/discord/member.py b/discord/member.py index 0ff90cce04..84d481b83c 100644 --- a/discord/member.py +++ b/discord/member.py @@ -51,8 +51,9 @@ ) if TYPE_CHECKING: + from .client import Client from .abc import Snowflake - from .channel import DMChannel, StageChannel, VoiceChannel + from .channel import DMChannel, VocalGuildChannel from .flags import PublicUserFlags from .guild import Guild from .message import Message @@ -63,11 +64,9 @@ from .types.member import MemberWithUser as MemberWithUserPayload from .types.member import UserWithMember as UserWithMemberPayload from .types.user import User as UserPayload - from .types.voice import GuildVoiceState as GuildVoiceStatePayload + from .types.voice import VoiceState as GuildVoiceStatePayload from .types.voice import VoiceState as VoiceStatePayload - VocalGuildChannel = Union[VoiceChannel, StageChannel] - class VoiceState: """Represents a Discord user's voice state. @@ -165,6 +164,20 @@ def __repr__(self) -> str: inner = " ".join("%s=%r" % t for t in attrs) return f"<{self.__class__.__name__} {inner}>" + @classmethod + def _create_default(cls, channel: VocalGuildChannel, client: Client) -> VoiceState: + self = cls( + data={ + 'channel_id': channel.id, + 'guild_id': channel.guild.id, + 'self_deaf': False, + 'self_mute': False, + 'user_id': client._connection.self_id, # type: ignore + }, + channel=channel, + ) + return self + def flatten_user(cls): for attr, value in itertools.chain( diff --git a/discord/voice/client.py b/discord/voice/client.py index 702d24c46f..f6e384d8de 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -24,31 +24,5 @@ """ from __future__ import annotations -from ._types import VoiceProtocol - - -class VoiceClient(VoiceProtocol): - """Represents a Discord voice connection. - - You do not create these, you typically get them from e.g. :meth:`VoiceChannel.connect`. - - Attributes - ---------- - session_id: :class:`str` - The voice connection session ID. You should not share this. - token: :class:`str` - The voice connection token. You should not share this. - endpoint: :class:`str` - The endpoint the current client is connected to. - channel: :class:`abc.Connectable` - The voice channel connected to. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the voice client is running on. - - Warning - ------- - In order to use PCM based AudioSources, you must have the opus library - installed on your system and loaded through :func:`opus.load_opus`. - Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) - or the library will not be able to transmit audio. - """ +# rn this is for typing, will be moved here in some point in the future +from discord.voice_client import VoiceClient diff --git a/discord/voice/enums.py b/discord/voice/enums.py new file mode 100644 index 0000000000..2e24f03b58 --- /dev/null +++ b/discord/voice/enums.py @@ -0,0 +1,54 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from discord.enums import Enum + + +class OpCodes(Enum): + # fmt: off + identify = 0 + select_protocol = 1 + ready = 2 + heartbeat = 3 + session_description = 4 + speaking = 5 + heartbeat_ack = 6 + resume = 7 + hello = 8 + resumed = 9 + client_connect = 10 + client_disconnect = 11 + # fmt: on + + def __eq__(self, other: object) -> bool: + if isinstance(other, int): + return self.value == other + elif isinstance(other, self.__class__): + return self is other + return NotImplemented + + def __int__(self) -> int: + return self.value diff --git a/discord/voice/errors.py b/discord/voice/errors.py deleted file mode 100644 index e29b50a106..0000000000 --- a/discord/voice/errors.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" -from __future__ import annotations - -from aiohttp import ClientWebSocketResponse - -from discord.errors import ClientException - - -class VoiceConnectionClosed(ClientException): - """Exception that's raised when a voice websocket connection - is closed for reasons that could not be handled internally. - - Attributes - ---------- - code: :class:`int` - The close code of the websocket. - reason: :class:`str` - The reason provided for the closure. - guild_id: :class:`int` - The guild ID the client was connected to. - channel_id: :class:`int` - The channel ID the client was connected to. - """ - - __slots__ = ( - 'code', - 'reason', - 'channel_id', - 'guild_id', - ) - - def __init__( - self, - socket: ClientWebSocketResponse, - channel_id: int, - guild_id: int, - *, - reason: str | None = None, - code: int | None = None, - ) -> None: - self.code: int = code or socket.close_code or -1 - self.reason: str = reason if reason is not None else "" - self.channel_id: int = channel_id - self.guild_id: int = guild_id - super().__init__( - f"The voice connection on {self.channel_id} (guild {self.guild_id}) was closed with {self.code}", - ) - - -class VoiceGuildMismatch(ClientException): - """Exception that's raised when, while connecting to a voice channel, the data - the library has differs from the one discord sends. - - Attributes - ---------- - expected: :class:`int` - The expected guild ID. This is the one the library has. - received: :class:`int` - The received guild ID. This is the one sent by discord. - """ - - __slots__ = ( - 'expected', - 'received', - ) - - def __init__(self, expt: int, recv: int) -> None: - self.expected: int = expt - self.received: int = recv diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 06735de2fb..e027b3ada6 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -25,16 +25,21 @@ from __future__ import annotations import asyncio +from collections import deque from collections.abc import Callable, Coroutine import logging +import struct +import time from typing import TYPE_CHECKING, Any import aiohttp from discord import utils -from discord.gateway import DiscordWebSocket +from discord.enums import SpeakingState +from discord.errors import ConnectionClosed +from discord.gateway import DiscordWebSocket, KeepAliveHandler as KeepAliveHandlerBase -from .errors import VoiceConnectionClosed +from .enums import OpCodes if TYPE_CHECKING: from .state import VoiceConnectionState @@ -42,14 +47,48 @@ _log = logging.getLogger(__name__) -class VoiceWebsocket(DiscordWebSocket): - if TYPE_CHECKING: - thread_id: int - gateway: str - _max_heartbeat_timeout: float - - VERSION = 8 - +class KeepAliveHandler(KeepAliveHandlerBase): + def __init__( + self, + *args: Any, + ws: VoiceWebSocket, + interval: float | None = None, + **kwargs: Any, + ) -> None: + daemon: bool = kwargs.pop('daemon', True) + name: str = kwargs.pop('name', f'voice-keep-alive-handler:{id(self):#x}') + super().__init__( + *args, + **kwargs, + name=name, + daemon=daemon, + ) + + self.ws: VoiceWebSocket = ws + self.interval: float | None = interval + self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.' + self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds.' + self.behing_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind.' + self.recent_ack_latencies: deque[float] = deque(maxlen=20) + + def get_payload(self) -> dict[str, Any]: + return { + 'op': int(OpCodes.heartbeat), + 'd': { + 't': int(time.time() * 1000), + 'seq_ack': self.ws.seq_ack, + }, + } + + def ack(self) -> None: + ack_time = time.perf_counter() + self._last_ack = ack_time + self._last_recv = ack_time + self.latency = ack_time - self._last_send + self.recent_ack_latencies.append(self.latency) + + +class VoiceWebSocket(DiscordWebSocket): def __init__( self, socket: aiohttp.ClientWebSocketResponse, @@ -60,33 +99,150 @@ def __init__( ) -> None: self.ws: aiohttp.ClientWebSocketResponse = socket self.loop: asyncio.AbstractEventLoop = loop - self._keep_alive: VoiceKeepAliveHandler | None = None + self._keep_alive: KeepAliveHandler | None = None self._close_code: int | None = None - self.secrety_key: list[int] | None = None + self.secret_key: list[int] | None = None self.seq_ack: int = -1 + self.session_id: str | None = None self.state: VoiceConnectionState = state + self.ssrc_map: dict[str] if hook: - self._hook = hook + self._hook = hook # type: ignore - def _hook(self, *args: Any) -> Any: + async def _hook(self, *args: Any) -> Any: pass async def send_as_json(self, data: Any) -> None: - _log.debug('Sending voice websocket frame: %s', data) + _log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json async def resume(self) -> None: - state = self._connection - - if not state.should_resume(): - if self.state.is_connected(): - await self.state.disconnect() - raise VoiceConnectionClosed( - self.ws, - channel_id=self.state.channel_id, - guild_id=self.state.guild_id, - reason='The library attempted a resume when it was not expected', + payload = { + 'op': int(OpCodes.resume), + 'd': { + 'token': self.token, + 'server_id': str(self.state.server_id), + 'session_id': self.session_id, + 'seq_ack': self.seq_ack, + }, + } + await self.send_as_json(payload) + + async def received_message(self, msg: Any, /): + _log.debug('Voice websocket frame received: %s', msg) + op = msg['op'] + data = msg.get('data', {}) # this key should ALWAYS be given, but guard anyways + self.seq_ack = data.get('seq', self.seq_ack) # keep the seq_ack updated + + if op == OpCodes.ready: + await self.ready(data) + elif op == OpCodes.heartbeat_ack: + if not self._keep_alive: + _log.error( + 'Received a heartbeat ACK but no keep alive handler was set.', + ) + return + self._keep_alive.ack() + elif op == OpCodes.resumed: + _log.info( + f'Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was ' + 'successfully RESUMED.', + ) + elif op == OpCodes.session_description: + self.state.update_session_description(data) + elif op == OpCodes.hello: + interval = data['heartbeat_interval'] / 1000.0 + self._keep_alive = KeepAliveHandler( + ws=self, + interval=min(interval, 5), ) + self._keep_alive.start() + + await self._hook(self, msg) + + async def ready(self, data: dict[str, Any]) -> None: + state = self.state + + state.ssrc = data['ssrc'] + state.voice_port = data['port'] + state.endpoint_ip = data['ip'] + + _log.debug( + f'Connecting to {state.endpoint_ip} (port {state.voice_port}).', + ) + + await self.loop.sock_connect( + state.socket, + (state.endpoint_id, state.voice_port), + ) + + state.ip, state.port = await self.get_ip() + + async def get_ip(self) -> tuple[str, int]: + state = self.state + packet = bytearray(75) + struct.pack_into('>H', packet, 0, 1) # 1 = Send + struct.pack_into('>H', packet, 2, 70) # 70 = Length + struct.pack_into('>I', packet, 4, state.ssrc) + + _log.debug(f'Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})') + await self.loop.sock_sendall(state.socket, packet) + + fut: asyncio.Future[bytes] = self.loop.create_future() + + def get_ip_packet(data: bytes) -> None: + if data[0] == 0x02 and len(data) == 74: + self.loop.call_soon_threadsafe(fut.set_result, data) + + fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) + state.add_socket_listener(get_ip_packet) + recv = await fut + + _log.debug('Received IP discovery packet with data %s', recv) + + ip_start = 8 + ip_end = recv.index(0, ip_start) + ip = recv[ip_start:ip_end].decode('ascii') + port = struct.unpack_from('>H', recv, len(recv) - 2)[0] + _log.debug('Detected IP %s with port %s', ip, port) + + return ip, port + + @property + def latency(self) -> float: + heartbeat = self._keep_alive + return float('inf') if heartbeat is None else heartbeat.latency + + @property + def average_latency(self) -> float: + heartbeat = self._keep_alive + if heartbeat is None or not heartbeat.recent_ack_latencies: + return float('inf') + return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) + + async def load_secret_key(self, data: dict[str, Any]) -> None: + _log.debug(f'Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})') + self.secret_key = self.state.secret_key = data['secret_key'] + await self.speak(SpeakingState.none) + + async def poll_event(self) -> None: + msg = await asyncio.wait_for(self.ws.receive(), timeout=30) + + if msg.type is aiohttp.WSMsgType.TEXT: + await self.received_message(utils._from_json(msg.data)) + elif msg.type is aiohttp.WSMsgType.ERROR: + _log.debug('Received %s', msg) + raise ConnectionClosed(self.ws, shard_id=None) from msg.data + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): + _log.debug('Received %s', msg) + raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) + + async def close(self, code: int = 1000) -> None: + if self._keep_alive: + self._keep_alive.stop() + + self._close_code = code + await self.ws.close(code=self._close_code) diff --git a/discord/voice/state.py b/discord/voice/state.py index e3d1e8689c..cd1b52e885 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -25,63 +25,125 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING - -from .errors import VoiceGuildMismatch +from collections.abc import Callable, Coroutine +import logging +import select +import threading +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .gateway import VoiceWebSocket from .client import VoiceClient - from .gateway import VoiceWebsocket - from discord.member import VoiceState - from discord.types.voice import ( - VoiceState as VoiceStatePayload, - ) +SocketReaderCallback = Callable[[bytes], Any] +_log = logging.getLogger(__name__) + + +class SocketEventReader(threading.Thread): + def __init__(self, state: VoiceConnectionState, *, start_paused: bool = True) -> None: + super().__init__( + daemon=True, + name=f'voice-socket-reader:{id(self):#x}', + ) + self.state: VoiceConnectionState = state + self.start_paused: bool = start_paused + self._callbacks: list[SocketReaderCallback] = [] + self._running: threading.Event = threading.Event() + self._end: threading.Event = threading.Event() + self._idle_paused: bool = True + + def register(self, callback: SocketReaderCallback) -> None: + self._callbacks.append(callback) + if self._idle_paused: + self._idle_paused = False + self._running.set() + + def unregister(self, callback: SocketReaderCallback) -> None: + try: + self._callbacks.remove(callback) + except ValueError: + pass + else: + if not self._callbacks and self._running.is_set(): + self._idle_paused = True + self._running.clear() + + def pause(self) -> None: + self._idle_paused = False + self._running.clear() + + def resume(self, *, force: bool = False) -> None: + if self._running.is_set(): + return + + if not force and not self._callbacks: + self._idle_paused = True + return + + self._idle_paused = False + self._running.set() + + def stop(self) -> None: + self._end.set() + self._running.set() + + def run(self) -> None: + self._end.clear() + self._running.set() + + if self.start_paused: + self.pause() + + try: + self._do_run() + except Exception: + _log.exception('Error while starting socket event reader at %s', self) + finally: + self.stop() + self._running.clear() + self._callbacks.clear() + + def _do_run(self) -> None: + while not self._end.is_set(): + if not self._running.is_set(): + self._running.wait() + continue + + try: + readable, _, _ = select.select([self.state.socket], [], [], 30) + except (ValueError, TypeError, OSError) as e: + _log.debug( + 'Select error handling socket in reader, this should be safe to ignore: %s: %s', + e.__class__.__name__, + e, + ) + continue + + if not readable: + continue + + try: + data = self.state.socket.recv(2048) + except OSError: + _log.debug('Error reading from socket in %s, this should be safe to ignore.', self, exc_info=True) + else: + for cb in self._callbacks: + try: + cb(data) + except Exception: + _log.exception( + 'Error while calling %s in %s', + cb, + self, + ) class VoiceConnectionState: def __init__( self, client: VoiceClient, - ws: VoiceWebsocket, + *, + hook: Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None = None, ) -> None: - self.client: VoiceClient = client - self.self_id: int = client.user.id - self.loop: asyncio.AbstractEventLoop = client.loop - - # this is used if we don't have our self-voice state available - self.self_mute: bool = client.self_muted - self.self_deaf: bool = client.self_deaf - - self._updated_server: asyncio.Event = asyncio.Event() - self._updated_state: asyncio.Event = asyncio.Event() - self.ws: VoiceWebsocket = ws - - @property - def connected(self) -> bool: - return ( - self._updated_server.is_set() and - self._updated_state.is_set() - ) - - @property - def guild_id(self) -> int: - return self.client.guild_id - - @property - def voice_guild_state(self) -> VoiceState: - return self.client.guild.me.voice - - @property - def channel_id(self) -> int: - return self.client.channel_id - - def update_state(self, payload: VoiceStatePayload) -> None: - # if we're here it means the guild is found - guild_id = int(payload['guild_id']) # type: ignore - - if self.guild_id != guild_id: - raise VoiceGuildMismatch(self.guild_id, guild_id) - - self.self_mute = payload['self_mute'] - self.self_deaf = payload['self_deaf'] + ... + # TODO: finish this diff --git a/discord/voice_client.py b/discord/voice_client.py index 5fa3a67053..57df6089f0 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -230,7 +230,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: self._voice_server_complete.set() async def voice_connect(self) -> None: - await self.channel.guild.change_voice_state(channel=self.channel) + await self.guild.change_voice_state(channel=self.channel) async def voice_disconnect(self) -> None: _log.info( From 098a280f7dd6de96d58b5ef6cfc9898d5def5b7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 01:15:04 +0000 Subject: [PATCH 04/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/gateway.py | 4 +- discord/member.py | 14 +++--- discord/voice/enums.py | 1 + discord/voice/gateway.py | 105 ++++++++++++++++++++++----------------- discord/voice/state.py | 27 ++++++---- 5 files changed, 86 insertions(+), 65 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index fadfa1ed2e..e4651c2229 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -137,8 +137,8 @@ def __init__( interval: float | None = None, **kwargs: Any, ) -> None: - daemon: bool = kwargs.pop('daemon', True) - name: str = kwargs.pop('name', f'keep-alive-handler:shard-{shard_id}') + daemon: bool = kwargs.pop("daemon", True) + name: str = kwargs.pop("name", f"keep-alive-handler:shard-{shard_id}") super().__init__( *args, **kwargs, diff --git a/discord/member.py b/discord/member.py index 84d481b83c..498f3322fd 100644 --- a/discord/member.py +++ b/discord/member.py @@ -30,7 +30,7 @@ import itertools import sys from operator import attrgetter -from typing import TYPE_CHECKING, Any, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar import discord.abc @@ -51,9 +51,9 @@ ) if TYPE_CHECKING: - from .client import Client from .abc import Snowflake from .channel import DMChannel, VocalGuildChannel + from .client import Client from .flags import PublicUserFlags from .guild import Guild from .message import Message @@ -168,11 +168,11 @@ def __repr__(self) -> str: def _create_default(cls, channel: VocalGuildChannel, client: Client) -> VoiceState: self = cls( data={ - 'channel_id': channel.id, - 'guild_id': channel.guild.id, - 'self_deaf': False, - 'self_mute': False, - 'user_id': client._connection.self_id, # type: ignore + "channel_id": channel.id, + "guild_id": channel.guild.id, + "self_deaf": False, + "self_mute": False, + "user_id": client._connection.self_id, # type: ignore }, channel=channel, ) diff --git a/discord/voice/enums.py b/discord/voice/enums.py index 2e24f03b58..60c3748b9d 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from discord.enums import Enum diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 7f1fc75ec0..d7ec95bd96 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -26,11 +26,11 @@ from __future__ import annotations import asyncio -from collections import deque -from collections.abc import Callable, Coroutine import logging import struct import time +from collections import deque +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any import aiohttp @@ -38,7 +38,8 @@ from discord import utils from discord.enums import SpeakingState from discord.errors import ConnectionClosed -from discord.gateway import DiscordWebSocket, KeepAliveHandler as KeepAliveHandlerBase +from discord.gateway import DiscordWebSocket +from discord.gateway import KeepAliveHandler as KeepAliveHandlerBase from .enums import OpCodes @@ -56,8 +57,8 @@ def __init__( interval: float | None = None, **kwargs: Any, ) -> None: - daemon: bool = kwargs.pop('daemon', True) - name: str = kwargs.pop('name', f'voice-keep-alive-handler:{id(self):#x}') + daemon: bool = kwargs.pop("daemon", True) + name: str = kwargs.pop("name", f"voice-keep-alive-handler:{id(self):#x}") super().__init__( *args, **kwargs, @@ -67,17 +68,21 @@ def __init__( self.ws: VoiceWebSocket = ws self.interval: float | None = interval - self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.' - self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds.' - self.behing_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind.' + self.msg: str = "Keeping shard ID %s voice websocket alive with timestamp %s." + self.block_msg: str = ( + "Shard ID %s voice heartbeat blocked for more than %s seconds." + ) + self.behing_msg: str = ( + "High socket latency, shard ID %s heartbeat is %.1fs behind." + ) self.recent_ack_latencies: deque[float] = deque(maxlen=20) def get_payload(self) -> dict[str, Any]: return { - 'op': int(OpCodes.heartbeat), - 'd': { - 't': int(time.time() * 1000), - 'seq_ack': self.ws.seq_ack, + "op": int(OpCodes.heartbeat), + "d": { + "t": int(time.time() * 1000), + "seq_ack": self.ws.seq_ack, }, } @@ -115,47 +120,47 @@ async def _hook(self, *args: Any) -> Any: pass async def send_as_json(self, data: Any) -> None: - _log.debug('Sending voice websocket frame: %s.', data) + _log.debug("Sending voice websocket frame: %s.", data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json async def resume(self) -> None: payload = { - 'op': int(OpCodes.resume), - 'd': { - 'token': self.token, - 'server_id': str(self.state.server_id), - 'session_id': self.session_id, - 'seq_ack': self.seq_ack, + "op": int(OpCodes.resume), + "d": { + "token": self.token, + "server_id": str(self.state.server_id), + "session_id": self.session_id, + "seq_ack": self.seq_ack, }, } await self.send_as_json(payload) async def received_message(self, msg: Any, /): - _log.debug('Voice websocket frame received: %s', msg) - op = msg['op'] - data = msg.get('data', {}) # this key should ALWAYS be given, but guard anyways - self.seq_ack = data.get('seq', self.seq_ack) # keep the seq_ack updated + _log.debug("Voice websocket frame received: %s", msg) + op = msg["op"] + data = msg.get("data", {}) # this key should ALWAYS be given, but guard anyways + self.seq_ack = data.get("seq", self.seq_ack) # keep the seq_ack updated if op == OpCodes.ready: await self.ready(data) elif op == OpCodes.heartbeat_ack: if not self._keep_alive: _log.error( - 'Received a heartbeat ACK but no keep alive handler was set.', + "Received a heartbeat ACK but no keep alive handler was set.", ) return self._keep_alive.ack() elif op == OpCodes.resumed: _log.info( - f'Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was ' - 'successfully RESUMED.', + f"Voice connection on channel ID {self.state.channel_id} (guild {self.state.guild_id}) was " + "successfully RESUMED.", ) elif op == OpCodes.session_description: self.state.update_session_description(data) elif op == OpCodes.hello: - interval = data['heartbeat_interval'] / 1000.0 + interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler( ws=self, interval=min(interval, 5), @@ -167,12 +172,12 @@ async def received_message(self, msg: Any, /): async def ready(self, data: dict[str, Any]) -> None: state = self.state - state.ssrc = data['ssrc'] - state.voice_port = data['port'] - state.endpoint_ip = data['ip'] + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + state.endpoint_ip = data["ip"] _log.debug( - f'Connecting to {state.endpoint_ip} (port {state.voice_port}).', + f"Connecting to {state.endpoint_ip} (port {state.voice_port}).", ) await self.loop.sock_connect( @@ -185,11 +190,13 @@ async def ready(self, data: dict[str, Any]) -> None: async def get_ip(self) -> tuple[str, int]: state = self.state packet = bytearray(75) - struct.pack_into('>H', packet, 0, 1) # 1 = Send - struct.pack_into('>H', packet, 2, 70) # 70 = Length - struct.pack_into('>I', packet, 4, state.ssrc) + struct.pack_into(">H", packet, 0, 1) # 1 = Send + struct.pack_into(">H", packet, 2, 70) # 70 = Length + struct.pack_into(">I", packet, 4, state.ssrc) - _log.debug(f'Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})') + _log.debug( + f"Sending IP discovery packet for voice in channel {state.channel_id} (guild {state.guild_id})" + ) await self.loop.sock_sendall(state.socket, packet) fut: asyncio.Future[bytes] = self.loop.create_future() @@ -202,31 +209,33 @@ def get_ip_packet(data: bytes) -> None: state.add_socket_listener(get_ip_packet) recv = await fut - _log.debug('Received IP discovery packet with data %s', recv) + _log.debug("Received IP discovery packet with data %s", recv) ip_start = 8 ip_end = recv.index(0, ip_start) - ip = recv[ip_start:ip_end].decode('ascii') - port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('Detected IP %s with port %s', ip, port) + ip = recv[ip_start:ip_end].decode("ascii") + port = struct.unpack_from(">H", recv, len(recv) - 2)[0] + _log.debug("Detected IP %s with port %s", ip, port) return ip, port @property def latency(self) -> float: heartbeat = self._keep_alive - return float('inf') if heartbeat is None else heartbeat.latency + return float("inf") if heartbeat is None else heartbeat.latency @property def average_latency(self) -> float: heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: - return float('inf') + return float("inf") return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) async def load_secret_key(self, data: dict[str, Any]) -> None: - _log.debug(f'Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})') - self.secret_key = self.state.secret_key = data['secret_key'] + _log.debug( + f"Received secret key for voice connection in channel {self.state.channel_id} (guild {self.state.guild_id})" + ) + self.secret_key = self.state.secret_key = data["secret_key"] await self.speak(SpeakingState.none) async def poll_event(self) -> None: @@ -235,10 +244,14 @@ async def poll_event(self) -> None: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Received %s', msg) + elif msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code: int = 1000) -> None: diff --git a/discord/voice/state.py b/discord/voice/state.py index 356417d448..9afe9c86e1 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -25,26 +25,27 @@ from __future__ import annotations -import asyncio -from collections.abc import Callable, Coroutine import logging import select import threading +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from .gateway import VoiceWebSocket from .client import VoiceClient + from .gateway import VoiceWebSocket SocketReaderCallback = Callable[[bytes], Any] _log = logging.getLogger(__name__) class SocketEventReader(threading.Thread): - def __init__(self, state: VoiceConnectionState, *, start_paused: bool = True) -> None: + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True + ) -> None: super().__init__( daemon=True, - name=f'voice-socket-reader:{id(self):#x}', + name=f"voice-socket-reader:{id(self):#x}", ) self.state: VoiceConnectionState = state self.start_paused: bool = start_paused @@ -98,7 +99,7 @@ def run(self) -> None: try: self._do_run() except Exception: - _log.exception('Error while starting socket event reader at %s', self) + _log.exception("Error while starting socket event reader at %s", self) finally: self.stop() self._running.clear() @@ -114,7 +115,7 @@ def _do_run(self) -> None: readable, _, _ = select.select([self.state.socket], [], [], 30) except (ValueError, TypeError, OSError) as e: _log.debug( - 'Select error handling socket in reader, this should be safe to ignore: %s: %s', + "Select error handling socket in reader, this should be safe to ignore: %s: %s", e.__class__.__name__, e, ) @@ -126,14 +127,18 @@ def _do_run(self) -> None: try: data = self.state.socket.recv(2048) except OSError: - _log.debug('Error reading from socket in %s, this should be safe to ignore.', self, exc_info=True) + _log.debug( + "Error reading from socket in %s, this should be safe to ignore.", + self, + exc_info=True, + ) else: for cb in self._callbacks: try: cb(data) except Exception: _log.exception( - 'Error while calling %s in %s', + "Error while calling %s in %s", cb, self, ) @@ -144,7 +149,9 @@ def __init__( self, client: VoiceClient, *, - hook: Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None = None, + hook: ( + Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None + ) = None, ) -> None: ... # TODO: finish this From 8a6ae39b33cbcaffe9666df02bd25f91d0309f4f Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Tue, 26 Aug 2025 16:38:36 +0200 Subject: [PATCH 05/40] finish voice things, start voice-recv things --- discord/opus.py | 51 +++- discord/sinks/core.py | 23 +- discord/state.py | 6 +- discord/voice/client.py | 563 ++++++++++++++++++++++++++++++++++- discord/voice/enums.py | 14 + discord/voice/errors.py | 93 ------ discord/voice/gateway.py | 44 ++- discord/voice/recorder.py | 27 ++ discord/voice/state.py | 611 +++++++++++++++++++++++++++++++++++++- 9 files changed, 1311 insertions(+), 121 deletions(-) delete mode 100644 discord/voice/errors.py create mode 100644 discord/voice/recorder.py diff --git a/discord/opus.py b/discord/opus.py index 6ea6f84308..0160f85ef5 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: T = TypeVar("T") + APPLICATION_CTL = Literal['audio', 'voip', 'lowdelay'] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] SIGNAL_CTL = Literal["auto", "voice", "music"] @@ -61,6 +62,12 @@ class SignalCtl(TypedDict): music: int +class ApplicationCtl(TypedDict): + audio: int + voip: int + lowdelay: int + + __all__ = ( "Encoder", "Decoder", @@ -74,6 +81,7 @@ class SignalCtl(TypedDict): c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) +OPUS_SILENCE = b'\xf8\xff\xfe' _lib = None @@ -96,9 +104,11 @@ class DecoderStruct(ctypes.Structure): BAD_ARG = -1 # Encoder CTLs -APPLICATION_AUDIO = 2049 -APPLICATION_VOIP = 2048 -APPLICATION_LOWDELAY = 2051 +application_ctl: ApplicationCtl = { + 'audio': 2049, + 'lowdelay': 2051, + 'voip': 2048, +} CTL_SET_BITRATE = 4002 CTL_SET_BANDWIDTH = 4008 @@ -365,16 +375,37 @@ def get_opus_version() -> str: class Encoder(_OpusStruct): - def __init__(self, application: int = APPLICATION_AUDIO): + def __init__( + self, + *, + application: APPLICATION_CTL = 'audio', + bitrate: int = 128, + fec: bool = True, + expected_packet_loss: float = 0.15, + bandwidth: BAND_CTL = 'full', + signal_type: SIGNAL_TL = 'auto', + ) -> None: + if application not in application_ctl: + raise ValueError( + 'invalid application ctl type provided' + ) + if not 16 <= bitrate <= 512: + raise ValueError('bitrate must be between 16 and 512, both included') + if not 0 < expected_packet_loss <= 1: + raise ValueError( + 'expected_packet_loss must be between 0 and 1, including 1', + ) + _OpusStruct.get_opus_version() - self.application: int = application + self.application: int = application_ctl[application] self._state: EncoderStruct = self._create_state() - self.set_bitrate(128) - self.set_fec(True) - self.set_expected_packet_loss_percent(0.15) - self.set_bandwidth("full") - self.set_signal_type("auto") + + self.set_bitrate(bitrate) + self.set_fec(fec) + self.set_expected_packet_loss_percent(expected_packet_loss) + self.set_bandwidth(bandwidth) + self.set_signal_type(signal_type) def __del__(self) -> None: if hasattr(self, "_state"): diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 90fdbf7d19..0d7b4f5483 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -37,7 +37,7 @@ from .errors import SinkException if TYPE_CHECKING: - from ..voice_client import VoiceClient + from ..voice.client import VoiceClient __all__ = ( "Filters", @@ -103,9 +103,14 @@ class RawData: .. versionadded:: 2.0 """ - def __init__(self, data, client): - self.data = bytearray(data) - self.client = client + if TYPE_CHECKING: + sequence: int + timestamp: int + ssrc: int + + def __init__(self, data: bytes, client: VoiceClient): + self.data: bytearray = bytearray(data) + self.client: VoiceClient = client unpacker = struct.Struct(">xxHII") self.sequence, self.timestamp, self.ssrc = unpacker.unpack_from(self.data[:12]) @@ -120,16 +125,16 @@ def __init__(self, data, client): else: cutoff = 12 - self.header = data[:cutoff] + self.header: bytes = data[:cutoff] self.data = self.data[cutoff:] - self.decrypted_data = getattr(self.client, f"_decrypt_{self.client.mode}")( + self.decrypted_data: bytes = getattr(self.client, f"_decrypt_{self.client.mode}")( self.header, self.data ) - self.decoded_data = None + self.decoded_data: bytes | None = None - self.user_id = None - self.receive_time = time.perf_counter() + self.user_id: int | None = None + self.receive_time: float = time.perf_counter() class AudioData: diff --git a/discord/state.py b/discord/state.py index 52a9cc0989..73917a6e59 100644 --- a/discord/state.py +++ b/discord/state.py @@ -1853,7 +1853,8 @@ def parse_voice_state_update(self, data) -> None: if int(data["user_id"]) == self_id: voice = self._get_voice_client(guild.id) if voice is not None: - coro = voice.on_voice_state_update(data) + payload = RawVoiceStateUpdateEvent(data=data, state=self) + coro = voice.on_voice_state_update(payload) asyncio.create_task( logging_coroutine( coro, info="Voice Protocol voice state update handler" @@ -1890,8 +1891,9 @@ def parse_voice_server_update(self, data) -> None: key_id = int(data["channel_id"]) vc = self._get_voice_client(key_id) + payload = RawVoiceServerUpdateEvent(data=data, state=self) if vc is not None: - coro = vc.on_voice_server_update(data) + coro = vc.on_voice_server_update(payload) asyncio.create_task( logging_coroutine( coro, info="Voice Protocol voice server update handler" diff --git a/discord/voice/client.py b/discord/voice/client.py index 5d650ae7ac..157e58a01e 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -25,5 +25,564 @@ from __future__ import annotations -# rn this is for typing, will be moved here in some point in the future -from discord.voice_client import VoiceClient +import asyncio +from collections.abc import Callable, Coroutine +import struct +from typing import TYPE_CHECKING, Any, Literal, overload + +from discord import opus +from discord.errors import ClientException +from discord.sinks.errors import RecordingException +from discord.utils import MISSING +from discord.sinks import RawData, Sink + +from ._types import VoiceProtocol +from .state import VoiceConnectionState +from .recorder import Recorder +from .source import AudioSource +from .player import AudioPlayer + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord import abc + from discord.client import Client + from discord.guild import Guild, VocalGuildChannel + from discord.state import ConnectionState + from discord.user import ClientUser + from discord.raw_models import ( + RawVoiceStateUpdateEvent, + RawVoiceServerUpdateEvent, + ) + from discord.types.voice import SupportedModes + from discord.opus import Encoder, APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Decoder + + from .gateway import VoiceWebSocket + + AfterCallback = Callable[[Exception | None], Any] + P = ParamSpec('P') + +has_nacl: bool + +try: + import nacl.secret + import nacl.utils + has_nacl = True +except ImportError: + has_nacl = False + + +class VoiceClient(VoiceProtocol): + """Represents a Discord voice connection. + + You do not create these, you typically get them from e.g. + :meth:`VoiceChannel.connect`. + + Attributes + ---------- + session_id: :class:`str` + The voice connection session ID. + token: :class:`str` + The voice connection token. + endpoint: :class:`str` + The endpoint we are connecting to. + channel: Union[:class:`VoiceChannel`, :class:`StageChannel`] + The channel we are connected to. + + Warning + ------- + In order to use PCM based AudioSources, you must have the opus library + installed on your system and loaded through :func:`opus.load_opus`. + Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) + or the library will not be able ot transmit audio. + """ + + channel: VocalGuildChannel + + def __init__(self, client: Client, channel: abc.Connectable) -> None: + if not has_nacl: + raise RuntimeError( + 'PyNaCl library is needed in order to use voice related features, ' + 'you can run "pip install py-cord[voice]" to install all voice-related ' + 'dependencies.' + ) + + super().__init__(client, channel) + state = client._connection + + self.server_id: int = MISSING + self.socket = MISSING + self.loop: asyncio.AbstractEventLoop = state.loop + self._state: ConnectionState = state + + self.sequence: int = 0 + self.timestamp: int = 0 + self._player: AudioPlayer | None = None + self._player_future: asyncio.Future[None] | None = None + self.encoder: Encoder = MISSING + self.decoder: Decoder = MISSING + self._incr_nonce: int = 0 + + self._connection: VoiceConnectionState = self.create_connection_state() + + # voice recv things + self._recorder: Recorder | None = None + + warn_nacl: bool = not has_nacl + supported_modes: tuple[SupportedModes, ...] = ( + 'aead_xchacha20_poly1305_rtpsize', + 'xsalsa20_poly1305_lite', + 'xsalsa20_poly1305_suffix', + 'xsalsa20_poly1305', + ) + + @property + def guild(self) -> Guild: + """Returns the guild the channel we're connecting to is bound to.""" + return self.channel.guild + + @property + def user(self) -> ClientUser: + """The user connected to voice (i.e. ourselves)""" + return self._state.user # type: ignore + + @property + def session_id(self) -> str | None: + return self._connection.session_id + + @property + def token(self) -> str | None: + return self._connection.token + + @property + def endpoint(self) -> str | None: + return self._connection.endpoint + + @property + def ssrc(self) -> int: + return self._connection.ssrc + + @property + def mode(self) -> SupportedModes: + return self._connection.mode + + @property + def secret_key(self) -> list[int]: + return self._connection.secret_key + + @property + def ws(self) -> VoiceWebSocket: + return self._connection.ws + + @property + def timeout(self) -> float: + return self._connection.timeout + + def checked_add(self, attr: str, value: int, limit: int) -> None: + val = getattr(self, attr) + if val + value > limit: + setattr(self, attr, 0) + else: + setattr(self, attr, val + value) + + def create_connection_state(self) -> VoiceConnectionState: + return VoiceConnectionState(self) + + async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + await self._connection.voice_state_update(data) + + async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + await self._connection.voice_server_update(data) + + async def connect( + self, + *, + reconnect: bool, + timeout: float, + self_deaf: bool = False, + self_mute: bool = False, + ) -> None: + await self._connection.connect( + reconnect=reconnect, + timeout=timeout, + self_deaf=self_deaf, + self_mute=self_mute, + resume=False, + ) + + def wait_until_connected(self, timeout: float | None = 30.0) -> bool: + self._connection.wait_for(timeout=timeout) + return self._connection.is_connected() + + @property + def latency(self) -> float: + """Latency between a HEARTBEAT and a HEARBEAT_ACK in seconds. + + This chould be referred to as the Discord Voice WebSocket latency and is + and analogue of user's voice latencies as seen in the Discord client. + + .. versionadded:: 1.4 + """ + ws = self.ws + return float('inf') if not ws else ws.latency + + @property + def average_latency(self) -> float: + """Average of most recent 20 HEARBEAT latencies in seconds. + + .. versionadded:: 1.4 + """ + ws = self.ws + return float('inf') if not ws else ws.average_latency + + async def disconnect(self, *, force: bool = False) -> None: + """|coro| + + Disconnects this voice client from voice. + """ + + self.stop() + await self._connection.disconnect(force=force, wait=True) + self.cleanup() + + async def move_to(self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0) -> None: + """|coro| + + moves you to a different voice channel. + + Parameters + ---------- + channel: Optional[:class:`abc.Snowflake`] + The channel to move to. If this is ``None``, it is an equivalent of calling :meth:`.disconnect`. + timeout: Optional[:class:`float`] + The maximum time in seconds to wait for the channel move to be completed, defaults to 30. + If it is ``None``, then there is no timeout. + + Raises + ------ + asyncio.TimeoutError + Waiting for channel move timed out. + """ + await self._connection.move_to(channel, timeout) + + def is_connected(self) -> bool: + """Whether the voice client is connected to voice.""" + return self._connection.is_connected() + + def is_playing(self) -> bool: + """INdicates if we're playing audio.""" + return self._player is not None and self._player.is_playing() + + def is_paused(self) -> bool: + """Indicates if we're playing audio, but if we're paused.""" + return self._player is not None and self._player.is_paused() + + # audio related + + def _get_voice_packet(self, data: Any) -> bytes: + header = bytearray(12) + + # formulate rtp header + header[0] = 0x80 + header[1] = 0x78 + struct.pack_into('>H', header, 2, self.sequence) + struct.pack_into('>I', header, 4, self.timestamp) + struct.pack_into('>I', header, 8, self.ssrc) + + encrypt_packet = getattr(self, f'_encrypt_{self.mode}') + return encrypt_packet(header, data) + + def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:12] = header + return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + + def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) + return header + box.encrypt(bytes(data), nonce).ciphertext + nonce + + def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:4] = struct.pack('>I', self._incr_nonce) + self.checked_add('_incr_nonce', 1, 4294967295) + + return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] + + def _encrypt_aead_xcacha20_poly1305_rtpsize(self, header: bytes, data: Any) -> bytes: + box = nacl.secret.Aead(bytes(self.secret_key)) + nonce = bytearray(24) + nonce[:4] = struct.pack('>I', self._incr_nonce) + self.checked_add('_incr_nonce', 1, 4294967295) + return header + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext + nonce[:4] + + def _decrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce = bytearray(24) + nonce[:12] = header + + return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + + def _decrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce_size = nacl.secret.SecretBox.NONCE_SIZE + nonce = data[-nonce_size:] + + return self.strip_header_ext(box.decrypt(bytes(data[:-nonce_size]), nonce)) + + def _decrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: + # deprecated + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + nonce = bytearray(24) + nonce[:4] = data[-4:] + data = data[:-4] + + return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + + def _decrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data: Any) -> bytes: + box = nacl.secret.Aead(bytes(self.secret_key)) + + nonce = bytearray(24) + nonce[:4] = data[-4:] + data = data[:-4] + + return self.strip_header_ext( + box.decrypt(bytes(data), bytes(header), bytes(nonce)) + ) + + @staticmethod + def strip_header_ext(data: bytes) -> bytes: + if len(data) > 4 and data[0] == 0xBE and data[1] == 0xDE: + _, length = struct.unpack_from('>HH', data) + offset = 4 + length * 4 + data = data[offset:] + return data + + @overload + def play( + self, + source: AudioSource, + *, + after: AfterCallback | None = ..., + application: APPLICATION_CTL = ..., + bitrate: int = ..., + fec: bool = ..., + expected_packet_loss: float = ..., + bandwidth: BAND_CTL = ..., + signal_type: SIGNAL_CTL = ..., + wait_finish: Literal[False] = ..., + ) -> None: ... + + @overload + def play( + self, + source: AudioSource, + *, + after: AfterCallback = ..., + application: APPLICATION_CTL = ..., + bitrate: int = ..., + fec: bool = ..., + expected_packet_loss: float = ..., + bandwidth: BAND_CTL = ..., + signal_type: SIGNAL_CTL = ..., + wait_finish: Literal[True], + ) -> asyncio.Future[None]: ... + + def play( + self, + source: AudioSource, + *, + after: AfterCallback | None = None, + application: APPLICATION_CTL = 'audio', + bitrate: int = 128, + fec: bool = True, + expected_packet_loss: float = 0.15, + bandwidth: BAND_CTL = 'full', + signal_type: SIGNAL_CTL = 'auto', + wait_finish: bool = False, + ) -> None | asyncio.Future[None]: + """Plays an :class:`AudioSource`. + + The finalizer, ``after`` is called after the source has been exhausted + or an error occurred. + + IF an error happens while the audio player is running, the exception is + caught and the audio player is then stopped. If no after callback is passed, + any caught exception will be displayed as if it were raised. + + Parameters + ---------- + source: :class:`AudioSource` + The audio source we're reading from. + after: Callable[[Optional[:class:`Exception`]], Any] + The finalizer that is called after the stream is exhausted. + This function must have a single parameter, ``error``, that + denotes an optional exception that was raised during playing. + application: :class:`str` + The intended application encoder application type. Must be one of + ``audio``, ``voip``, or ``lowdelay``. Defaults to ``audio``. + bitrate: :class:`int` + The encoder's bitrate. Must be between ``16`` and ``512``. Defaults + to ``128``. + fec: :class:`bool` + Configures the encoder's use of inband forward error correction (fec). + Defaults to ``True``. + expected_packet_loss: :class:`float` + How much packet loss percentage is expected from the encoder. This requires ``fec`` + to be set to ``True``. Defaults to ``0.15``. + bandwidth: :class:`str` + The encoder's bandpass. Must be one of ``narrow``, ``medium``, ``wide``, + ``superwide``, or ``full``. Defaults to ``full``. + signal_type: :class:`str` + The type of signal being encoded. Must be one of ``auto``, ``voice``, ``music``. + Defaults to ``auto``. + wait_finish: :class:`bool` + If ``True``, then an awaitable is returned that waits for the audio source to be + exhausted, and will return an optional exception that could have been raised. + + If ``False``, ``None`` is returned and the function does not block. + + .. versionadded:: 2.5 + + Raises + ------ + ClientException + Already playing audio, or not connected to voice. + TypeError + Source is not a :class:`AudioSource`, or after is not a callable. + OpusNotLoaded + Source is not opus encoded and opus is not loaded. + """ + + if not self.is_connected(): + raise ClientException('Not connected to voice') + if self.is_playing(): + raise ClientException('Already playing audio') + if not isinstance(source, AudioSource): + raise TypeError( + f'Source must be an AudioSource, not {source.__class__.__name__}', + ) + if not self.encoder and not source.is_opus(): + self.encoder = opus.Encoder( + application=application, + bitrate=bitrate, + fec=fec, + expected_packet_loss=expected_packet_loss, + bandwidth=bandwidth, + signal_type=signal_type, + ) + + if wait_finish: + self._player_future = future = self.loop.create_future() + after_callback = after + + def _after(exc: Exception | None) -> None: + if callable(after_callback): + after_callback(exc) + future.set_result(exc) + + after = _after + + self._player = AudioPlayer(source, self, after=after) + self._player.start() + return future + + def stop(self) -> None: + """Stops playing audio, if applicable.""" + if self._player: + self._player.stop() + if self._player_future: + for cb, _ in self._player_future._callbacks: + self._player_future.remove_done_callback(cb) + self._player_future.set_result(None) + + self._player = None + self._player_future = None + + def unpack_audio(self, data: bytes) -> bytes | None: + """Takes an audio packet received from Discord and decodes it into PCM Audio data. + If there are no users talking in the channel, ``None`` will be returned. + + You must be connected to receive audio. + + .. versionadded:: 2.0 + + Parameters + ---------- + data: :class:`bytes` + Bytes received by Discord via de UDP connection used for sending and receiving voice data. + """ + + if not len(data) > 2: + return None + + if data[1] != 0x78: + # We Should Ignore Any Payload Types We Do Not Understand + # Ref RFC 3550 5.1 payload type + # At Some Point We Noted That We Should Ignore Only Types 200 - 204 inclusive. + # They Were Marked As RTCP: Provides Information About The Connection + # This Was Too Broad Of A Whitelist, It Is Unclear If This Is Too Narrow Of A Whitelist + return None + if self.paused: + return None + + raw = RawData(data, self) + + if raw.decrypted_data == opus.OPUS_SILENCE: # silenece frame + return None + + return self.decoder.decode(raw) + + def start_recording( + self, + sink: Sink, + callback: Callable[P, Coroutine[Any, Any, Any]], + sync_start: bool = False, + *callback_args: P.args, + **callback_kwargs: P.kwargs, + ): + r"""Start recording audio from the current voice channel. This function uses + a thread so the current code line will not be stopped. You must be in a voice + channel to use this, and must not be already recording. + + .. versionadded:: 2.0 + + Parameters + ---------- + sink: :class:`~discord.Sink` + A Sink which will "store" all the audio data. + callback: :ref:`coroutine ` + A function which is called after the bot has stopped recording. + sync_start: :class:`bool` + If ``True``, the recordings of subsequent users will start with silence. This + is useful for recording audio just as it was heard. + \*callback_args + Arguments that will be passed to the callback function. + \*\*callback_kwargs + Keyword arguments that will be passed to the callback function. + + Raises + ------ + RecordingException + Not connected to a voice channel, or you are already recording, or you + did not provide a Sink object. + """ + + if not self.is_connected(): + raise RecordingException('Not connected to a voice channel') + if self.recording: + raise RecordingException('You are already recording') + if not isinstance(sink, Sink): + raise RecordingException(f'Expected a Sink object, got {sink.__class__.__name__}') + + self._recording_handler.empty() diff --git a/discord/voice/enums.py b/discord/voice/enums.py index 2e24f03b58..0422600911 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -52,3 +52,17 @@ def __eq__(self, other: object) -> bool: def __int__(self) -> int: return self.value + + +class ConnectionFlowState(Enum): + # fmt: off + disconnected = 0 + set_guild_voice_state = 1 + got_voice_state_update = 2 + got_voice_server_update = 3 + got_both_voice_updates = 4 + websocket_connected = 5 + got_websocket_ready = 6 + got_ip_discovery = 7 + connected = 8 + # fmt: on diff --git a/discord/voice/errors.py b/discord/voice/errors.py deleted file mode 100644 index be35852c22..0000000000 --- a/discord/voice/errors.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -from aiohttp import ClientWebSocketResponse - -from discord.errors import ClientException - - -class VoiceConnectionClosed(ClientException): - """Exception that's raised when a voice websocket connection - is closed for reasons that could not be handled internally. - - Attributes - ---------- - code: :class:`int` - The close code of the websocket. - reason: :class:`str` - The reason provided for the closure. - guild_id: :class:`int` - The guild ID the client was connected to. - channel_id: :class:`int` - The channel ID the client was connected to. - """ - - __slots__ = ( - "code", - "reason", - "channel_id", - "guild_id", - ) - - def __init__( - self, - socket: ClientWebSocketResponse, - channel_id: int, - guild_id: int, - *, - reason: str | None = None, - code: int | None = None, - ) -> None: - self.code: int = code or socket.close_code or -1 - self.reason: str = reason if reason is not None else "" - self.channel_id: int = channel_id - self.guild_id: int = guild_id - super().__init__( - f"The voice connection on {self.channel_id} (guild {self.guild_id}) was closed with {self.code}", - ) - - -class VoiceGuildMismatch(ClientException): - """Exception that's raised when, while connecting to a voice channel, the data - the library has differs from the one discord sends. - - Attributes - ---------- - expected: :class:`int` - The expected guild ID. This is the one the library has. - received: :class:`int` - The received guild ID. This is the one sent by discord. - """ - - __slots__ = ( - "expected", - "received", - ) - - def __init__(self, expt: int, recv: int) -> None: - self.expected: int = expt - self.received: int = recv diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 7f1fc75ec0..bc76a10c74 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -30,6 +30,7 @@ from collections.abc import Callable, Coroutine import logging import struct +import threading import time from typing import TYPE_CHECKING, Any @@ -43,6 +44,8 @@ from .enums import OpCodes if TYPE_CHECKING: + from typing_extensions import Self + from .state import VoiceConnectionState _log = logging.getLogger(__name__) @@ -106,7 +109,8 @@ def __init__( self.seq_ack: int = -1 self.session_id: str | None = None self.state: VoiceConnectionState = state - self.ssrc_map: dict[str] + self.ssrc_map: dict[str, dict[str, Any]] = {} + self.token: str | None = None if hook: self._hook = hook # type: ignore @@ -153,7 +157,7 @@ async def received_message(self, msg: Any, /): 'successfully RESUMED.', ) elif op == OpCodes.session_description: - self.state.update_session_description(data) + self.state.mode = data['mode'] elif op == OpCodes.hello: interval = data['heartbeat_interval'] / 1000.0 self._keep_alive = KeepAliveHandler( @@ -247,3 +251,39 @@ async def close(self, code: int = 1000) -> None: self._close_code = code await self.ws.close(code=self._close_code) + + async def speak(self, state: SpeakingState = SpeakingState.voice) -> None: + await self.send_as_json( + { + 'op': int(OpCodes.speaking), + 'd': { + 'speaking': int(state), + 'delay': 0, + }, + }, + ) + + @classmethod + async def from_state( + cls, + state: VoiceConnectionState, + *, + resume: bool = False, + hook: Callable[..., Coroutine[Any, Any, Any]] | None = None, + seq_ack: int = -1, + ) -> Self: + gateway = f'wss://{state.endpoint}/?v=8' + client = state.client + http = client._state.http + socket = await http.ws_connect(gateway, compress=15) + ws = cls(socket, loop=client.loop, hook=hook, state=state) + ws.gateway = gateway + ws.seq_ack = seq_ack + ws._max_heartbeat_timeout = 60.0 + ws.thread_id = threading.get_ident() + + if resume: + await ws.resume() + else: + await ws.identify() + return ws diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py new file mode 100644 index 0000000000..e01d9f8694 --- /dev/null +++ b/discord/voice/recorder.py @@ -0,0 +1,27 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +# TODO: finish this diff --git a/discord/voice/state.py b/discord/voice/state.py index 356417d448..6b6e1d5aa1 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -29,13 +29,30 @@ from collections.abc import Callable, Coroutine import logging import select +import socket import threading from typing import TYPE_CHECKING, Any +from discord import utils +from discord.backoff import ExponentialBackoff +from discord.errors import ConnectionClosed + +from .gateway import VoiceWebSocket +from .enums import ConnectionFlowState + if TYPE_CHECKING: - from .gateway import VoiceWebSocket + from discord import abc + from discord.user import ClientUser + from discord.guild import Guild + from discord.member import VoiceState + from discord.types.voice import ( + SupportedModes, + VoiceServerUpdate as VoiceServerUpdatePayload, + ) + from discord.raw_models import RawVoiceStateUpdateEvent, RawVoiceServerUpdateEvent from .client import VoiceClient +MISSING = utils.MISSING SocketReaderCallback = Callable[[bytes], Any] _log = logging.getLogger(__name__) @@ -146,5 +163,593 @@ def __init__( *, hook: Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None = None, ) -> None: - ... - # TODO: finish this + self.client: VoiceClient = client + self.hook = hook + + self.timeout: float = 30.0 + self.reconnect: bool = True + self.self_deaf: bool = False + self.self_mute: bool = False + self.endpoint: str | None = None + self.endpoint_ip: str | None = None + self.server_id: int | None = None + self.ip: str | None = None + self.port: int | None = None + self.voice_port: int | None = None + self.secret_key: list[int] = MISSING + self.ssrc: int = MISSING + self.mode: SupportedModes = MISSING + self.socket: socket.socket = MISSING + self.ws: VoiceWebSocket = MISSING + + self._state: ConnectionFlowState = ConnectionFlowState.disconnected + self._expecting_disconnect: bool = False + self._connected = threading.Event() + self._state_event = asyncio.Event() + self._disconnected = asyncio.Event() + self._runner: asyncio.Task[None] | None = None + self._connector: asyncio.Task[None] | None = None + self._socket_reader = SocketEventReader(self) + self._socket_reader.start() + + @property + def state(self) -> ConnectionFlowState: + return self._state + + @state.setter + def state(self, state: ConnectionFlowState) -> None: + if state is not self._state: + _log.debug('State changed from %s to %s', self._state.name, state.name) + + self._state = state + self._state_event.set() + self._state_event.clear() + + if state is ConnectionFlowState.connected: + self._connected.set() + else: + self._connected.clear() + + @property + def guild(self) -> Guild: + return self.client.guild + + @property + def user(self) -> ClientUser: + return self.client.user + + @property + def channel_id(self) -> int | None: + return self.client.channel and self.client.channel.id + + @property + def guild_id(self) -> int: + return self.guild.id + + @property + def supported_modes(self) -> tuple[SupportedModes, ...]: + return self.client.supported_modes + + @property + def self_voice_state(self) -> VoiceState | None: + return self.guild.me.voice + + @property + def token(self) -> str | None: + return self.ws.token + + @token.setter + def token(self, token: str | None) -> None: + self.ws.token = token + + @property + def session_id(self) -> str | None: + return self.ws.session_id + + @session_id.setter + def session_id(self, value: str | None) -> None: + self.ws.session_id = value + + def is_connected(self) -> bool: + return self.state is ConnectionFlowState.connected + + def _inside_runner(self) -> bool: + return self._runner is not None and asyncio.current_task() == self._runner + + async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + channel_id = data.channel_id + + if channel_id is None: + self._disconnected.set() + + if self._expecting_disconnect: + self._expecting_disconnect = False + else: + _log.debug('We have been disconnected from voice') + await self.disconnect() + return + + self.ws.session_id = data['session_id'] + + if self.state in ( + ConnectionFlowState.set_guild_voice_state, + ConnectionFlowState.got_voice_server_update, + ): + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_state_update + + if channel_id != self.client.channel.id: + # moved from channel + self._update_voice_channel(channel_id) + else: + self.state = ConnectionFlowState.got_both_voice_updates + return + + if self.state is ConnectionFlowState.connected: + self._update_voice_channel(channel_id) + + elif self.state is not ConnectionFlowState.disconnected: + if channel_id != self.client.channel.id: + _log.info('We were moved from the channel while connecting...') + + self._update_voice_channel(channel_id) + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + else: + _log.debug('Ignoring unexpected VOICE_STATEUPDATE event') + + async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: + previous_token = self.token + previous_server_id = self.server_id + previous_endpoint = self.endpoint + + self.token = data.token + self.server_id = data.guild_id + endpoint = data.endpoint + + if self.token is None or endpoint is None: + _log.warning( + 'Awaiting endpoint... This requires waiting. ' + 'If timeout occurred considering raising the timeout and reconnecting.' + ) + return + + # strip the prefix off since we add it later + self.endpoint = endpoint.removeprefix('wss://') + + if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update): + self.endpoint_ip = MISSING + self._create_socket() + + if self.state is ConnectionFlowState.set_guild_voice_state: + self.state = ConnectionFlowState.got_voice_server_update + else: + self.state = ConnectionFlowState.got_both_voice_updates + + elif self.state is ConnectionFlowState.connected: + _log.debug('Voice server update, closing old voice websocket') + await self.ws.close(4014) # 4014 = main gw dropped + self.state = ConnectionFlowState.got_voice_server_update + + elif self.state is not ConnectionFlowState.disconnected: + if previous_token == self.token and previous_server_id == self.server_id and previous_endpoint == self.endpoint: + return + + _log.debug('Unexpected VOICE_SERVER_UPDATE event received, handling...') + + await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update) + await self.connect( + reconnect=self.reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + wait=False, + ) + self._create_socket() + + async def connect( + self, + *, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + wait: bool = True, + ) -> None: + if self._connector: + self._connector.cancel() + self._connector = None + + if self._runner: + self._runner.cancel() + self._runner = None + + self.timeout = timeout + self.reconnect = reconnect + self._connector = self.client.loop.create_task( + self._wrap_connect( + reconnect, + timeout, + self_deaf, + self_mute, + resume, + ), + name=f'voice-connector:{id(self):#x}', + ) + + if wait: + await self._connector + + async def _wrap_connect( + self, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + ) -> None: + try: + await self._connect(reconnect, timeout, self_deaf, self_mute, resume) + except asyncio.CancelledError: + _log.debug('Cancelling voice connection') + await self.soft_disconnect() + raise + except asyncio.TimeoutError: + _log.info('Timed out while connecting to voice') + await self.disconnect() + raise + except Exception: + _log.exception('Error while connecting to voice... disconnecting') + await self.disconnect() + raise + + async def _inner_connect(self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool) -> None: + for i in range(5): + _log.info('Starting voice handshake (connection attempt %s)', i + 1) + + await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) + if self.state is ConnectionFlowState.disconnected: + self.state = ConnectionFlowState.set_guild_voice_state + + await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) + + _log.info('Voice handshake complete. Endpoint found: %s', self.endpoint) + + try: + self.ws = await self._connect_websocket(resume) + await self._handshake_websocket() + break + except ConnectionClosed: + if reconnect: + wait = 1 + i * 2 + _log.exception('Failed to connect to voice... Retrying in %s seconds', wait) + await self.disconnect(cleanup=False) + await asyncio.sleep(wait) + continue + else: + await self.disconnect() + raise + + async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None: + _log.info(f'Connecting to voice {self.client.channel.id}') + + await asyncio.wait_for( + self._inner_connect(reconnect=reconnect, self_deaf=self_deaf, self_mute=self_mute, resume=resume), + timeout=timeout, + ) + _log.info('Voice connection completed') + + if not self._runner: + self._runner = self.client.loop.create_task( + self._poll_ws(reconnect), + name=f'voice-ws-poller:{id(self):#x}', + ) + + async def disconnect(self, *, force: bool = True, cleanup: bool = True, wait: bool = False) -> None: + if not force and not self.is_connected(): + return + + try: + await self._voice_disconnect() + if self.ws: + await self.ws.close() + except Exception: + _log.debug('Ignoring exception while disconnecting from voice', exc_info=True) + finally: + self.state = ConnectionFlowState.disconnected + self._socket_reader.pause() + + if cleanup: + self._socket_reader.stop() + self.client.stop() + + self._connected.set() + self._connected.clear() + + if self.socket: + self.socket.close() + + self.ip = MISSING + self.port = MISSING + + if wait and not self._inside_runner(): + try: + await asyncio.wait_for(self._disconnected.wait(), timeout=self.timeout) + except TimeoutError: + _log.debug('Timed out waiting for voice disconnect confirmation') + except asyncio.CancelledError: + pass + + if cleanup: + self.client.cleanup() + + async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None: + _log.debug('Soft disconnecting from voice') + + if self._runner: + self._runner.cancel() + self._runner = None + + try: + if self.ws: + await self.ws.close() + except Exception: + _log.debug('Ignoring exception while soft disconnecting from voice', exc_info=True) + finally: + self.state = with_state + self._socket_reader.pause() + + if self.socket: + self.socket.close() + + self.ip = MISSING + self.port = MISSING + + async def move_to(self, channel: abc.Snowflake | None, timeout: float | None) -> None: + if channel is None: + await self.disconnect(wait=True) + return + + if self.client.channel and channel.id == self.client.channel.id: + return + + previous_state = self.state + await self._move_to(channel) + + last_state = self.state + + try: + await self.wait_for(timeout=timeout) + except asyncio.TimeoutError: + _log.warning('Timed out trying to move to channel %s in guild %s', channel.id, self.guild.id) + if self.state is last_state: + _log.debug('Reverting state %s to previous state: %s', last_state.name, previous_state.name) + self.state = previous_state + + def wait_for( + self, + state: ConnectionFlowState = ConnectionFlowState.connected, + timeout: float | None = None, + ) -> Any: + if state is ConnectionFlowState.connected: + return self._connected.wait(timeout) + return self._wait_for_state(state, timeout=timeout) + + def send_packet(self, packet: bytes) -> None: + self.socket.sendall(packet) + + def add_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Registering a socket listener callback %s', callback) + self._socket_reader.register(callback) + + def remove_socket_listener(self, callback: SocketReaderCallback) -> None: + _log.debug('Unregistering a socket listener callback %s', callback) + self._socket_reader.unregister(callback) + + async def _wait_for_state( + self, + *states: ConnectionFlowState, + timeout: float | None = None, + ) -> None: + if not states: + raise ValueError + + while True: + if self.state in states: + return + + _, pending = await asyncio.wait( + [ + asyncio.ensure_future(self._state_event.wait()), + ], + timeout=timeout, + ) + if pending: + # if we're here, it means that the state event + # has timed out, so just raise the exception + raise asyncio.TimeoutError + + async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None: + channel = self.client.channel + await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute) + + async def _voice_disconnect(self) -> None: + _log.info( + 'Terminating voice handshake for channel %s (guild %s)', + self.client.channel.id, + self.client.guild.id, + ) + + self.state = ConnectionFlowState.disconnected + await self.client.channel.guild.change_voice_state(channel=None) + self._expecting_disconnect = True + self._disconnected.clear() + + async def _connect_websocket(self, resume: bool) -> VoiceWebSocket: + seq_ack = -1 + if self.ws is not MISSING: + seq_ack = self.ws.seq_ack + ws = await VoiceWebSocket.from_state(self, resume=resume, hook=self.hook, seq_ack=seq_ack) + self.state = ConnectionFlowState.websocket_connected + return ws + + async def _handshake_websocket(self) -> None: + while not self.ip: + await self.ws.poll_event() + + self.state = ConnectionFlowState.got_ip_discovery + while self.ws.secret_key is None: + await self.ws.poll_event() + + self.state = ConnectionFlowState.connected + + def _create_socket(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + self._socket_reader.resume() + + async def _poll_ws(self, reconnect: bool) -> None: + backoff = ExponentialBackoff() + + while True: + try: + await self.ws.poll_event() + except asyncio.CancelledError: + return + except (ConnectionClosed, asyncio.TimeoutError) as exc: + if isinstance(exc, ConnectionClosed): + # 1000 - normal closure - not resumable + # 4014 - externally disconnected - not resumable + # 4015 - voice server crashed - resumable + # 4021 - ratelimited, not reconnect - not resumable + # 4022 - call terminated, similar to 4014 - not resumable + + if exc.code == 1000: + if not self._expecting_disconnect: + _log.info('Disconnecting from voice manually, close code %d', exc.code) + await self.disconnect() + break + elif exc.code in (4014, 4022): + if self._disconnected.is_set(): + _log.info('Disconnectinf from voice by Discord, close code %d', exc.code) + await self.disconnect() + break + + _log.info('Disconnecting from voice by force... potentially reconnecting...') + successful = await self._potential_reconnect() + if not successful: + _log.info('Reconnect was unsuccessful, disconnecting from voice normally') + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + else: + # we have successfully resumed so just keep polling events + continue + elif exc.code == 4021: + _log.warning( + 'We are being rate limited while attempting to connect to voice. Disconnecting...', + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + elif exc.code == 4015: + _log.info( + 'Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...', + ) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=True, + ) + except asyncio.TimeoutError: + _log.info('Could not resume the voice connection... Disconnecting...') + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + except Exception: + _log.exception( + 'An exception was raised while attempting a reconnect and resume... Disconnecting...', + exc_info=True, + ) + if self.state is not ConnectionFlowState.disconnected: + await self.disconnect() + break + else: + _log.info('Successfully reconnected and resume the voice connection') + continue + else: + _log.debug( + 'Not handling close code %s (%s)', + exc.code, + exc.reason or 'No reason was provided', + ) + + if not reconnect: + await self.disconnect() + raise + + retry = backoff.delay() + _log.exception( + 'Disconnected from voice... Reconnecting in %.2fs', + retry, + ) + await asyncio.sleep(retry) + await self.disconnect(cleanup=False) + + try: + await self._connect( + reconnect=reconnect, + timeout=self.timeout, + self_deaf=(self.self_voice_state or self).self_deaf, + self_mute=(self.self_voice_state or self).self_mute, + resume=False, + ) + except asyncio.TimeoutError: + _log.warning('Could not connect to voice... Retrying...') + continue + + async def _potential_reconnect(self) -> bool: + try: + await self._wait_for_state( + ConnectionFlowState.got_voice_server_update, + ConnectionFlowState.got_both_voice_updates, + ConnectionFlowState.disconnected, + timeout=self.timeout, + ) + except asyncio.TimeoutError: + return False + else: + if self.state is ConnectionFlowState.disconnected: + return False + + previous_ws = self.ws + + try: + self.ws = await self._connect_websocket(False) + await self._handshake_websocket() + except (ConnectionClosed, asyncio.TimeoutError): + return False + else: + return True + finally: + await previous_ws.close() + + async def _move_to(self, channel: abc.Snowflake) -> None: + await self.client.channel.guild.change_voice_state(channel=channel) + self.state = ConnectionFlowState.set_guild_voice_state + + def _update_voice_channel(self, channel_id: int | None) -> None: + self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore From 24d2db0865fca5310ae029c6627b622d1c2ef384 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 14:40:04 +0000 Subject: [PATCH 06/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/opus.py | 24 +++-- discord/sinks/core.py | 6 +- discord/voice/client.py | 91 ++++++++++-------- discord/voice/gateway.py | 12 +-- discord/voice/recorder.py | 1 + discord/voice/state.py | 192 ++++++++++++++++++++++++++------------ 6 files changed, 205 insertions(+), 121 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index 0160f85ef5..214ad2b577 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: T = TypeVar("T") - APPLICATION_CTL = Literal['audio', 'voip', 'lowdelay'] + APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] SIGNAL_CTL = Literal["auto", "voice", "music"] @@ -81,7 +81,7 @@ class ApplicationCtl(TypedDict): c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) -OPUS_SILENCE = b'\xf8\xff\xfe' +OPUS_SILENCE = b"\xf8\xff\xfe" _lib = None @@ -105,9 +105,9 @@ class DecoderStruct(ctypes.Structure): # Encoder CTLs application_ctl: ApplicationCtl = { - 'audio': 2049, - 'lowdelay': 2051, - 'voip': 2048, + "audio": 2049, + "lowdelay": 2051, + "voip": 2048, } CTL_SET_BITRATE = 4002 @@ -378,22 +378,20 @@ class Encoder(_OpusStruct): def __init__( self, *, - application: APPLICATION_CTL = 'audio', + application: APPLICATION_CTL = "audio", bitrate: int = 128, fec: bool = True, expected_packet_loss: float = 0.15, - bandwidth: BAND_CTL = 'full', - signal_type: SIGNAL_TL = 'auto', + bandwidth: BAND_CTL = "full", + signal_type: SIGNAL_TL = "auto", ) -> None: if application not in application_ctl: - raise ValueError( - 'invalid application ctl type provided' - ) + raise ValueError("invalid application ctl type provided") if not 16 <= bitrate <= 512: - raise ValueError('bitrate must be between 16 and 512, both included') + raise ValueError("bitrate must be between 16 and 512, both included") if not 0 < expected_packet_loss <= 1: raise ValueError( - 'expected_packet_loss must be between 0 and 1, including 1', + "expected_packet_loss must be between 0 and 1, including 1", ) _OpusStruct.get_opus_version() diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 0d7b4f5483..cacffbfbfa 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -128,9 +128,9 @@ def __init__(self, data: bytes, client: VoiceClient): self.header: bytes = data[:cutoff] self.data = self.data[cutoff:] - self.decrypted_data: bytes = getattr(self.client, f"_decrypt_{self.client.mode}")( - self.header, self.data - ) + self.decrypted_data: bytes = getattr( + self.client, f"_decrypt_{self.client.mode}" + )(self.header, self.data) self.decoded_data: bytes | None = None self.user_id: int | None = None diff --git a/discord/voice/client.py b/discord/voice/client.py index 157e58a01e..dee777f035 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -26,21 +26,21 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine import struct +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, Literal, overload from discord import opus from discord.errors import ClientException +from discord.sinks import RawData, Sink from discord.sinks.errors import RecordingException from discord.utils import MISSING -from discord.sinks import RawData, Sink from ._types import VoiceProtocol -from .state import VoiceConnectionState +from .player import AudioPlayer from .recorder import Recorder from .source import AudioSource -from .player import AudioPlayer +from .state import VoiceConnectionState if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -48,25 +48,26 @@ from discord import abc from discord.client import Client from discord.guild import Guild, VocalGuildChannel - from discord.state import ConnectionState - from discord.user import ClientUser + from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Decoder, Encoder from discord.raw_models import ( - RawVoiceStateUpdateEvent, RawVoiceServerUpdateEvent, + RawVoiceStateUpdateEvent, ) + from discord.state import ConnectionState from discord.types.voice import SupportedModes - from discord.opus import Encoder, APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Decoder + from discord.user import ClientUser from .gateway import VoiceWebSocket AfterCallback = Callable[[Exception | None], Any] - P = ParamSpec('P') + P = ParamSpec("P") has_nacl: bool try: import nacl.secret import nacl.utils + has_nacl = True except ImportError: has_nacl = False @@ -102,9 +103,9 @@ class VoiceClient(VoiceProtocol): def __init__(self, client: Client, channel: abc.Connectable) -> None: if not has_nacl: raise RuntimeError( - 'PyNaCl library is needed in order to use voice related features, ' + "PyNaCl library is needed in order to use voice related features, " 'you can run "pip install py-cord[voice]" to install all voice-related ' - 'dependencies.' + "dependencies." ) super().__init__(client, channel) @@ -130,10 +131,10 @@ def __init__(self, client: Client, channel: abc.Connectable) -> None: warn_nacl: bool = not has_nacl supported_modes: tuple[SupportedModes, ...] = ( - 'aead_xchacha20_poly1305_rtpsize', - 'xsalsa20_poly1305_lite', - 'xsalsa20_poly1305_suffix', - 'xsalsa20_poly1305', + "aead_xchacha20_poly1305_rtpsize", + "xsalsa20_poly1305_lite", + "xsalsa20_poly1305_suffix", + "xsalsa20_poly1305", ) @property @@ -224,7 +225,7 @@ def latency(self) -> float: .. versionadded:: 1.4 """ ws = self.ws - return float('inf') if not ws else ws.latency + return float("inf") if not ws else ws.latency @property def average_latency(self) -> float: @@ -233,7 +234,7 @@ def average_latency(self) -> float: .. versionadded:: 1.4 """ ws = self.ws - return float('inf') if not ws else ws.average_latency + return float("inf") if not ws else ws.average_latency async def disconnect(self, *, force: bool = False) -> None: """|coro| @@ -245,7 +246,9 @@ async def disconnect(self, *, force: bool = False) -> None: await self._connection.disconnect(force=force, wait=True) self.cleanup() - async def move_to(self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0) -> None: + async def move_to( + self, channel: abc.Snowflake | None, *, timeout: float | None = 30.0 + ) -> None: """|coro| moves you to a different voice channel. @@ -285,11 +288,11 @@ def _get_voice_packet(self, data: Any) -> bytes: # formulate rtp header header[0] = 0x80 header[1] = 0x78 - struct.pack_into('>H', header, 2, self.sequence) - struct.pack_into('>I', header, 4, self.timestamp) - struct.pack_into('>I', header, 8, self.ssrc) + struct.pack_into(">H", header, 2, self.sequence) + struct.pack_into(">I", header, 4, self.timestamp) + struct.pack_into(">I", header, 8, self.ssrc) - encrypt_packet = getattr(self, f'_encrypt_{self.mode}') + encrypt_packet = getattr(self, f"_encrypt_{self.mode}") return encrypt_packet(header, data) def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: @@ -309,17 +312,23 @@ def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: # deprecated box = nacl.secret.SecretBox(bytes(self.secret_key)) nonce = bytearray(24) - nonce[:4] = struct.pack('>I', self._incr_nonce) - self.checked_add('_incr_nonce', 1, 4294967295) + nonce[:4] = struct.pack(">I", self._incr_nonce) + self.checked_add("_incr_nonce", 1, 4294967295) return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - def _encrypt_aead_xcacha20_poly1305_rtpsize(self, header: bytes, data: Any) -> bytes: + def _encrypt_aead_xcacha20_poly1305_rtpsize( + self, header: bytes, data: Any + ) -> bytes: box = nacl.secret.Aead(bytes(self.secret_key)) nonce = bytearray(24) - nonce[:4] = struct.pack('>I', self._incr_nonce) - self.checked_add('_incr_nonce', 1, 4294967295) - return header + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext + nonce[:4] + nonce[:4] = struct.pack(">I", self._incr_nonce) + self.checked_add("_incr_nonce", 1, 4294967295) + return ( + header + + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext + + nonce[:4] + ) def _decrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: # deprecated @@ -349,7 +358,9 @@ def _decrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) - def _decrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data: Any) -> bytes: + def _decrypt_aead_xchacha20_poly1305_rtpsize( + self, header: bytes, data: Any + ) -> bytes: box = nacl.secret.Aead(bytes(self.secret_key)) nonce = bytearray(24) @@ -363,7 +374,7 @@ def _decrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data: Any) -> @staticmethod def strip_header_ext(data: bytes) -> bytes: if len(data) > 4 and data[0] == 0xBE and data[1] == 0xDE: - _, length = struct.unpack_from('>HH', data) + _, length = struct.unpack_from(">HH", data) offset = 4 + length * 4 data = data[offset:] return data @@ -403,12 +414,12 @@ def play( source: AudioSource, *, after: AfterCallback | None = None, - application: APPLICATION_CTL = 'audio', + application: APPLICATION_CTL = "audio", bitrate: int = 128, fec: bool = True, expected_packet_loss: float = 0.15, - bandwidth: BAND_CTL = 'full', - signal_type: SIGNAL_CTL = 'auto', + bandwidth: BAND_CTL = "full", + signal_type: SIGNAL_CTL = "auto", wait_finish: bool = False, ) -> None | asyncio.Future[None]: """Plays an :class:`AudioSource`. @@ -465,12 +476,12 @@ def play( """ if not self.is_connected(): - raise ClientException('Not connected to voice') + raise ClientException("Not connected to voice") if self.is_playing(): - raise ClientException('Already playing audio') + raise ClientException("Already playing audio") if not isinstance(source, AudioSource): raise TypeError( - f'Source must be an AudioSource, not {source.__class__.__name__}', + f"Source must be an AudioSource, not {source.__class__.__name__}", ) if not self.encoder and not source.is_opus(): self.encoder = opus.Encoder( @@ -579,10 +590,12 @@ def start_recording( """ if not self.is_connected(): - raise RecordingException('Not connected to a voice channel') + raise RecordingException("Not connected to a voice channel") if self.recording: - raise RecordingException('You are already recording') + raise RecordingException("You are already recording") if not isinstance(sink, Sink): - raise RecordingException(f'Expected a Sink object, got {sink.__class__.__name__}') + raise RecordingException( + f"Expected a Sink object, got {sink.__class__.__name__}" + ) self._recording_handler.empty() diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 0e0e4a78de..fca150d0fa 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -162,7 +162,7 @@ async def received_message(self, msg: Any, /): "successfully RESUMED.", ) elif op == OpCodes.session_description: - self.state.mode = data['mode'] + self.state.mode = data["mode"] elif op == OpCodes.hello: interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler( @@ -268,10 +268,10 @@ async def close(self, code: int = 1000) -> None: async def speak(self, state: SpeakingState = SpeakingState.voice) -> None: await self.send_as_json( { - 'op': int(OpCodes.speaking), - 'd': { - 'speaking': int(state), - 'delay': 0, + "op": int(OpCodes.speaking), + "d": { + "speaking": int(state), + "delay": 0, }, }, ) @@ -285,7 +285,7 @@ async def from_state( hook: Callable[..., Coroutine[Any, Any, Any]] | None = None, seq_ack: int = -1, ) -> Self: - gateway = f'wss://{state.endpoint}/?v=8' + gateway = f"wss://{state.endpoint}/?v=8" client = state.client http = client._state.http socket = await http.ws_connect(gateway, compress=15) diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py index e01d9f8694..9b5798d55f 100644 --- a/discord/voice/recorder.py +++ b/discord/voice/recorder.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations # TODO: finish this diff --git a/discord/voice/state.py b/discord/voice/state.py index 403d1abf56..d5a14827de 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -37,16 +37,17 @@ from discord.backoff import ExponentialBackoff from discord.errors import ConnectionClosed -from .gateway import VoiceWebSocket from .enums import ConnectionFlowState +from .gateway import VoiceWebSocket if TYPE_CHECKING: from discord import abc - from discord.user import ClientUser from discord.guild import Guild from discord.member import VoiceState + from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent from discord.types.voice import SupportedModes - from discord.raw_models import RawVoiceStateUpdateEvent, RawVoiceServerUpdateEvent + from discord.user import ClientUser + from .client import VoiceClient MISSING = utils.MISSING @@ -164,7 +165,9 @@ def __init__( self, client: VoiceClient, *, - hook: Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None = None, + hook: ( + Callable[[VoiceWebSocket, dict[str, Any]], Coroutine[Any, Any, Any]] | None + ) = None, ) -> None: self.client: VoiceClient = client self.hook = hook @@ -202,7 +205,7 @@ def state(self) -> ConnectionFlowState: @state.setter def state(self, state: ConnectionFlowState) -> None: if state is not self._state: - _log.debug('State changed from %s to %s', self._state.name, state.name) + _log.debug("State changed from %s to %s", self._state.name, state.name) self._state = state self._state_event.set() @@ -268,11 +271,11 @@ async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: if self._expecting_disconnect: self._expecting_disconnect = False else: - _log.debug('We have been disconnected from voice') + _log.debug("We have been disconnected from voice") await self.disconnect() return - self.ws.session_id = data['session_id'] + self.ws.session_id = data["session_id"] if self.state in ( ConnectionFlowState.set_guild_voice_state, @@ -293,10 +296,12 @@ async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: elif self.state is not ConnectionFlowState.disconnected: if channel_id != self.client.channel.id: - _log.info('We were moved from the channel while connecting...') + _log.info("We were moved from the channel while connecting...") self._update_voice_channel(channel_id) - await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update) + await self.soft_disconnect( + with_state=ConnectionFlowState.got_voice_state_update + ) await self.connect( reconnect=self.reconnect, timeout=self.timeout, @@ -306,7 +311,7 @@ async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: wait=False, ) else: - _log.debug('Ignoring unexpected VOICE_STATEUPDATE event') + _log.debug("Ignoring unexpected VOICE_STATEUPDATE event") async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: previous_token = self.token @@ -319,15 +324,18 @@ async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: if self.token is None or endpoint is None: _log.warning( - 'Awaiting endpoint... This requires waiting. ' - 'If timeout occurred considering raising the timeout and reconnecting.' + "Awaiting endpoint... This requires waiting. " + "If timeout occurred considering raising the timeout and reconnecting." ) return # strip the prefix off since we add it later - self.endpoint = endpoint.removeprefix('wss://') + self.endpoint = endpoint.removeprefix("wss://") - if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update): + if self.state in ( + ConnectionFlowState.set_guild_voice_state, + ConnectionFlowState.got_voice_state_update, + ): self.endpoint_ip = MISSING self._create_socket() @@ -337,17 +345,23 @@ async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: self.state = ConnectionFlowState.got_both_voice_updates elif self.state is ConnectionFlowState.connected: - _log.debug('Voice server update, closing old voice websocket') + _log.debug("Voice server update, closing old voice websocket") await self.ws.close(4014) # 4014 = main gw dropped self.state = ConnectionFlowState.got_voice_server_update elif self.state is not ConnectionFlowState.disconnected: - if previous_token == self.token and previous_server_id == self.server_id and previous_endpoint == self.endpoint: + if ( + previous_token == self.token + and previous_server_id == self.server_id + and previous_endpoint == self.endpoint + ): return - _log.debug('Unexpected VOICE_SERVER_UPDATE event received, handling...') + _log.debug("Unexpected VOICE_SERVER_UPDATE event received, handling...") - await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update) + await self.soft_disconnect( + with_state=ConnectionFlowState.got_voice_server_update + ) await self.connect( reconnect=self.reconnect, timeout=self.timeout, @@ -386,7 +400,7 @@ async def connect( self_mute, resume, ), - name=f'voice-connector:{id(self):#x}', + name=f"voice-connector:{id(self):#x}", ) if wait: @@ -403,21 +417,23 @@ async def _wrap_connect( try: await self._connect(reconnect, timeout, self_deaf, self_mute, resume) except asyncio.CancelledError: - _log.debug('Cancelling voice connection') + _log.debug("Cancelling voice connection") await self.soft_disconnect() raise except asyncio.TimeoutError: - _log.info('Timed out while connecting to voice') + _log.info("Timed out while connecting to voice") await self.disconnect() raise except Exception: - _log.exception('Error while connecting to voice... disconnecting') + _log.exception("Error while connecting to voice... disconnecting") await self.disconnect() raise - async def _inner_connect(self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool) -> None: + async def _inner_connect( + self, reconnect: bool, self_deaf: bool, self_mute: bool, resume: bool + ) -> None: for i in range(5): - _log.info('Starting voice handshake (connection attempt %s)', i + 1) + _log.info("Starting voice handshake (connection attempt %s)", i + 1) await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) if self.state is ConnectionFlowState.disconnected: @@ -425,7 +441,7 @@ async def _inner_connect(self, reconnect: bool, self_deaf: bool, self_mute: bool await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) - _log.info('Voice handshake complete. Endpoint found: %s', self.endpoint) + _log.info("Voice handshake complete. Endpoint found: %s", self.endpoint) try: self.ws = await self._connect_websocket(resume) @@ -434,7 +450,9 @@ async def _inner_connect(self, reconnect: bool, self_deaf: bool, self_mute: bool except ConnectionClosed: if reconnect: wait = 1 + i * 2 - _log.exception('Failed to connect to voice... Retrying in %s seconds', wait) + _log.exception( + "Failed to connect to voice... Retrying in %s seconds", wait + ) await self.disconnect(cleanup=False) await asyncio.sleep(wait) continue @@ -442,22 +460,36 @@ async def _inner_connect(self, reconnect: bool, self_deaf: bool, self_mute: bool await self.disconnect() raise - async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None: - _log.info(f'Connecting to voice {self.client.channel.id}') + async def _connect( + self, + reconnect: bool, + timeout: float, + self_deaf: bool, + self_mute: bool, + resume: bool, + ) -> None: + _log.info(f"Connecting to voice {self.client.channel.id}") await asyncio.wait_for( - self._inner_connect(reconnect=reconnect, self_deaf=self_deaf, self_mute=self_mute, resume=resume), + self._inner_connect( + reconnect=reconnect, + self_deaf=self_deaf, + self_mute=self_mute, + resume=resume, + ), timeout=timeout, ) - _log.info('Voice connection completed') + _log.info("Voice connection completed") if not self._runner: self._runner = self.client.loop.create_task( self._poll_ws(reconnect), - name=f'voice-ws-poller:{id(self):#x}', + name=f"voice-ws-poller:{id(self):#x}", ) - async def disconnect(self, *, force: bool = True, cleanup: bool = True, wait: bool = False) -> None: + async def disconnect( + self, *, force: bool = True, cleanup: bool = True, wait: bool = False + ) -> None: if not force and not self.is_connected(): return @@ -466,7 +498,9 @@ async def disconnect(self, *, force: bool = True, cleanup: bool = True, wait: bo if self.ws: await self.ws.close() except Exception: - _log.debug('Ignoring exception while disconnecting from voice', exc_info=True) + _log.debug( + "Ignoring exception while disconnecting from voice", exc_info=True + ) finally: self.state = ConnectionFlowState.disconnected self._socket_reader.pause() @@ -486,17 +520,23 @@ async def disconnect(self, *, force: bool = True, cleanup: bool = True, wait: bo if wait and not self._inside_runner(): try: - await asyncio.wait_for(self._disconnected.wait(), timeout=self.timeout) + await asyncio.wait_for( + self._disconnected.wait(), timeout=self.timeout + ) except TimeoutError: - _log.debug('Timed out waiting for voice disconnect confirmation') + _log.debug("Timed out waiting for voice disconnect confirmation") except asyncio.CancelledError: pass if cleanup: self.client.cleanup() - async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None: - _log.debug('Soft disconnecting from voice') + async def soft_disconnect( + self, + *, + with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates, + ) -> None: + _log.debug("Soft disconnecting from voice") if self._runner: self._runner.cancel() @@ -506,7 +546,9 @@ async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionF if self.ws: await self.ws.close() except Exception: - _log.debug('Ignoring exception while soft disconnecting from voice', exc_info=True) + _log.debug( + "Ignoring exception while soft disconnecting from voice", exc_info=True + ) finally: self.state = with_state self._socket_reader.pause() @@ -517,7 +559,9 @@ async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionF self.ip = MISSING self.port = MISSING - async def move_to(self, channel: abc.Snowflake | None, timeout: float | None) -> None: + async def move_to( + self, channel: abc.Snowflake | None, timeout: float | None + ) -> None: if channel is None: await self.disconnect(wait=True) return @@ -533,9 +577,17 @@ async def move_to(self, channel: abc.Snowflake | None, timeout: float | None) -> try: await self.wait_for(timeout=timeout) except asyncio.TimeoutError: - _log.warning('Timed out trying to move to channel %s in guild %s', channel.id, self.guild.id) + _log.warning( + "Timed out trying to move to channel %s in guild %s", + channel.id, + self.guild.id, + ) if self.state is last_state: - _log.debug('Reverting state %s to previous state: %s', last_state.name, previous_state.name) + _log.debug( + "Reverting state %s to previous state: %s", + last_state.name, + previous_state.name, + ) self.state = previous_state def wait_for( @@ -551,11 +603,11 @@ def send_packet(self, packet: bytes) -> None: self.socket.sendall(packet) def add_socket_listener(self, callback: SocketReaderCallback) -> None: - _log.debug('Registering a socket listener callback %s', callback) + _log.debug("Registering a socket listener callback %s", callback) self._socket_reader.register(callback) def remove_socket_listener(self, callback: SocketReaderCallback) -> None: - _log.debug('Unregistering a socket listener callback %s', callback) + _log.debug("Unregistering a socket listener callback %s", callback) self._socket_reader.unregister(callback) async def _wait_for_state( @@ -581,13 +633,17 @@ async def _wait_for_state( # has timed out, so just raise the exception raise asyncio.TimeoutError - async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None: + async def _voice_connect( + self, *, self_deaf: bool = False, self_mute: bool = False + ) -> None: channel = self.client.channel - await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute) + await channel.guild.change_voice_state( + channel=channel, self_deaf=self_deaf, self_mute=self_mute + ) async def _voice_disconnect(self) -> None: _log.info( - 'Terminating voice handshake for channel %s (guild %s)', + "Terminating voice handshake for channel %s (guild %s)", self.client.channel.id, self.client.guild.id, ) @@ -601,7 +657,9 @@ async def _connect_websocket(self, resume: bool) -> VoiceWebSocket: seq_ack = -1 if self.ws is not MISSING: seq_ack = self.ws.seq_ack - ws = await VoiceWebSocket.from_state(self, resume=resume, hook=self.hook, seq_ack=seq_ack) + ws = await VoiceWebSocket.from_state( + self, resume=resume, hook=self.hook, seq_ack=seq_ack + ) self.state = ConnectionFlowState.websocket_connected return ws @@ -638,19 +696,29 @@ async def _poll_ws(self, reconnect: bool) -> None: if exc.code == 1000: if not self._expecting_disconnect: - _log.info('Disconnecting from voice manually, close code %d', exc.code) + _log.info( + "Disconnecting from voice manually, close code %d", + exc.code, + ) await self.disconnect() break elif exc.code in (4014, 4022): if self._disconnected.is_set(): - _log.info('Disconnectinf from voice by Discord, close code %d', exc.code) + _log.info( + "Disconnectinf from voice by Discord, close code %d", + exc.code, + ) await self.disconnect() break - _log.info('Disconnecting from voice by force... potentially reconnecting...') + _log.info( + "Disconnecting from voice by force... potentially reconnecting..." + ) successful = await self._potential_reconnect() if not successful: - _log.info('Reconnect was unsuccessful, disconnecting from voice normally') + _log.info( + "Reconnect was unsuccessful, disconnecting from voice normally" + ) if self.state is not ConnectionFlowState.disconnected: await self.disconnect() break @@ -659,14 +727,14 @@ async def _poll_ws(self, reconnect: bool) -> None: continue elif exc.code == 4021: _log.warning( - 'We are being rate limited while attempting to connect to voice. Disconnecting...', + "We are being rate limited while attempting to connect to voice. Disconnecting...", ) if self.state is not ConnectionFlowState.disconnected: await self.disconnect() break elif exc.code == 4015: _log.info( - 'Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...', + "Disconnected from voice due to a Discord-side issue, attempting to reconnect and resume...", ) try: @@ -678,26 +746,30 @@ async def _poll_ws(self, reconnect: bool) -> None: resume=True, ) except asyncio.TimeoutError: - _log.info('Could not resume the voice connection... Disconnecting...') + _log.info( + "Could not resume the voice connection... Disconnecting..." + ) if self.state is not ConnectionFlowState.disconnected: await self.disconnect() break except Exception: _log.exception( - 'An exception was raised while attempting a reconnect and resume... Disconnecting...', + "An exception was raised while attempting a reconnect and resume... Disconnecting...", exc_info=True, ) if self.state is not ConnectionFlowState.disconnected: await self.disconnect() break else: - _log.info('Successfully reconnected and resume the voice connection') + _log.info( + "Successfully reconnected and resume the voice connection" + ) continue else: _log.debug( - 'Not handling close code %s (%s)', + "Not handling close code %s (%s)", exc.code, - exc.reason or 'No reason was provided', + exc.reason or "No reason was provided", ) if not reconnect: @@ -706,7 +778,7 @@ async def _poll_ws(self, reconnect: bool) -> None: retry = backoff.delay() _log.exception( - 'Disconnected from voice... Reconnecting in %.2fs', + "Disconnected from voice... Reconnecting in %.2fs", retry, ) await asyncio.sleep(retry) @@ -721,7 +793,7 @@ async def _poll_ws(self, reconnect: bool) -> None: resume=False, ) except asyncio.TimeoutError: - _log.warning('Could not connect to voice... Retrying...') + _log.warning("Could not connect to voice... Retrying...") continue async def _potential_reconnect(self) -> bool: From 63b4d853cad7f0ac4fa13ed37ab8c74995fea5e1 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:23:04 +0200 Subject: [PATCH 07/40] protocols --- discord/voice/_types.py | 5 +++++ discord/voice/recorder.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/discord/voice/_types.py b/discord/voice/_types.py index 4acf832657..1686faba24 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -34,8 +34,10 @@ RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) + from discord.voice.client import VoiceClient ClientT = TypeVar("ClientT", bound="Client", covariant=True) +VoiceClientT = TypeVar('VoiceClientT', bound='VoiceClient', covariant=True) class VoiceProtocol(Generic[ClientT]): @@ -157,3 +159,6 @@ def cleanup(self) -> None: """ key, _ = self.channel._get_voice_client_key() self.client._connection._remove_voice_client(key) + + +class RecorderProtocol(Generic[VoiceClientT]): diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py index e01d9f8694..6323cae6a7 100644 --- a/discord/voice/recorder.py +++ b/discord/voice/recorder.py @@ -25,3 +25,14 @@ from __future__ import annotations # TODO: finish this + + +class Recorder: + """Represents a voice recorder for a voice client. + + You should not construct this but instead obtain it from :attr:`VoiceClient.recorder`. + + .. versionadded:: 2.7 + """ + + def __init__(self, client: VoiceClient) -> None: From fe88ee82283b8939ab00b4db83342b3f2624de79 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Tue, 26 Aug 2025 19:05:19 +0200 Subject: [PATCH 08/40] recv --- discord/opus.py | 24 ++++---- discord/sinks/core.py | 2 +- discord/voice/_types.py | 112 ++++++++++++++++++++++++++++++++++++-- discord/voice/client.py | 16 ++++-- discord/voice/gateway.py | 2 +- discord/voice/recorder.py | 43 ++++++++++++++- 6 files changed, 174 insertions(+), 25 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index 214ad2b577..1288d7248f 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -42,6 +42,8 @@ from .sinks import RawData if TYPE_CHECKING: + from discord.voice.recorder import VoiceRecorderClient + T = TypeVar("T") APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] @@ -548,17 +550,17 @@ def decode(self, data, *, fec=False): class DecodeManager(threading.Thread, _OpusStruct): - def __init__(self, client): + def __init__(self, client: VoiceRecorderClient): super().__init__(daemon=True, name="DecodeManager") - self.client = client - self.decode_queue = [] + self.client: VoiceRecorderClient = client + self.decode_queue: list[RawData] = [] - self.decoder = {} + self.decoder: dict[int, Decoder] = {} self._end_thread = threading.Event() - def decode(self, opus_frame): + def decode(self, opus_frame: RawData): if not isinstance(opus_frame, RawData): raise TypeError("opus_frame should be a RawData object.") self.decode_queue.append(opus_frame) @@ -579,20 +581,20 @@ def run(self): data.decrypted_data ) except OpusError: - print("Error occurred while decoding opus frame.") + _log.exception("Error occurred while decoding opus frame.", exc_info=True) continue - self.client.recv_decoded_audio(data) + self.client.receive_audio(data) - def stop(self): + def stop(self) -> None: while self.decoding: time.sleep(0.1) self.decoder = {} gc.collect() - print("Decoder Process Killed") + _log.debug("Decoder Process Killed") self._end_thread.set() - def get_decoder(self, ssrc): + def get_decoder(self, ssrc: int) -> Decoder: d = self.decoder.get(ssrc) if d is not None: return d @@ -600,5 +602,5 @@ def get_decoder(self, ssrc): return self.decoder[ssrc] @property - def decoding(self): + def decoding(self) -> bool: return bool(self.decode_queue) diff --git a/discord/sinks/core.py b/discord/sinks/core.py index cacffbfbfa..b67701cb67 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -224,7 +224,7 @@ def __init__(self, *, filters=None): self.audio_data = {} def init(self, vc): # called under listen - self.vc: VoiceClient = vc + self.vc = vc super().init() @Filters.container diff --git a/discord/voice/_types.py b/discord/voice/_types.py index 1686faba24..e96674063c 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -25,19 +25,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union if TYPE_CHECKING: + from typing_extensions import ParamSpec + from discord import abc from discord.client import Client from discord.raw_models import ( RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) - from discord.voice.client import VoiceClient + from discord.sinks import Sink + + P = ParamSpec('P') + R = TypeVar('R') + RecordCallback = Union[Callable[P, R], Callable[P, Awaitable[R]]] ClientT = TypeVar("ClientT", bound="Client", covariant=True) -VoiceClientT = TypeVar('VoiceClientT', bound='VoiceClient', covariant=True) +VoiceProtocolT = TypeVar('VoiceProtocolT', bound='VoiceProtocol', covariant=True) class VoiceProtocol(Generic[ClientT]): @@ -45,7 +52,7 @@ class VoiceProtocol(Generic[ClientT]): .. warning:: - If you are a end user, you **should not construct this manually** but instead + If you are an end user, you **should not construct this manually** but instead take it from the return type in :meth:`abc.Connectable.connect `. The parameters and methods being documented here is so third party libraries can refer to it when implementing their own VoiceProtocol types. @@ -161,4 +168,99 @@ def cleanup(self) -> None: self.client._connection._remove_voice_client(key) -class RecorderProtocol(Generic[VoiceClientT]): +class VoiceRecorderProtocol(Generic[VoiceProtocolT]): + """A class that represents a Discord voice client recorder protocol. + + .. warning:: + + If you are an end user, you **should not construct this manually** but instead + take it from a :class:`VoiceProtocol` implementation, like :attr:`VoiceClient.recorder`. + The parameters and methods being documented here is so third party libraries can refer to it + when implementing their own RecorderProtocol types. + + This is an abstract class. The library provides a concrete implementation under + :class:`VoiceRecorderClient`. + + This class allows you to implement a protocol to allow for an external + method of receiving and handling voice data. + + .. versionadded:: 2.7 + + Parameters + ---------- + client: :class:`VoiceProtocol` + The voice client (or its subclasses) that are bound to this recorder. + channel: :class:`abc.Connectable` + The voice channel that is being recorder. If not provided, defaults to + :attr:`VoiceProtocol.channel` + """ + + def __init__(self, client: VoiceProtocolT, channel: abc.Connectable | None = None) -> None: + self.client: VoiceProtocolT = client + self.channel: abc.Connectable = channel or client.channel + + def get_ssrc(self, user_id: int) -> int: + """Gets the ssrc of a user. + + Parameters + ---------- + user_id: :class:`int` + The user ID to get the ssrc from. + + Returns + ------- + :class:`int` + The ssrc for the provided user ID. + """ + raise NotImplementedError('subclasses must implement this') + + def unpack(self, data: bytes) -> bytes | None: + """Takes an audio packet received from Discord and decodes it. + + Parameters + ---------- + data: :class:`bytes` + The bytes received by Discord. + + Returns + ------- + Optional[:class:`bytes`] + The unpacked bytes, or ``None`` if they could not be unpacked. + """ + raise NotImplementedError('subclasses must implement this') + + def record( + self, + sink: Sink, + callback: RecordCallback[P, R], + sync_start: bool, + *callback_args: P.args, + **callback_kwargs: P.kwargs, + ) -> None: + """Start recording audio from the current voice channel in the provided sink. + + You must be in a voice channel. + + Parameters + ---------- + sink: :class:`~discord.Sink` + The sink to record to. + callback: Callable[..., Any] + The function called after the bot has stopped recording. This can take any arguments and + can return an awaitable. + sync_start: :class:`bool` + Whether the subsequent recording users will start with silence. This is useful for recording + audio just as it was heard. + + Raises + ------ + RecordingException + Not connected to a voice channel + TypeError + You did not pass a Sink object. + """ + raise NotImplementedError('subclasses must implement this') + + def stop(self) -> None: + """Stops recording.""" + raise NotImplementedError('subclasses must implement this') diff --git a/discord/voice/client.py b/discord/voice/client.py index dee777f035..c6f74af3b1 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -38,7 +38,7 @@ from ._types import VoiceProtocol from .player import AudioPlayer -from .recorder import Recorder +from .recorder import VoiceRecorderClient from .source import AudioSource from .state import VoiceConnectionState @@ -100,7 +100,13 @@ class VoiceClient(VoiceProtocol): channel: VocalGuildChannel - def __init__(self, client: Client, channel: abc.Connectable) -> None: + def __init__( + self, + client: Client, + channel: abc.Connectable, + *, + use_recorder: bool = True, + ) -> None: if not has_nacl: raise RuntimeError( "PyNaCl library is needed in order to use voice related features, " @@ -127,7 +133,9 @@ def __init__(self, client: Client, channel: abc.Connectable) -> None: self._connection: VoiceConnectionState = self.create_connection_state() # voice recv things - self._recorder: Recorder | None = None + self._recorder: VoiceRecorderClient | None = None + if use_recorder: + self._recorder = VoiceRecorderClient(self) warn_nacl: bool = not has_nacl supported_modes: tuple[SupportedModes, ...] = ( @@ -187,7 +195,7 @@ def checked_add(self, attr: str, value: int, limit: int) -> None: setattr(self, attr, val + value) def create_connection_state(self) -> VoiceConnectionState: - return VoiceConnectionState(self) + return VoiceConnectionState(self, hook=self._recorder) async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: await self._connection.voice_state_update(data) diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index fca150d0fa..10b5f0d3ba 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -171,7 +171,7 @@ async def received_message(self, msg: Any, /): ) self._keep_alive.start() - await self._hook(self, msg) + await utils.maybe_coroutine(self._hook, self, data) async def ready(self, data: dict[str, Any]) -> None: state = self.state diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py index d72f2f6082..321e22982c 100644 --- a/discord/voice/recorder.py +++ b/discord/voice/recorder.py @@ -25,10 +25,24 @@ from __future__ import annotations -# TODO: finish this +import asyncio +import threading +from typing import TYPE_CHECKING, Any, TypeVar +from discord.opus import DecodeManager -class Recorder: +from ._types import VoiceRecorderProtocol + +if TYPE_CHECKING: + from discord.sinks import Sink + + from .client import VoiceClient + from .gateway import VoiceWebSocket + + VoiceClientT = TypeVar('VoiceClientT', bound=VoiceClient, covariant=True) + + +class VoiceRecorderClient(VoiceRecorderProtocol[VoiceClientT]): """Represents a voice recorder for a voice client. You should not construct this but instead obtain it from :attr:`VoiceClient.recorder`. @@ -36,4 +50,27 @@ class Recorder: .. versionadded:: 2.7 """ - def __init__(self, client: VoiceClient) -> None: + def __init__(self, client: VoiceClientT) -> None: + super().__init__(client) + + self._paused: asyncio.Event = asyncio.Event() + self._recording: asyncio.Event = asyncio.Event() + self.decoder: DecodeManager = DecodeManager(self) + self.sync_start: bool = False + self.sinks: dict[int, tuple[Sink, threading.Thread]] = {} + + def is_paused(self) -> bool: + """Whether the current recorder is paused.""" + return self._paused.is_set() + + def is_recording(self) -> bool: + """Whether the current recording is actively recording.""" + return self._recording.is_set() + + async def hook(self, ws: VoiceWebSocket, data: dict[str, Any]) -> None: + ... + + def record( + self, + sink: Sink, + ) From 726c8f989228be748ed9479778a0b0ae74a0f95a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:05:54 +0000 Subject: [PATCH 09/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/opus.py | 4 +++- discord/voice/_types.py | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index 1288d7248f..6c8720d49e 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -581,7 +581,9 @@ def run(self): data.decrypted_data ) except OpusError: - _log.exception("Error occurred while decoding opus frame.", exc_info=True) + _log.exception( + "Error occurred while decoding opus frame.", exc_info=True + ) continue self.client.receive_audio(data) diff --git a/discord/voice/_types.py b/discord/voice/_types.py index e96674063c..57d5fae237 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -26,7 +26,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar, Union if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -39,12 +39,12 @@ ) from discord.sinks import Sink - P = ParamSpec('P') - R = TypeVar('R') + P = ParamSpec("P") + R = TypeVar("R") RecordCallback = Union[Callable[P, R], Callable[P, Awaitable[R]]] ClientT = TypeVar("ClientT", bound="Client", covariant=True) -VoiceProtocolT = TypeVar('VoiceProtocolT', bound='VoiceProtocol', covariant=True) +VoiceProtocolT = TypeVar("VoiceProtocolT", bound="VoiceProtocol", covariant=True) class VoiceProtocol(Generic[ClientT]): @@ -195,7 +195,9 @@ class VoiceRecorderProtocol(Generic[VoiceProtocolT]): :attr:`VoiceProtocol.channel` """ - def __init__(self, client: VoiceProtocolT, channel: abc.Connectable | None = None) -> None: + def __init__( + self, client: VoiceProtocolT, channel: abc.Connectable | None = None + ) -> None: self.client: VoiceProtocolT = client self.channel: abc.Connectable = channel or client.channel @@ -212,7 +214,7 @@ def get_ssrc(self, user_id: int) -> int: :class:`int` The ssrc for the provided user ID. """ - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") def unpack(self, data: bytes) -> bytes | None: """Takes an audio packet received from Discord and decodes it. @@ -227,7 +229,7 @@ def unpack(self, data: bytes) -> bytes | None: Optional[:class:`bytes`] The unpacked bytes, or ``None`` if they could not be unpacked. """ - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") def record( self, @@ -259,8 +261,8 @@ def record( TypeError You did not pass a Sink object. """ - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") def stop(self) -> None: """Stops recording.""" - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") From fbab9803af5daad703fdf12d7fb7e1405f12a77d Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:01:25 +0200 Subject: [PATCH 10/40] fix voice connection, now fix voice recording :sob: --- discord/abc.py | 2 +- discord/client.py | 63 +-- discord/commands/context.py | 2 +- discord/ext/commands/context.py | 2 +- discord/guild.py | 4 +- discord/player.py | 74 +-- discord/raw_models.py | 1 - discord/state.py | 2 +- discord/utils.py | 100 ++++ discord/voice/_types.py | 2 +- discord/voice/client.py | 134 +++-- discord/voice/gateway.py | 73 ++- discord/voice/recorder.py | 3 +- discord/voice/state.py | 21 +- discord/voice_client.py | 932 -------------------------------- 15 files changed, 308 insertions(+), 1107 deletions(-) delete mode 100644 discord/voice_client.py diff --git a/discord/abc.py b/discord/abc.py index 6b90871436..354ef37e83 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -55,7 +55,7 @@ from .role import Role from .scheduled_events import ScheduledEvent from .sticker import GuildSticker, StickerItem -from .voice_client import VoiceClient, VoiceProtocol +from .voice import VoiceClient, VoiceProtocol __all__ = ( "Snowflake", diff --git a/discord/client.py b/discord/client.py index 2d0f1b8770..1c936099b8 100644 --- a/discord/client.py +++ b/discord/client.py @@ -61,19 +61,19 @@ from .ui.view import View from .user import ClientUser, User from .utils import MISSING -from .voice_client import VoiceClient +from .voice import VoiceClient from .webhook import Webhook from .widget import Widget if TYPE_CHECKING: from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .channel import DMChannel - from .interaction import Interaction + from .interactions import Interaction from .member import Member from .message import Message from .poll import Poll from .ui.item import Item - from .voice_client import VoiceProtocol + from .voice import VoiceProtocol __all__ = ("Client",) @@ -467,7 +467,7 @@ def _schedule_event( return task def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug("Dispatching event %s", event) + logging.getLogger('discord.state').debug("Dispatching event %s", event) method = f"on_{event}" listeners = self._listeners.get(event) @@ -789,7 +789,16 @@ async def start(self, token: str, *, reconnect: bool = True) -> None: await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args: Any, **kwargs: Any) -> None: + def run( + self, + token: str, + *, + reconnect: bool = True, + log_handler: logging.Handler | None = MISSING, + log_formatter: logging.Formatter = MISSING, + log_level: int = MISSING, + root_logger: bool = False, + ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -815,39 +824,25 @@ def run(self, *args: Any, **kwargs: Any) -> None: """ loop = self.loop - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.add_signal_handler(signal.SIGTERM, loop.stop) - except (NotImplementedError, RuntimeError): - pass - async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() - - def stop_loop_on_completion(f): - loop.stop() + async with self: + await self.start(token, reconnect=reconnect) + + if log_handler is not None: + utils.setup_logging( + handler=log_handler, + formatter=log_formatter, + level=log_level, + root=root_logger, + ) - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() + asyncio.run(runner()) except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") - finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) - - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # nothing to do here + # `asyncio.run` handles the loop cleanup + # and `self.start` closes all sockets and the HTTPClient instance. + return # properties diff --git a/discord/commands/context.py b/discord/commands/context.py index 73a6b39a45..d8748b9098 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -48,7 +48,7 @@ from ..permissions import Permissions from ..state import ConnectionState from ..user import User - from ..voice_client import VoiceClient + from ..voice import VoiceClient from ..webhook import WebhookMessage from .core import ApplicationCommand, Option diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index afd023d351..11e2216fce 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -41,7 +41,7 @@ from discord.member import Member from discord.state import ConnectionState from discord.user import ClientUser, User - from discord.voice_client import VoiceProtocol + from discord.voice import VoiceProtocol from .bot import AutoShardedBot, Bot from .cog import Cog diff --git a/discord/guild.py b/discord/guild.py index 488f4aa074..2d2d17beeb 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -111,8 +111,8 @@ from .types.guild import GuildFeature, MFALevel from .types.member import Member as MemberPayload from .types.threads import Thread as ThreadPayload - from .types.voice import GuildVoiceState - from .voice_client import VoiceClient + from .types.voice import VoiceState as GuildVoiceState + from .voice import VoiceClient from .webhook import Webhook VocalGuildChannel = Union[VoiceChannel, StageChannel] diff --git a/discord/player.py b/discord/player.py index 65b23ed42a..ae345be9b8 100644 --- a/discord/player.py +++ b/discord/player.py @@ -41,12 +41,13 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, TypeVar from .errors import ClientException +from .enums import SpeakingState from .oggparse import OggStream -from .opus import Encoder as OpusEncoder +from .opus import Encoder as OpusEncoder, OPUS_SILENCE from .utils import MISSING if TYPE_CHECKING: - from .voice_client import VoiceClient + from .voice import VoiceClient AT = TypeVar("AT", bound="AudioSource") @@ -732,7 +733,6 @@ def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): self._resumed: threading.Event = threading.Event() self._resumed.set() # we are not paused self._current_error: Exception | None = None - self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() self._played_frames_offset: int = 0 @@ -742,49 +742,51 @@ def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): def _do_run(self) -> None: # attempt to read first audio segment from source before starting # some sources can take a few seconds and may cause problems - first_data = self.source.read() self.loops = 0 self._start = time.perf_counter() # getattr lookup speed ups - play_audio = self.client.send_audio_packet - self._speak(True) + client = self.client + play_audio = client.send_audio_packet + self._speak(SpeakingState.voice) while not self._end.is_set(): # are we paused? if not self._resumed.is_set(): + self.send_silence() # wait until we aren't self._resumed.wait() continue - # are we disconnected from voice? - if not self._connected.is_set(): - # wait until we are connected - self._connected.wait() - # reset our internal data - self._played_frames_offset += self.loops - self.loops = 0 - self._start = time.perf_counter() - - self.loops += 1 - - # Send the data read from the start of the function if it is not None - if first_data is not None: - data = first_data - first_data = None - # Else read the next bit from the source - else: - data = self.source.read() + data = self.source.read() if not data: self.stop() break + # are we disconnected from voice? + if not client.is_connected(): + _log.debug('Not connected, waiting for %ss...', client.timeout) + # wait until we are connected, but not forever + connected = client.wait_until_connected(client.timeout) + if self._end.is_set() or not connected: + _log.debug('Aborting playback') + return + _log.debug('Reconnected, resuming playback') + self._speak(SpeakingState.voice) + # reset our internal data + self.loops = 0 + self._start = time.perf_counter() + play_audio(data, encode=not self.source.is_opus()) + self.loops += 1 next_time = self._start + self.DELAY * self.loops delay = max(0, self.DELAY + (next_time - time.perf_counter())) time.sleep(delay) + if client.is_connected(): + self.send_silence() + def run(self) -> None: try: self._do_run() @@ -792,8 +794,8 @@ def run(self) -> None: self._current_error = exc self.stop() finally: - self.source.cleanup() self._call_after() + self.source.cleanup() def _call_after(self) -> None: error = self._current_error @@ -802,24 +804,21 @@ def _call_after(self) -> None: try: self.after(error) except Exception as exc: - _log.exception("Calling the after function failed.") exc.__context__ = error - traceback.print_exception(type(exc), exc, exc.__traceback__) + _log.exception("Calling the after function failed.", exc_info=exc) elif error: msg = f"Exception in voice thread {self.name}" _log.exception(msg, exc_info=error) - print(msg, file=sys.stderr) - traceback.print_exception(type(error), error, error.__traceback__) def stop(self) -> None: self._end.set() self._resumed.set() - self._speak(False) + self._speak(SpeakingState.none) def pause(self, *, update_speaking: bool = True) -> None: self._resumed.clear() if update_speaking: - self._speak(False) + self._speak(SpeakingState.none) def resume(self, *, update_speaking: bool = True) -> None: self._played_frames_offset += self.loops @@ -827,7 +826,7 @@ def resume(self, *, update_speaking: bool = True) -> None: self._start = time.perf_counter() self._resumed.set() if update_speaking: - self._speak(True) + self._speak(SpeakingState.voice) def is_playing(self) -> bool: return self._resumed.is_set() and not self._end.is_set() @@ -841,10 +840,10 @@ def _set_source(self, source: AudioSource) -> None: self.source = source self.resume(update_speaking=False) - def _speak(self, speaking: bool) -> None: + def _speak(self, state: SpeakingState) -> None: try: asyncio.run_coroutine_threadsafe( - self.client.ws.speak(speaking), self.client.loop + self.client.ws.speak(state), self.client.loop ) except Exception as e: _log.info("Speaking call in player failed: %s", e) @@ -852,3 +851,10 @@ def _speak(self, speaking: bool) -> None: def played_frames(self) -> int: """Gets the number of 20ms frames played since the start of the audio file.""" return self._played_frames_offset + self.loops + + def send_silence(self, count: int = 5) -> None: + try: + for n in range(count): + self.client.send_audio_packet(OPUS_SILENCE, encode=False) + except Exception: + pass diff --git a/discord/raw_models.py b/discord/raw_models.py index 0771159f28..581ae7052e 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -972,7 +972,6 @@ class RawVoiceStateUpdateEvent(_PayloadLike): "suppress", "requested_to_speak_at", "afk", - "channel", "guild_id", "channel_id", "_state", diff --git a/discord/state.py b/discord/state.py index 73917a6e59..0a72b9f0bf 100644 --- a/discord/state.py +++ b/discord/state.py @@ -88,7 +88,7 @@ from .types.poll import Poll as PollPayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.user import User as UserPayload - from .voice_client import VoiceClient + from .voice import VoiceClient T = TypeVar("T") CS = TypeVar("CS", bound="ConnectionState") diff --git a/discord/utils.py b/discord/utils.py index e7b4ffde0d..bab51d3687 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -33,6 +33,8 @@ import importlib.resources import itertools import json +import logging +import os import re import sys import types @@ -1408,3 +1410,101 @@ def filter_params(params, **kwargs): params[new_param] = params.pop(old_param) return params + + +def is_docker() -> bool: + path = '/proc/self/cgroup' + return os.path.exists('/.dockerenv') or (os.path.isfile(path) and any('docker' in line for line in open(path))) + + +def stream_supports_colour(stream: Any) -> bool: + is_a_tty = hasattr(stream, 'isatty') and stream.isatty() + + # Pycharm and Vscode support colour in their inbuilt editors + if 'PYCHARM_HOSTED' in os.environ or os.environ.get('TERM_PROGRAM') == 'vscode': + return is_a_tty + + if sys.platform != 'win32': + # Docker does not consistently have a tty attached to it + return is_a_tty or is_docker() + + # ANSICON checks for things like ConEmu + # WT_SESSION checks if this is Windows Terminal + return is_a_tty and ('ANSICON' in os.environ or 'WT_SESSION' in os.environ) + + +class _ColourFormatter(logging.Formatter): + # ANSI codes are a bit weird to decipher if you're unfamiliar with them, so here's a refresher + # It starts off with a format like \x1b[XXXm where XXX is a semicolon separated list of commands + # The important ones here relate to colour. + # 30-37 are black, red, green, yellow, blue, magenta, cyan and white in that order + # 40-47 are the same except for the background + # 90-97 are the same but "bright" foreground + # 100-107 are the same as the bright ones but for the background. + # 1 means bold, 2 means dim, 0 means reset, and 4 means underline. + + LEVEL_COLOURS = [ + (logging.DEBUG, '\x1b[40;1m'), + (logging.INFO, '\x1b[34;1m'), + (logging.WARNING, '\x1b[33;1m'), + (logging.ERROR, '\x1b[31m'), + (logging.CRITICAL, '\x1b[41m'), + ] + + FORMATS = { + level: logging.Formatter( + f'\x1b[30;1m%(asctime)s\x1b[0m {colour}%(levelname)-8s\x1b[0m \x1b[35m%(name)s\x1b[0m %(message)s', + '%Y-%m-%d %H:%M:%S', + ) + for level, colour in LEVEL_COLOURS + } + + def format(self, record): + formatter = self.FORMATS.get(record.levelno) + if formatter is None: + formatter = self.FORMATS[logging.DEBUG] + + # Override the traceback to always print in red + if record.exc_info: + text = formatter.formatException(record.exc_info) + record.exc_text = f'\x1b[31m{text}\x1b[0m' + + output = formatter.format(record) + + # Remove the cache layer + record.exc_text = None + return output + + +def setup_logging( + *, + handler: logging.Handler = MISSING, + formatter: logging.Formatter = MISSING, + level: int = MISSING, + root: bool = True, +) -> None: + """A helper method to automatically setup the library's default logging. + """ + + if level is MISSING: + level = logging.INFO + + if handler is MISSING: + handler = logging.StreamHandler() + + if formatter is MISSING: + if isinstance(handler, logging.StreamHandler) and stream_supports_colour(handler.stream): + formatter = _ColourFormatter() + else: + dt_fmt = '%Y-%m-%d %H:%M:%S' + formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}', dt_fmt, style='{') + + if root: + logger = logging.getLogger() + else: + lib, _, _ = __name__.partition('.') + logger = logging.getLogger(lib) + + handler.setFormatter(formatter) + logger.setLevel(level) + logger.addHandler(handler) diff --git a/discord/voice/_types.py b/discord/voice/_types.py index e96674063c..aaadecdadc 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -237,7 +237,7 @@ def record( *callback_args: P.args, **callback_kwargs: P.kwargs, ) -> None: - """Start recording audio from the current voice channel in the provided sink. + r"""Start recording audio from the current voice channel in the provided sink. You must be in a voice channel. diff --git a/discord/voice/client.py b/discord/voice/client.py index c6f74af3b1..9581d11687 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -26,20 +26,19 @@ from __future__ import annotations import asyncio +import datetime +import logging import struct -from collections.abc import Callable, Coroutine +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal, overload from discord import opus from discord.errors import ClientException -from discord.sinks import RawData, Sink -from discord.sinks.errors import RecordingException from discord.utils import MISSING +from discord.player import AudioSource, AudioPlayer from ._types import VoiceProtocol -from .player import AudioPlayer -from .recorder import VoiceRecorderClient -from .source import AudioSource +#from .recorder import VoiceRecorderClient from .state import VoiceConnectionState if TYPE_CHECKING: @@ -62,6 +61,8 @@ AfterCallback = Callable[[Exception | None], Any] P = ParamSpec("P") +_log = logging.getLogger(__name__) + has_nacl: bool try: @@ -133,9 +134,9 @@ def __init__( self._connection: VoiceConnectionState = self.create_connection_state() # voice recv things - self._recorder: VoiceRecorderClient | None = None - if use_recorder: - self._recorder = VoiceRecorderClient(self) + #self._recorder: VoiceRecorderClient | None = None + #if use_recorder: + # self._recorder = VoiceRecorderClient(self) warn_nacl: bool = not has_nacl supported_modes: tuple[SupportedModes, ...] = ( @@ -195,7 +196,7 @@ def checked_add(self, attr: str, value: int, limit: int) -> None: setattr(self, attr, val + value) def create_connection_state(self) -> VoiceConnectionState: - return VoiceConnectionState(self, hook=self._recorder) + return VoiceConnectionState(self) async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: await self._connection.voice_state_update(data) @@ -325,7 +326,7 @@ def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - def _encrypt_aead_xcacha20_poly1305_rtpsize( + def _encrypt_aead_xchacha20_poly1305_rtpsize( self, header: bytes, data: Any ) -> bytes: box = nacl.secret.Aead(bytes(self.secret_key)) @@ -501,6 +502,7 @@ def play( signal_type=signal_type, ) + future = None if wait_finish: self._player_future = future = self.loop.create_future() after_callback = after @@ -528,82 +530,70 @@ def stop(self) -> None: self._player = None self._player_future = None - def unpack_audio(self, data: bytes) -> bytes | None: - """Takes an audio packet received from Discord and decodes it into PCM Audio data. - If there are no users talking in the channel, ``None`` will be returned. + def pause(self) -> None: + """Pauses the audio playing.""" + if self._player: + self._player.pause() - You must be connected to receive audio. + def resume(self) -> None: + """Resumes the audio playing.""" + if self._player: + self._player.resume() - .. versionadded:: 2.0 + @property + def source(self) -> AudioSource | None: + """The audio source being player, if playing. - Parameters - ---------- - data: :class:`bytes` - Bytes received by Discord via de UDP connection used for sending and receiving voice data. + This property can also be used to change the audio source currently being played. """ + return self._player and self._player.source - if not len(data) > 2: - return None + @source.setter + def source(self, value: AudioSource) -> None: + if not isinstance(value, AudioSource): + raise TypeError(f'expected AudioSource, not {value.__class__.__name__}') - if data[1] != 0x78: - # We Should Ignore Any Payload Types We Do Not Understand - # Ref RFC 3550 5.1 payload type - # At Some Point We Noted That We Should Ignore Only Types 200 - 204 inclusive. - # They Were Marked As RTCP: Provides Information About The Connection - # This Was Too Broad Of A Whitelist, It Is Unclear If This Is Too Narrow Of A Whitelist - return None - if self.paused: - return None + if self._player is None: + raise ValueError('the client is not playing anything') - raw = RawData(data, self) + self._player._set_source(value) - if raw.decrypted_data == opus.OPUS_SILENCE: # silenece frame - return None + def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: + """Sends an audio packet composed of the ``data``. - return self.decoder.decode(raw) - - def start_recording( - self, - sink: Sink, - callback: Callable[P, Coroutine[Any, Any, Any]], - sync_start: bool = False, - *callback_args: P.args, - **callback_kwargs: P.kwargs, - ): - r"""Start recording audio from the current voice channel. This function uses - a thread so the current code line will not be stopped. You must be in a voice - channel to use this, and must not be already recording. - - .. versionadded:: 2.0 + You must be connected to play audio. Parameters ---------- - sink: :class:`~discord.Sink` - A Sink which will "store" all the audio data. - callback: :ref:`coroutine ` - A function which is called after the bot has stopped recording. - sync_start: :class:`bool` - If ``True``, the recordings of subsequent users will start with silence. This - is useful for recording audio just as it was heard. - \*callback_args - Arguments that will be passed to the callback function. - \*\*callback_kwargs - Keyword arguments that will be passed to the callback function. + data: :class:`bytes` + The :term:`py:bytes-like object` denoting PCM or Opus voice data. + encode: :class:`bool` + Indicates if ``data`` should be encoded into Opus. Raises ------ - RecordingException - Not connected to a voice channel, or you are already recording, or you - did not provide a Sink object. + ClientException + You are not connected. + opus.OpusError + Encoding the data failed. """ - if not self.is_connected(): - raise RecordingException("Not connected to a voice channel") - if self.recording: - raise RecordingException("You are already recording") - if not isinstance(sink, Sink): - raise RecordingException( - f"Expected a Sink object, got {sink.__class__.__name__}" - ) + self.checked_add('sequence', 1, 65535) + if encode: + encoded = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) + else: + encoded = data + + packet = self._get_voice_packet(encoded) + try: + self._connection.send_packet(packet) + except OSError: + _log.debug('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) - self._recording_handler.empty() + self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) + + def elapsed(self) -> datetime.timedelta: + """Returns the elapsed time of the playing audio.""" + if self._player: + return datetime.timedelta(milliseconds=self._player.played_frames() * 20) + return datetime.timedelta() diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 10b5f0d3ba..30ceebeced 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -67,10 +67,10 @@ def __init__( **kwargs, name=name, daemon=daemon, + ws=ws, + interval=interval, ) - self.ws: VoiceWebSocket = ws - self.interval: float | None = interval self.msg: str = "Keeping shard ID %s voice websocket alive with timestamp %s." self.block_msg: str = ( "Shard ID %s voice heartbeat blocked for more than %s seconds." @@ -112,14 +112,28 @@ def __init__( self._close_code: int | None = None self.secret_key: list[int] | None = None self.seq_ack: int = -1 - self.session_id: str | None = None self.state: VoiceConnectionState = state self.ssrc_map: dict[str, dict[str, Any]] = {} - self.token: str | None = None if hook: self._hook = hook # type: ignore + @property + def token(self) -> str | None: + return self.state.token + + @token.setter + def token(self, value: str | None) -> None: + self.state.token = value + + @property + def session_id(self) -> str | None: + return self.state.session_id + + @session_id.setter + def session_id(self, value: str | None) -> None: + self.state.session_id = value + async def _hook(self, *args: Any) -> Any: pass @@ -144,8 +158,8 @@ async def resume(self) -> None: async def received_message(self, msg: Any, /): _log.debug("Voice websocket frame received: %s", msg) op = msg["op"] - data = msg.get("data", {}) # this key should ALWAYS be given, but guard anyways - self.seq_ack = data.get("seq", self.seq_ack) # keep the seq_ack updated + data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways + self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated if op == OpCodes.ready: await self.ready(data) @@ -163,6 +177,7 @@ async def received_message(self, msg: Any, /): ) elif op == OpCodes.session_description: self.state.mode = data["mode"] + await self.load_secret_key(data) elif op == OpCodes.hello: interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler( @@ -186,14 +201,41 @@ async def ready(self, data: dict[str, Any]) -> None: await self.loop.sock_connect( state.socket, - (state.endpoint_id, state.voice_port), + (state.endpoint_ip, state.voice_port), + ) + + _log.debug( + 'Connected socket to %s (port %s)', + state.endpoint_ip, + state.voice_port, ) state.ip, state.port = await self.get_ip() + modes = [mode for mode in data['modes'] if mode in self.state.supported_modes] + _log.debug('Received available voice connection modes: %s', modes) + + mode = modes[0] + await self.select_protocol(state.ip, state.port, mode) + _log.debug('Selected voice protocol %s for this connection', mode) + + async def select_protocol(self, ip: str, port: int, mode: str) -> None: + payload = { + 'op': int(OpCodes.select_protocol), + 'd': { + 'protocol': 'udp', + 'data': { + 'address': ip, + 'port': port, + 'mode': mode, + }, + }, + } + await self.send_as_json(payload) + async def get_ip(self) -> tuple[str, int]: state = self.state - packet = bytearray(75) + packet = bytearray(74) struct.pack_into(">H", packet, 0, 1) # 1 = Send struct.pack_into(">H", packet, 2, 70) # 70 = Length struct.pack_into(">I", packet, 4, state.ssrc) @@ -206,7 +248,7 @@ async def get_ip(self) -> tuple[str, int]: fut: asyncio.Future[bytes] = self.loop.create_future() def get_ip_packet(data: bytes) -> None: - if data[0] == 0x02 and len(data) == 74: + if data[1] == 0x02 and len(data) == 74: self.loop.call_soon_threadsafe(fut.set_result, data) fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) @@ -300,3 +342,16 @@ async def from_state( else: await ws.identify() return ws + + async def identify(self) -> None: + state = self.state + payload = { + 'op': int(OpCodes.identify), + 'd': { + 'server_id': str(state.server_id), + 'user_id': str(state.user.id), + 'session_id': self.session_id, + 'token': self.token, + }, + } + await self.send_as_json(payload) diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py index 321e22982c..0838f0979e 100644 --- a/discord/voice/recorder.py +++ b/discord/voice/recorder.py @@ -73,4 +73,5 @@ async def hook(self, ws: VoiceWebSocket, data: dict[str, Any]) -> None: def record( self, sink: Sink, - ) + ) -> int: + ... diff --git a/discord/voice/state.py b/discord/voice/state.py index d5a14827de..f1b797ab34 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -187,6 +187,8 @@ def __init__( self.mode: SupportedModes = MISSING self.socket: socket.socket = MISSING self.ws: VoiceWebSocket = MISSING + self.session_id: str | None = None + self.token: str | None = None self._state: ConnectionFlowState = ConnectionFlowState.disconnected self._expecting_disconnect: bool = False @@ -240,22 +242,6 @@ def supported_modes(self) -> tuple[SupportedModes, ...]: def self_voice_state(self) -> VoiceState | None: return self.guild.me.voice - @property - def token(self) -> str | None: - return self.ws.token - - @token.setter - def token(self, token: str | None) -> None: - self.ws.token = token - - @property - def session_id(self) -> str | None: - return self.ws.session_id - - @session_id.setter - def session_id(self, value: str | None) -> None: - self.ws.session_id = value - def is_connected(self) -> bool: return self.state is ConnectionFlowState.connected @@ -275,7 +261,7 @@ async def voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: await self.disconnect() return - self.ws.session_id = data["session_id"] + self.session_id = data["session_id"] if self.state in ( ConnectionFlowState.set_guild_voice_state, @@ -493,6 +479,7 @@ async def disconnect( if not force and not self.is_connected(): return + _log.debug('Attempting a voice disconnect for channel %s (guild %s)', self.channel_id, self.guild_id) try: await self._voice_disconnect() if self.ws: diff --git a/discord/voice_client.py b/discord/voice_client.py deleted file mode 100644 index 01d158287e..0000000000 --- a/discord/voice_client.py +++ /dev/null @@ -1,932 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - -Some documentation to refer to: - -- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. -- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. -- We pull the session_id from VOICE_STATE_UPDATE. -- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. -- Then we initiate the voice web socket (vWS) pointing to the endpoint. -- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. -- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. -- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. -- Then we send our IP and port via vWS with opcode 1. -- When that's all done, we receive opcode 4 from the vWS. -- Finally we can transmit data to endpoint:port. -""" - -from __future__ import annotations - -import asyncio -import datetime -import logging -import select -import socket -import struct -import threading -import time -from typing import TYPE_CHECKING, Any, Callable, Literal, overload - -from . import opus, utils -from .backoff import ExponentialBackoff -from .errors import ClientException, ConnectionClosed -from .gateway import * -from .player import AudioPlayer, AudioSource -from .sinks import RawData, RecordingException, Sink -from .utils import MISSING - -if TYPE_CHECKING: - from . import abc - from .client import Client - from .guild import Guild - from .opus import Encoder - from .state import ConnectionState - from .types.voice import GuildVoiceState as GuildVoiceStatePayload - from .types.voice import SupportedModes - from .types.voice import VoiceServerUpdate as VoiceServerUpdatePayload - from .user import ClientUser - - -has_nacl: bool - -try: - import nacl.secret # type: ignore - - has_nacl = True -except ImportError: - has_nacl = False - -__all__ = ( - "VoiceProtocol", - "VoiceClient", -) - - -_log = logging.getLogger(__name__) - - -class VoiceClient(VoiceProtocol): - """Represents a Discord voice connection. - - You do not create these, you typically get them from - e.g. :meth:`VoiceChannel.connect`. - - Attributes - ---------- - session_id: :class:`str` - The voice connection session ID. - token: :class:`str` - The voice connection token. - endpoint: :class:`str` - The endpoint we are connecting to. - channel: :class:`abc.Connectable` - The voice channel connected to. - loop: :class:`asyncio.AbstractEventLoop` - The event loop that the voice client is running on. - - Warning - ------- - In order to use PCM based AudioSources, you must have the opus library - installed on your system and loaded through :func:`opus.load_opus`. - Otherwise, your AudioSources must be opus encoded (e.g. using :class:`FFmpegOpusAudio`) - or the library will not be able to transmit audio. - """ - - endpoint_ip: str - voice_port: int - secret_key: list[int] - ssrc: int - - def __init__(self, client: Client, channel: abc.Connectable): - if not has_nacl: - raise RuntimeError("PyNaCl library needed in order to use voice") - - super().__init__(client, channel) - state = client._connection - self.token: str = MISSING - self.socket = MISSING - self.loop: asyncio.AbstractEventLoop = state.loop - self._state: ConnectionState = state - # this will be used in the AudioPlayer thread - self._connected: threading.Event = threading.Event() - - self._handshaking: bool = False - self._potentially_reconnecting: bool = False - self._voice_state_complete: asyncio.Event = asyncio.Event() - self._voice_server_complete: asyncio.Event = asyncio.Event() - - self.mode: str = MISSING - self._connections: int = 0 - self.sequence: int = 0 - self.timestamp: int = 0 - self.timeout: float = 0 - self._runner: asyncio.Task = MISSING - self._player: AudioPlayer | None = None - self.encoder: Encoder = MISSING - self.decoder = None - self._lite_nonce: int = 0 - self.ws: DiscordVoiceWebSocket = MISSING - - self.paused = False - self.recording = False - self.user_timestamps = {} - self.sink = None - self.starting_time = None - self.stopping_time = None - - warn_nacl = not has_nacl - supported_modes: tuple[SupportedModes, ...] = ( - "xsalsa20_poly1305_lite", - "xsalsa20_poly1305_suffix", - "xsalsa20_poly1305", - "aead_xchacha20_poly1305_rtpsize", - ) - - @property - def guild(self) -> Guild | None: - """The guild we're connected to, if applicable.""" - return getattr(self.channel, "guild", None) - - @property - def user(self) -> ClientUser: - """The user connected to voice (i.e. ourselves).""" - return self._state.user - - def checked_add(self, attr, value, limit): - val = getattr(self, attr) - if val + value > limit: - setattr(self, attr, 0) - else: - setattr(self, attr, val + value) - - # connection related - - async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: - self.session_id = data["session_id"] - channel_id = data["channel_id"] - - if not self._handshaking or self._potentially_reconnecting: - # If we're done handshaking then we just need to update ourselves - # If we're potentially reconnecting due to a 4014, then we need to differentiate - # a channel move and an actual force disconnect - if channel_id is None: - # We're being disconnected so cleanup - await self.disconnect() - else: - guild = self.guild - self.channel = channel_id and guild and guild.get_channel(int(channel_id)) # type: ignore - else: - self._voice_state_complete.set() - - async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: - if self._voice_server_complete.is_set(): - _log.info("Ignoring extraneous voice server update.") - return - - self.token = data.get("token") - self.server_id = int(data["guild_id"]) - endpoint = data.get("endpoint") - - if endpoint is None or self.token is None: - _log.warning( - "Awaiting endpoint... This requires waiting. " - "If timeout occurred considering raising the timeout and reconnecting." - ) - return - - self.endpoint = endpoint.removeprefix("wss://") - # This gets set later - self.endpoint_ip = MISSING - - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setblocking(False) - - if not self._handshaking: - # If we're not handshaking then we need to terminate our previous connection in the websocket - await self.ws.close(4000) - return - - self._voice_server_complete.set() - - async def voice_connect(self) -> None: - await self.guild.change_voice_state(channel=self.channel) - - async def voice_disconnect(self) -> None: - _log.info( - "The voice handshake is being terminated for Channel ID %s (Guild ID %s)", - self.channel.id, - self.guild.id, - ) - await self.channel.guild.change_voice_state(channel=None) - - def prepare_handshake(self) -> None: - self._voice_state_complete.clear() - self._voice_server_complete.clear() - self._handshaking = True - _log.info( - "Starting voice handshake... (connection attempt %d)", self._connections + 1 - ) - self._connections += 1 - - def finish_handshake(self) -> None: - _log.info("Voice handshake complete. Endpoint found %s", self.endpoint) - self._handshaking = False - self._voice_server_complete.clear() - self._voice_state_complete.clear() - - async def connect_websocket(self) -> DiscordVoiceWebSocket: - ws = await DiscordVoiceWebSocket.from_client(self) - self._connected.clear() - while ws.secret_key is None: - await ws.poll_event() - self._connected.set() - return ws - - async def connect(self, *, reconnect: bool, timeout: float) -> None: - _log.info("Connecting to voice...") - self.timeout = timeout - - for i in range(5): - self.prepare_handshake() - - # This has to be created before we start the flow. - futures = [ - self._voice_state_complete.wait(), - self._voice_server_complete.wait(), - ] - - # Start the connection flow - await self.voice_connect() - - try: - await utils.sane_wait_for(futures, timeout=timeout) - except asyncio.TimeoutError: - await self.disconnect(force=True) - raise - - self.finish_handshake() - - try: - self.ws = await self.connect_websocket() - break - except (ConnectionClosed, asyncio.TimeoutError): - if reconnect: - _log.exception("Failed to connect to voice... Retrying...") - await asyncio.sleep(1 + i * 2.0) - await self.voice_disconnect() - continue - else: - raise - - if self._runner is MISSING: - self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) - - async def potential_reconnect(self) -> bool: - # Attempt to stop the player thread from playing early - self._connected.clear() - self.prepare_handshake() - self._potentially_reconnecting = True - try: - # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for( - self._voice_server_complete.wait(), timeout=self.timeout - ) - except asyncio.TimeoutError: - self._potentially_reconnecting = False - await self.disconnect(force=True) - return False - - self.finish_handshake() - self._potentially_reconnecting = False - try: - self.ws = await self.connect_websocket() - except (ConnectionClosed, asyncio.TimeoutError): - return False - else: - return True - - @property - def latency(self) -> float: - """Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. - - This could be referred to as the Discord Voice WebSocket latency and is - an analogue of user's voice latencies as seen in the Discord client. - - .. versionadded:: 1.4 - """ - ws = self.ws - return float("inf") if not ws else ws.latency - - @property - def average_latency(self) -> float: - """Average of most recent 20 HEARTBEAT latencies in seconds. - - .. versionadded:: 1.4 - """ - ws = self.ws - return float("inf") if not ws else ws.average_latency - - async def poll_voice_ws(self, reconnect: bool) -> None: - backoff = ExponentialBackoff() - while True: - try: - await self.ws.poll_event() - except (ConnectionClosed, asyncio.TimeoutError) as exc: - if isinstance(exc, ConnectionClosed): - # The following close codes are undocumented, so I will document them here. - # 1000 - normal closure (obviously) - # 4014 - voice channel has been deleted. - # 4015 - voice server has crashed, we should resume - if exc.code == 1000: - _log.info( - "Disconnecting from voice normally, close code %d.", - exc.code, - ) - await self.disconnect() - break - if exc.code == 4014: - _log.info( - "Disconnected from voice by force... potentially" - " reconnecting." - ) - successful = await self.potential_reconnect() - if successful: - continue - - _log.info( - "Reconnect was unsuccessful, disconnecting from voice" - " normally..." - ) - await self.disconnect() - break - if exc.code == 4015: - _log.info("Disconnected from voice, trying to resume...") - - try: - await self.ws.resume() - except asyncio.TimeoutError: - _log.info( - "Could not resume the voice connection... Disconnection..." - ) - if self._connected.is_set(): - await self.disconnect(force=True) - else: - _log.info("Successfully resumed voice connection") - continue - - if not reconnect: - await self.disconnect() - raise - - retry = backoff.delay() - _log.exception( - "Disconnected from voice... Reconnecting in %.2fs.", retry - ) - self._connected.clear() - await asyncio.sleep(retry) - await self.voice_disconnect() - try: - await self.connect(reconnect=True, timeout=self.timeout) - except asyncio.TimeoutError: - # at this point we've retried 5 times... let's continue the loop. - _log.warning("Could not connect to voice... Retrying...") - continue - - async def disconnect(self, *, force: bool = False) -> None: - """|coro| - - Disconnects this voice client from voice. - """ - if not force and not self.is_connected(): - return - - self.stop() - self._connected.clear() - - try: - if self.ws: - await self.ws.close() - - await self.voice_disconnect() - finally: - self.cleanup() - if self.socket: - self.socket.close() - - async def move_to(self, channel: abc.Connectable) -> None: - """|coro| - - Moves you to a different voice channel. - - Parameters - ---------- - channel: :class:`abc.Connectable` - The channel to move to. Must be a voice channel. - """ - await self.channel.guild.change_voice_state(channel=channel) - - def is_connected(self) -> bool: - """Indicates if the voice client is connected to voice.""" - return self._connected.is_set() - - # audio related - - def _get_voice_packet(self, data): - header = bytearray(12) - - # Formulate rtp header - header[0] = 0x80 - header[1] = 0x78 - struct.pack_into(">H", header, 2, self.sequence) - struct.pack_into(">I", header, 4, self.timestamp) - struct.pack_into(">I", header, 8, self.ssrc) - - encrypt_packet = getattr(self, f"_encrypt_{self.mode}") - return encrypt_packet(header, data) - - def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = bytearray(24) - nonce[:12] = header - - return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext - - def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) - - return header + box.encrypt(bytes(data), nonce).ciphertext + nonce - - def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes: - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - nonce = bytearray(24) - - nonce[:4] = struct.pack(">I", self._lite_nonce) - self.checked_add("_lite_nonce", 1, 4294967295) - - return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] - - def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes: - # Required as of Nov 18 2024 - box = nacl.secret.Aead(bytes(self.secret_key)) - nonce = bytearray(24) - - nonce[:4] = struct.pack(">I", self._lite_nonce) - self.checked_add("_lite_nonce", 1, 4294967295) - - return ( - header - + box.encrypt(bytes(data), bytes(header), bytes(nonce)).ciphertext - + nonce[:4] - ) - - def _decrypt_xsalsa20_poly1305(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:12] = header - - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) - - def _decrypt_xsalsa20_poly1305_suffix(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce_size = nacl.secret.SecretBox.NONCE_SIZE - nonce = data[-nonce_size:] - - return self.strip_header_ext(box.decrypt(bytes(data[:-nonce_size]), nonce)) - - def _decrypt_xsalsa20_poly1305_lite(self, header, data): - # Deprecated, remove in 2.7 - box = nacl.secret.SecretBox(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:4] = data[-4:] - data = data[:-4] - - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) - - def _decrypt_aead_xchacha20_poly1305_rtpsize(self, header, data): - # Required as of Nov 18 2024 - box = nacl.secret.Aead(bytes(self.secret_key)) - - nonce = bytearray(24) - nonce[:4] = data[-4:] - data = data[:-4] - - return self.strip_header_ext( - box.decrypt(bytes(data), bytes(header), bytes(nonce)) - ) - - @staticmethod - def strip_header_ext(data): - if len(data) > 4 and data[0] == 0xBE and data[1] == 0xDE: - _, length = struct.unpack_from(">HH", data) - offset = 4 + length * 4 - data = data[offset:] - return data - - def get_ssrc(self, user_id): - return {info["user_id"]: ssrc for ssrc, info in self.ws.ssrc_map.items()}[ - user_id - ] - - @overload - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: Literal[False] = False, - ) -> None: ... - - @overload - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: Literal[True], - ) -> asyncio.Future: ... - - def play( - self, - source: AudioSource, - *, - after: Callable[[Exception | None], Any] | None = None, - wait_finish: bool = False, - ) -> None | asyncio.Future: - """Plays an :class:`AudioSource`. - - The finalizer, ``after`` is called after the source has been exhausted - or an error occurred. - - If an error happens while the audio player is running, the exception is - caught and the audio player is then stopped. If no after callback is - passed, any caught exception will be displayed as if it were raised. - - Parameters - ---------- - source: :class:`AudioSource` - The audio source we're reading from. - after: Callable[[Optional[:class:`Exception`]], Any] - The finalizer that is called after the stream is exhausted. - This function must have a single parameter, ``error``, that - denotes an optional exception that was raised during playing. - wait_finish: bool - If True, an awaitable will be returned, which can be used to wait for - audio to stop playing. This awaitable will return an exception if raised, - or None when no exception is raised. - - If False, None is returned and the function does not block. - - .. versionadded:: v2.5 - - Raises - ------ - ClientException - Already playing audio or not connected. - TypeError - Source is not a :class:`AudioSource` or after is not a callable. - OpusNotLoaded - Source is not opus encoded and opus is not loaded. - """ - - if not self.is_connected(): - raise ClientException("Not connected to voice.") - - if self.is_playing(): - raise ClientException("Already playing audio.") - - if not isinstance(source, AudioSource): - raise TypeError( - f"source must be an AudioSource not {source.__class__.__name__}" - ) - - if not self.encoder and not source.is_opus(): - self.encoder = opus.Encoder() - - future = None - if wait_finish: - future = asyncio.Future() - after_callback = after - - def _after(exc: Exception | None): - if callable(after_callback): - after_callback(exc) - - future.set_result(exc) - - after = _after - - self._player = AudioPlayer(source, self, after=after) - self._player.start() - return future - - def unpack_audio(self, data): - """Takes an audio packet received from Discord and decodes it into pcm audio data. - If there are no users talking in the channel, `None` will be returned. - - You must be connected to receive audio. - - .. versionadded:: 2.0 - - Parameters - ---------- - data: :class:`bytes` - Bytes received by Discord via the UDP connection used for sending and receiving voice data. - """ - if data[1] != 0x78: - # We Should Ignore Any Payload Types We Do Not Understand - # Ref RFC 3550 5.1 payload type - # At Some Point We Noted That We Should Ignore Only Types 200 - 204 inclusive. - # They Were Marked As RTCP: Provides Information About The Connection - # This Was Too Broad Of A Whitelist, It Is Unclear If This Is Too Narrow Of A Whitelist - return - if self.paused: - return - - data = RawData(data, self) - - if data.decrypted_data == b"\xf8\xff\xfe": # Frame of silence - return - - self.decoder.decode(data) - - def start_recording(self, sink, callback, *args, sync_start: bool = False): - """The bot will begin recording audio from the current voice channel it is in. - This function uses a thread so the current code line will not be stopped. - Must be in a voice channel to use. - Must not be already recording. - - .. versionadded:: 2.0 - - Parameters - ---------- - sink: :class:`.Sink` - A Sink which will "store" all the audio data. - callback: :ref:`coroutine ` - A function which is called after the bot has stopped recording. - *args: - Args which will be passed to the callback function. - sync_start: :class:`bool` - If True, the recordings of subsequent users will start with silence. - This is useful for recording audio just as it was heard. - - Raises - ------ - RecordingException - Not connected to a voice channel. - RecordingException - Already recording. - RecordingException - Must provide a Sink object. - """ - if not self.is_connected(): - raise RecordingException("Not connected to voice channel.") - if self.recording: - raise RecordingException("Already recording.") - if not isinstance(sink, Sink): - raise RecordingException("Must provide a Sink object.") - - self.empty_socket() - - self.decoder = opus.DecodeManager(self) - self.decoder.start() - self.recording = True - self.sync_start = sync_start - self.sink = sink - sink.init(self) - - t = threading.Thread( - target=self.recv_audio, - args=( - sink, - callback, - *args, - ), - ) - t.start() - - def stop_recording(self): - """Stops the recording. - Must be already recording. - - .. versionadded:: 2.0 - - Raises - ------ - RecordingException - Not currently recording. - """ - if not self.recording: - raise RecordingException("Not currently recording audio.") - self.decoder.stop() - self.recording = False - self.paused = False - - def toggle_pause(self): - """Pauses or unpauses the recording. - Must be already recording. - - .. versionadded:: 2.0 - - Raises - ------ - RecordingException - Not currently recording. - """ - if not self.recording: - raise RecordingException("Not currently recording audio.") - self.paused = not self.paused - - def empty_socket(self): - while True: - ready, _, _ = select.select([self.socket], [], [], 0.0) - if not ready: - break - for s in ready: - s.recv(4096) - - def recv_audio(self, sink, callback, *args): - # Gets data from _recv_audio and sorts - # it by user, handles pcm files and - # silence that should be added. - - self.user_timestamps: dict[int, tuple[int, float]] = {} - self.starting_time = time.perf_counter() - self.first_packet_timestamp: float - while self.recording: - ready, _, err = select.select([self.socket], [], [self.socket], 0.01) - if not ready: - if err: - print(f"Socket error: {err}") - continue - - try: - data = self.socket.recv(4096) - except OSError: - self.stop_recording() - continue - - self.unpack_audio(data) - - self.stopping_time = time.perf_counter() - self.sink.cleanup() - callback = asyncio.run_coroutine_threadsafe(callback(sink, *args), self.loop) - result = callback.result() - - if result is not None: - print(result) - - def recv_decoded_audio(self, data: RawData): - # Add silence when they were not being recorded. - if data.ssrc not in self.user_timestamps: # First packet from user - if ( - not self.user_timestamps or not self.sync_start - ): # First packet from anyone - self.first_packet_timestamp = data.receive_time - silence = 0 - - else: # Previously received a packet from someone else - silence = ( - (data.receive_time - self.first_packet_timestamp) * 48000 - ) - 960 - - else: # Previously received a packet from user - dRT = ( - data.receive_time - self.user_timestamps[data.ssrc][1] - ) * 48000 # delta receive time - dT = data.timestamp - self.user_timestamps[data.ssrc][0] # delta timestamp - diff = abs(100 - dT * 100 / dRT) - if ( - diff > 60 and dT != 960 - ): # If the difference in change is more than 60% threshold - silence = dRT - 960 - else: - silence = dT - 960 - - self.user_timestamps.update({data.ssrc: (data.timestamp, data.receive_time)}) - - data.decoded_data = ( - struct.pack(" bool: - """Indicates if we're currently playing audio.""" - return self._player is not None and self._player.is_playing() - - def is_paused(self) -> bool: - """Indicates if we're playing audio, but if we're paused.""" - return self._player is not None and self._player.is_paused() - - def stop(self) -> None: - """Stops playing audio.""" - if self._player: - self._player.stop() - self._player = None - - def pause(self) -> None: - """Pauses the audio playing.""" - if self._player: - self._player.pause() - - def resume(self) -> None: - """Resumes the audio playing.""" - if self._player: - self._player.resume() - - @property - def source(self) -> AudioSource | None: - """The audio source being played, if playing. - - This property can also be used to change the audio source currently being played. - """ - return self._player.source if self._player else None - - @source.setter - def source(self, value: AudioSource) -> None: - if not isinstance(value, AudioSource): - raise TypeError(f"expected AudioSource not {value.__class__.__name__}.") - - if self._player is None: - raise ValueError("Not playing anything.") - - self._player._set_source(value) - - def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: - """Sends an audio packet composed of the data. - - You must be connected to play audio. - - Parameters - ---------- - data: :class:`bytes` - The :term:`py:bytes-like object` denoting PCM or Opus voice data. - encode: :class:`bool` - Indicates if ``data`` should be encoded into Opus. - - Raises - ------ - ClientException - You are not connected. - opus.OpusError - Encoding the data failed. - """ - - self.checked_add("sequence", 1, 65535) - if encode: - if not self.encoder: - self.encoder = opus.Encoder() - encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) - else: - encoded_data = data - packet = self._get_voice_packet(encoded_data) - try: - self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) - except BlockingIOError: - _log.warning( - "A packet has been dropped (seq: %s, timestamp: %s)", - self.sequence, - self.timestamp, - ) - - self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295) - - def elapsed(self) -> datetime.timedelta: - """Returns the elapsed time of the playing audio.""" - if self._player: - return datetime.timedelta(milliseconds=self._player.played_frames() * 20) - return datetime.timedelta() From 870d135fe1532715cac089d88c4ed6c5e4c5b418 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:02:14 +0000 Subject: [PATCH 11/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/client.py | 5 ++--- discord/player.py | 12 +++++------ discord/utils.py | 45 ++++++++++++++++++++++----------------- discord/voice/client.py | 23 ++++++++++++-------- discord/voice/gateway.py | 34 ++++++++++++++--------------- discord/voice/recorder.py | 8 +++---- discord/voice/state.py | 6 +++++- 7 files changed, 72 insertions(+), 61 deletions(-) diff --git a/discord/client.py b/discord/client.py index 1c936099b8..2129c10578 100644 --- a/discord/client.py +++ b/discord/client.py @@ -27,7 +27,6 @@ import asyncio import logging -import signal import sys import traceback from types import TracebackType @@ -467,7 +466,7 @@ def _schedule_event( return task def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - logging.getLogger('discord.state').debug("Dispatching event %s", event) + logging.getLogger("discord.state").debug("Dispatching event %s", event) method = f"on_{event}" listeners = self._listeners.get(event) @@ -822,7 +821,7 @@ def run( is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop + self.loop async def runner(): async with self: diff --git a/discord/player.py b/discord/player.py index ae345be9b8..3bf74a0a8f 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,14 +36,14 @@ import sys import threading import time -import traceback from math import floor from typing import IO, TYPE_CHECKING, Any, Callable, Generic, TypeVar -from .errors import ClientException from .enums import SpeakingState +from .errors import ClientException from .oggparse import OggStream -from .opus import Encoder as OpusEncoder, OPUS_SILENCE +from .opus import OPUS_SILENCE +from .opus import Encoder as OpusEncoder from .utils import MISSING if TYPE_CHECKING: @@ -766,13 +766,13 @@ def _do_run(self) -> None: # are we disconnected from voice? if not client.is_connected(): - _log.debug('Not connected, waiting for %ss...', client.timeout) + _log.debug("Not connected, waiting for %ss...", client.timeout) # wait until we are connected, but not forever connected = client.wait_until_connected(client.timeout) if self._end.is_set() or not connected: - _log.debug('Aborting playback') + _log.debug("Aborting playback") return - _log.debug('Reconnected, resuming playback') + _log.debug("Reconnected, resuming playback") self._speak(SpeakingState.voice) # reset our internal data self.loops = 0 diff --git a/discord/utils.py b/discord/utils.py index bab51d3687..6c8b4bfaea 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1413,24 +1413,26 @@ def filter_params(params, **kwargs): def is_docker() -> bool: - path = '/proc/self/cgroup' - return os.path.exists('/.dockerenv') or (os.path.isfile(path) and any('docker' in line for line in open(path))) + path = "/proc/self/cgroup" + return os.path.exists("/.dockerenv") or ( + os.path.isfile(path) and any("docker" in line for line in open(path)) + ) def stream_supports_colour(stream: Any) -> bool: - is_a_tty = hasattr(stream, 'isatty') and stream.isatty() + is_a_tty = hasattr(stream, "isatty") and stream.isatty() # Pycharm and Vscode support colour in their inbuilt editors - if 'PYCHARM_HOSTED' in os.environ or os.environ.get('TERM_PROGRAM') == 'vscode': + if "PYCHARM_HOSTED" in os.environ or os.environ.get("TERM_PROGRAM") == "vscode": return is_a_tty - if sys.platform != 'win32': + if sys.platform != "win32": # Docker does not consistently have a tty attached to it return is_a_tty or is_docker() # ANSICON checks for things like ConEmu # WT_SESSION checks if this is Windows Terminal - return is_a_tty and ('ANSICON' in os.environ or 'WT_SESSION' in os.environ) + return is_a_tty and ("ANSICON" in os.environ or "WT_SESSION" in os.environ) class _ColourFormatter(logging.Formatter): @@ -1444,17 +1446,17 @@ class _ColourFormatter(logging.Formatter): # 1 means bold, 2 means dim, 0 means reset, and 4 means underline. LEVEL_COLOURS = [ - (logging.DEBUG, '\x1b[40;1m'), - (logging.INFO, '\x1b[34;1m'), - (logging.WARNING, '\x1b[33;1m'), - (logging.ERROR, '\x1b[31m'), - (logging.CRITICAL, '\x1b[41m'), + (logging.DEBUG, "\x1b[40;1m"), + (logging.INFO, "\x1b[34;1m"), + (logging.WARNING, "\x1b[33;1m"), + (logging.ERROR, "\x1b[31m"), + (logging.CRITICAL, "\x1b[41m"), ] FORMATS = { level: logging.Formatter( - f'\x1b[30;1m%(asctime)s\x1b[0m {colour}%(levelname)-8s\x1b[0m \x1b[35m%(name)s\x1b[0m %(message)s', - '%Y-%m-%d %H:%M:%S', + f"\x1b[30;1m%(asctime)s\x1b[0m {colour}%(levelname)-8s\x1b[0m \x1b[35m%(name)s\x1b[0m %(message)s", + "%Y-%m-%d %H:%M:%S", ) for level, colour in LEVEL_COLOURS } @@ -1467,7 +1469,7 @@ def format(self, record): # Override the traceback to always print in red if record.exc_info: text = formatter.formatException(record.exc_info) - record.exc_text = f'\x1b[31m{text}\x1b[0m' + record.exc_text = f"\x1b[31m{text}\x1b[0m" output = formatter.format(record) @@ -1483,8 +1485,7 @@ def setup_logging( level: int = MISSING, root: bool = True, ) -> None: - """A helper method to automatically setup the library's default logging. - """ + """A helper method to automatically setup the library's default logging.""" if level is MISSING: level = logging.INFO @@ -1493,16 +1494,20 @@ def setup_logging( handler = logging.StreamHandler() if formatter is MISSING: - if isinstance(handler, logging.StreamHandler) and stream_supports_colour(handler.stream): + if isinstance(handler, logging.StreamHandler) and stream_supports_colour( + handler.stream + ): formatter = _ColourFormatter() else: - dt_fmt = '%Y-%m-%d %H:%M:%S' - formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}', dt_fmt, style='{') + dt_fmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter( + "[{asctime}] [{levelname:<8}] {name}: {message}", dt_fmt, style="{" + ) if root: logger = logging.getLogger() else: - lib, _, _ = __name__.partition('.') + lib, _, _ = __name__.partition(".") logger = logging.getLogger(lib) handler.setFormatter(formatter) diff --git a/discord/voice/client.py b/discord/voice/client.py index 9581d11687..c52abaae13 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -34,11 +34,12 @@ from discord import opus from discord.errors import ClientException +from discord.player import AudioPlayer, AudioSource from discord.utils import MISSING -from discord.player import AudioSource, AudioPlayer from ._types import VoiceProtocol -#from .recorder import VoiceRecorderClient + +# from .recorder import VoiceRecorderClient from .state import VoiceConnectionState if TYPE_CHECKING: @@ -134,8 +135,8 @@ def __init__( self._connection: VoiceConnectionState = self.create_connection_state() # voice recv things - #self._recorder: VoiceRecorderClient | None = None - #if use_recorder: + # self._recorder: VoiceRecorderClient | None = None + # if use_recorder: # self._recorder = VoiceRecorderClient(self) warn_nacl: bool = not has_nacl @@ -551,10 +552,10 @@ def source(self) -> AudioSource | None: @source.setter def source(self, value: AudioSource) -> None: if not isinstance(value, AudioSource): - raise TypeError(f'expected AudioSource, not {value.__class__.__name__}') + raise TypeError(f"expected AudioSource, not {value.__class__.__name__}") if self._player is None: - raise ValueError('the client is not playing anything') + raise ValueError("the client is not playing anything") self._player._set_source(value) @@ -578,7 +579,7 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: Encoding the data failed. """ - self.checked_add('sequence', 1, 65535) + self.checked_add("sequence", 1, 65535) if encode: encoded = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) else: @@ -588,9 +589,13 @@ def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: try: self._connection.send_packet(packet) except OSError: - _log.debug('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) + _log.debug( + "A packet has been dropped (seq: %s, timestamp: %s)", + self.sequence, + self.timestamp, + ) - self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) + self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295) def elapsed(self) -> datetime.timedelta: """Returns the elapsed time of the playing audio.""" diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 30ceebeced..a6958722a7 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -205,29 +205,29 @@ async def ready(self, data: dict[str, Any]) -> None: ) _log.debug( - 'Connected socket to %s (port %s)', + "Connected socket to %s (port %s)", state.endpoint_ip, state.voice_port, ) state.ip, state.port = await self.get_ip() - modes = [mode for mode in data['modes'] if mode in self.state.supported_modes] - _log.debug('Received available voice connection modes: %s', modes) + modes = [mode for mode in data["modes"] if mode in self.state.supported_modes] + _log.debug("Received available voice connection modes: %s", modes) mode = modes[0] await self.select_protocol(state.ip, state.port, mode) - _log.debug('Selected voice protocol %s for this connection', mode) + _log.debug("Selected voice protocol %s for this connection", mode) async def select_protocol(self, ip: str, port: int, mode: str) -> None: payload = { - 'op': int(OpCodes.select_protocol), - 'd': { - 'protocol': 'udp', - 'data': { - 'address': ip, - 'port': port, - 'mode': mode, + "op": int(OpCodes.select_protocol), + "d": { + "protocol": "udp", + "data": { + "address": ip, + "port": port, + "mode": mode, }, }, } @@ -346,12 +346,12 @@ async def from_state( async def identify(self) -> None: state = self.state payload = { - 'op': int(OpCodes.identify), - 'd': { - 'server_id': str(state.server_id), - 'user_id': str(state.user.id), - 'session_id': self.session_id, - 'token': self.token, + "op": int(OpCodes.identify), + "d": { + "server_id": str(state.server_id), + "user_id": str(state.user.id), + "session_id": self.session_id, + "token": self.token, }, } await self.send_as_json(payload) diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py index 0838f0979e..5d59e0a953 100644 --- a/discord/voice/recorder.py +++ b/discord/voice/recorder.py @@ -39,7 +39,7 @@ from .client import VoiceClient from .gateway import VoiceWebSocket - VoiceClientT = TypeVar('VoiceClientT', bound=VoiceClient, covariant=True) + VoiceClientT = TypeVar("VoiceClientT", bound=VoiceClient, covariant=True) class VoiceRecorderClient(VoiceRecorderProtocol[VoiceClientT]): @@ -67,11 +67,9 @@ def is_recording(self) -> bool: """Whether the current recording is actively recording.""" return self._recording.is_set() - async def hook(self, ws: VoiceWebSocket, data: dict[str, Any]) -> None: - ... + async def hook(self, ws: VoiceWebSocket, data: dict[str, Any]) -> None: ... def record( self, sink: Sink, - ) -> int: - ... + ) -> int: ... diff --git a/discord/voice/state.py b/discord/voice/state.py index f1b797ab34..6884e47eba 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -479,7 +479,11 @@ async def disconnect( if not force and not self.is_connected(): return - _log.debug('Attempting a voice disconnect for channel %s (guild %s)', self.channel_id, self.guild_id) + _log.debug( + "Attempting a voice disconnect for channel %s (guild %s)", + self.channel_id, + self.guild_id, + ) try: await self._voice_disconnect() if self.ws: From 73d5c6ddc9624d11e274e7bc0b9790ed2c209f8b Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:03:29 +0200 Subject: [PATCH 12/40] actually i did this for logging --- discord/client.py | 14 ------- discord/utils.py | 98 ----------------------------------------------- 2 files changed, 112 deletions(-) diff --git a/discord/client.py b/discord/client.py index 1c936099b8..f9716ee247 100644 --- a/discord/client.py +++ b/discord/client.py @@ -794,10 +794,6 @@ def run( token: str, *, reconnect: bool = True, - log_handler: logging.Handler | None = MISSING, - log_formatter: logging.Formatter = MISSING, - log_level: int = MISSING, - root_logger: bool = False, ) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -822,20 +818,10 @@ def run( is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop async def runner(): async with self: await self.start(token, reconnect=reconnect) - - if log_handler is not None: - utils.setup_logging( - handler=log_handler, - formatter=log_formatter, - level=log_level, - root=root_logger, - ) - try: asyncio.run(runner()) except KeyboardInterrupt: diff --git a/discord/utils.py b/discord/utils.py index bab51d3687..5c0f17b613 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1410,101 +1410,3 @@ def filter_params(params, **kwargs): params[new_param] = params.pop(old_param) return params - - -def is_docker() -> bool: - path = '/proc/self/cgroup' - return os.path.exists('/.dockerenv') or (os.path.isfile(path) and any('docker' in line for line in open(path))) - - -def stream_supports_colour(stream: Any) -> bool: - is_a_tty = hasattr(stream, 'isatty') and stream.isatty() - - # Pycharm and Vscode support colour in their inbuilt editors - if 'PYCHARM_HOSTED' in os.environ or os.environ.get('TERM_PROGRAM') == 'vscode': - return is_a_tty - - if sys.platform != 'win32': - # Docker does not consistently have a tty attached to it - return is_a_tty or is_docker() - - # ANSICON checks for things like ConEmu - # WT_SESSION checks if this is Windows Terminal - return is_a_tty and ('ANSICON' in os.environ or 'WT_SESSION' in os.environ) - - -class _ColourFormatter(logging.Formatter): - # ANSI codes are a bit weird to decipher if you're unfamiliar with them, so here's a refresher - # It starts off with a format like \x1b[XXXm where XXX is a semicolon separated list of commands - # The important ones here relate to colour. - # 30-37 are black, red, green, yellow, blue, magenta, cyan and white in that order - # 40-47 are the same except for the background - # 90-97 are the same but "bright" foreground - # 100-107 are the same as the bright ones but for the background. - # 1 means bold, 2 means dim, 0 means reset, and 4 means underline. - - LEVEL_COLOURS = [ - (logging.DEBUG, '\x1b[40;1m'), - (logging.INFO, '\x1b[34;1m'), - (logging.WARNING, '\x1b[33;1m'), - (logging.ERROR, '\x1b[31m'), - (logging.CRITICAL, '\x1b[41m'), - ] - - FORMATS = { - level: logging.Formatter( - f'\x1b[30;1m%(asctime)s\x1b[0m {colour}%(levelname)-8s\x1b[0m \x1b[35m%(name)s\x1b[0m %(message)s', - '%Y-%m-%d %H:%M:%S', - ) - for level, colour in LEVEL_COLOURS - } - - def format(self, record): - formatter = self.FORMATS.get(record.levelno) - if formatter is None: - formatter = self.FORMATS[logging.DEBUG] - - # Override the traceback to always print in red - if record.exc_info: - text = formatter.formatException(record.exc_info) - record.exc_text = f'\x1b[31m{text}\x1b[0m' - - output = formatter.format(record) - - # Remove the cache layer - record.exc_text = None - return output - - -def setup_logging( - *, - handler: logging.Handler = MISSING, - formatter: logging.Formatter = MISSING, - level: int = MISSING, - root: bool = True, -) -> None: - """A helper method to automatically setup the library's default logging. - """ - - if level is MISSING: - level = logging.INFO - - if handler is MISSING: - handler = logging.StreamHandler() - - if formatter is MISSING: - if isinstance(handler, logging.StreamHandler) and stream_supports_colour(handler.stream): - formatter = _ColourFormatter() - else: - dt_fmt = '%Y-%m-%d %H:%M:%S' - formatter = logging.Formatter('[{asctime}] [{levelname:<8}] {name}: {message}', dt_fmt, style='{') - - if root: - logger = logging.getLogger() - else: - lib, _, _ = __name__.partition('.') - logger = logging.getLogger(lib) - - handler.setFormatter(formatter) - logger.setLevel(level) - logger.addHandler(handler) From 8f891a320c21636e403ec44727836cbc621e0682 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:04:36 +0000 Subject: [PATCH 13/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/client.py | 1 + discord/utils.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/discord/client.py b/discord/client.py index 626cd31b16..1859ff2775 100644 --- a/discord/client.py +++ b/discord/client.py @@ -821,6 +821,7 @@ def run( async def runner(): async with self: await self.start(token, reconnect=reconnect) + try: asyncio.run(runner()) except KeyboardInterrupt: diff --git a/discord/utils.py b/discord/utils.py index 5c0f17b613..e7b4ffde0d 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -33,8 +33,6 @@ import importlib.resources import itertools import json -import logging -import os import re import sys import types From eaa3d77237c0f1eac3a0c31ae5bcdc0d2b641418 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:12:49 +0200 Subject: [PATCH 14/40] use utils._get_as_snowflake and remove old code --- discord/raw_models.py | 117 ++++++++++-------------------------------- 1 file changed, 27 insertions(+), 90 deletions(-) diff --git a/discord/raw_models.py b/discord/raw_models.py index 581ae7052e..0333c313b7 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -127,10 +127,7 @@ def __init__(self, data: MessageDeleteEvent) -> None: self.message_id: int = int(data["id"]) self.channel_id: int = int(data["channel_id"]) self.cached_message: Message | None = None - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: MessageDeleteEvent = data @@ -159,11 +156,7 @@ def __init__(self, data: BulkMessageDeleteEvent) -> None: self.message_ids: set[int] = {int(x) for x in data.get("ids", [])} self.channel_id: int = int(data["channel_id"]) self.cached_messages: list[Message] = [] - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: BulkMessageDeleteEvent = data @@ -197,11 +190,7 @@ def __init__(self, data: MessageUpdateEvent) -> None: self.channel_id: int = int(data["channel_id"]) self.data: MessageUpdateEvent = data self.cached_message: Message | None = None - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') class RawReactionActionEvent(_RawReprMixin): @@ -274,11 +263,7 @@ def __init__( self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: ReactionActionEvent = data @@ -304,11 +289,7 @@ class RawReactionClearEvent(_RawReprMixin): def __init__(self, data: ReactionClearEvent) -> None: self.message_id: int = int(data["message_id"]) self.channel_id: int = int(data["channel_id"]) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: ReactionClearEvent = data @@ -361,11 +342,7 @@ def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: ReactionClearEmojiEvent = data @@ -393,11 +370,7 @@ class RawIntegrationDeleteEvent(_RawReprMixin): def __init__(self, data: IntegrationDeleteEvent) -> None: self.integration_id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - - try: - self.application_id: int | None = int(data["application_id"]) - except KeyError: - self.application_id: int | None = None + self.application_id: int | None = utils._get_as_snowflake(data, 'application_id') self.data: IntegrationDeleteEvent = data @@ -460,10 +433,10 @@ class RawThreadDeleteEvent(_RawReprMixin): __slots__ = ("thread_id", "thread_type", "guild_id", "parent_id", "thread", "data") def __init__(self, data: ThreadDeleteEvent) -> None: - self.thread_id: int = int(data["id"]) - self.thread_type: ChannelType = try_enum(ChannelType, int(data["type"])) - self.guild_id: int = int(data["guild_id"]) - self.parent_id: int = int(data["parent_id"]) + self.thread_id: int = int(data["id"]) # type: ignore + self.thread_type: ChannelType = try_enum(ChannelType, int(data["type"])) # type: ignore + self.guild_id: int = int(data["guild_id"]) # type: ignore + self.parent_id: int = int(data["parent_id"]) # type: ignore self.thread: Thread | None = None self.data: ThreadDeleteEvent = data @@ -490,11 +463,7 @@ class RawVoiceChannelStatusUpdateEvent(_RawReprMixin): def __init__(self, data: VoiceChannelStatusUpdateEvent) -> None: self.id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - - try: - self.status: str | None = data["status"] - except KeyError: - self.status: str | None = None + self.status: str | None = data.get('status') self.data: VoiceChannelStatusUpdateEvent = data @@ -530,11 +499,7 @@ def __init__(self, data: TypingEvent) -> None: data.get("timestamp"), tz=datetime.timezone.utc ) self.member: Member | None = None - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') self.data: TypingEvent = data @@ -589,8 +554,8 @@ class RawScheduledEventSubscription(_RawReprMixin): __slots__ = ("event_id", "guild", "user_id", "event_type", "data") def __init__(self, data: ScheduledEventSubscription, event_type: str): - self.event_id: int = int(data["guild_scheduled_event_id"]) - self.user_id: int = int(data["user_id"]) + self.event_id: int = int(data["guild_scheduled_event_id"]) # type: ignore + self.user_id: int = int(data["user_id"]) # type: ignore self.guild: Guild | None = None self.event_type: str = event_type self.data: ScheduledEventSubscription = data @@ -676,42 +641,18 @@ def __init__(self, state: ConnectionState, data: AutoModActionExecution) -> None self.guild: Guild | None = state._get_guild(self.guild_id) self.user_id: int = int(data["user_id"]) self.content: str | None = data.get("content", None) - self.matched_keyword: str = data["matched_keyword"] + self.matched_keyword: str = data["matched_keyword"] # type: ignore self.matched_content: str | None = data.get("matched_content", None) - - if self.guild: - self.member: Member | None = self.guild.get_member(self.user_id) - else: - self.member: Member | None = None - - try: - # I don't see why this would be optional, but it's documented - # as such, so we should treat it that way - self.channel_id: int | None = int(data["channel_id"]) - self.channel: MessageableChannel | None = self.guild.get_channel_or_thread( - self.channel_id - ) - except KeyError: - self.channel_id: int | None = None - self.channel: MessageableChannel | None = None - - try: - self.message_id: int | None = int(data["message_id"]) - self.message: Message | None = state._get_message(self.message_id) - except KeyError: - self.message_id: int | None = None - self.message: Message | None = None - - try: - self.alert_system_message_id: int | None = int( - data["alert_system_message_id"] - ) - self.alert_system_message: Message | None = state._get_message( - self.alert_system_message_id - ) - except KeyError: - self.alert_system_message_id: int | None = None - self.alert_system_message: Message | None = None + self.channel_id: int | None = utils._get_as_snowflake(data, 'channel_id') + self.channel: MessageableChannel | None = ( + self.channel_id and self.guild + and self.guild.get_channel_or_thread(self.channel_id) + ) # type: ignore + self.member: Member | None = self.guild and self.guild.get_member(self.user_id) + self.message_id: int | None = utils._get_as_snowflake(data, 'message_id') + self.message: Message | None = state._get_message(self.message_id) + self.alert_system_message_id: int | None = utils._get_as_snowflake(data, 'alert_system_message_id') + self.alert_system_message: Message | None = state._get_message(self.alert_system_message_id) self.data: AutoModActionExecution = data def __repr__(self) -> str: @@ -848,11 +789,7 @@ def __init__(self, data: MessagePollVoteEvent, added: bool) -> None: self.answer_id: int = int(data["answer_id"]) self.data: MessagePollVoteEvent = data self.added: bool = added - - try: - self.guild_id: int | None = int(data["guild_id"]) - except KeyError: - self.guild_id: int | None = None + self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') # this is for backwards compatibility because VoiceProtocol.on_voice_..._update From def162d19977ecf9e6791b610727bff344aa6554 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:13:34 +0000 Subject: [PATCH 15/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/raw_models.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/discord/raw_models.py b/discord/raw_models.py index 0333c313b7..da2f9d8085 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -127,7 +127,7 @@ def __init__(self, data: MessageDeleteEvent) -> None: self.message_id: int = int(data["id"]) self.channel_id: int = int(data["channel_id"]) self.cached_message: Message | None = None - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: MessageDeleteEvent = data @@ -156,7 +156,7 @@ def __init__(self, data: BulkMessageDeleteEvent) -> None: self.message_ids: set[int] = {int(x) for x in data.get("ids", [])} self.channel_id: int = int(data["channel_id"]) self.cached_messages: list[Message] = [] - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: BulkMessageDeleteEvent = data @@ -190,7 +190,7 @@ def __init__(self, data: MessageUpdateEvent) -> None: self.channel_id: int = int(data["channel_id"]) self.data: MessageUpdateEvent = data self.cached_message: Message | None = None - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") class RawReactionActionEvent(_RawReprMixin): @@ -263,7 +263,7 @@ def __init__( self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionActionEvent = data @@ -289,7 +289,7 @@ class RawReactionClearEvent(_RawReprMixin): def __init__(self, data: ReactionClearEvent) -> None: self.message_id: int = int(data["message_id"]) self.channel_id: int = int(data["channel_id"]) - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionClearEvent = data @@ -342,7 +342,7 @@ def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: self.burst_colours: list = data.get("burst_colors", []) self.burst_colors: list = self.burst_colours self.type: ReactionType = try_enum(ReactionType, data.get("type", 0)) - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: ReactionClearEmojiEvent = data @@ -370,7 +370,9 @@ class RawIntegrationDeleteEvent(_RawReprMixin): def __init__(self, data: IntegrationDeleteEvent) -> None: self.integration_id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - self.application_id: int | None = utils._get_as_snowflake(data, 'application_id') + self.application_id: int | None = utils._get_as_snowflake( + data, "application_id" + ) self.data: IntegrationDeleteEvent = data @@ -463,7 +465,7 @@ class RawVoiceChannelStatusUpdateEvent(_RawReprMixin): def __init__(self, data: VoiceChannelStatusUpdateEvent) -> None: self.id: int = int(data["id"]) self.guild_id: int = int(data["guild_id"]) - self.status: str | None = data.get('status') + self.status: str | None = data.get("status") self.data: VoiceChannelStatusUpdateEvent = data @@ -499,7 +501,7 @@ def __init__(self, data: TypingEvent) -> None: data.get("timestamp"), tz=datetime.timezone.utc ) self.member: Member | None = None - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") self.data: TypingEvent = data @@ -555,7 +557,7 @@ class RawScheduledEventSubscription(_RawReprMixin): def __init__(self, data: ScheduledEventSubscription, event_type: str): self.event_id: int = int(data["guild_scheduled_event_id"]) # type: ignore - self.user_id: int = int(data["user_id"]) # type: ignore + self.user_id: int = int(data["user_id"]) # type: ignore self.guild: Guild | None = None self.event_type: str = event_type self.data: ScheduledEventSubscription = data @@ -643,16 +645,21 @@ def __init__(self, state: ConnectionState, data: AutoModActionExecution) -> None self.content: str | None = data.get("content", None) self.matched_keyword: str = data["matched_keyword"] # type: ignore self.matched_content: str | None = data.get("matched_content", None) - self.channel_id: int | None = utils._get_as_snowflake(data, 'channel_id') + self.channel_id: int | None = utils._get_as_snowflake(data, "channel_id") self.channel: MessageableChannel | None = ( - self.channel_id and self.guild + self.channel_id + and self.guild and self.guild.get_channel_or_thread(self.channel_id) ) # type: ignore self.member: Member | None = self.guild and self.guild.get_member(self.user_id) - self.message_id: int | None = utils._get_as_snowflake(data, 'message_id') + self.message_id: int | None = utils._get_as_snowflake(data, "message_id") self.message: Message | None = state._get_message(self.message_id) - self.alert_system_message_id: int | None = utils._get_as_snowflake(data, 'alert_system_message_id') - self.alert_system_message: Message | None = state._get_message(self.alert_system_message_id) + self.alert_system_message_id: int | None = utils._get_as_snowflake( + data, "alert_system_message_id" + ) + self.alert_system_message: Message | None = state._get_message( + self.alert_system_message_id + ) self.data: AutoModActionExecution = data def __repr__(self) -> str: @@ -789,7 +796,7 @@ def __init__(self, data: MessagePollVoteEvent, added: bool) -> None: self.answer_id: int = int(data["answer_id"]) self.data: MessagePollVoteEvent = data self.added: bool = added - self.guild_id: int | None = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: int | None = utils._get_as_snowflake(data, "guild_id") # this is for backwards compatibility because VoiceProtocol.on_voice_..._update From c81adf514a73b35832f1c0ac6467c5f91b95efb3 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:15:15 +0200 Subject: [PATCH 16/40] remove fmt:off references --- discord/voice/enums.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/discord/voice/enums.py b/discord/voice/enums.py index 7995453a0b..ca61852156 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -29,7 +29,6 @@ class OpCodes(Enum): - # fmt: off identify = 0 select_protocol = 1 ready = 2 @@ -42,7 +41,6 @@ class OpCodes(Enum): resumed = 9 client_connect = 10 client_disconnect = 11 - # fmt: on def __eq__(self, other: object) -> bool: if isinstance(other, int): @@ -56,7 +54,6 @@ def __int__(self) -> int: class ConnectionFlowState(Enum): - # fmt: off disconnected = 0 set_guild_voice_state = 1 got_voice_state_update = 2 @@ -66,4 +63,3 @@ class ConnectionFlowState(Enum): got_websocket_ready = 6 got_ip_discovery = 7 connected = 8 - # fmt: on From 88175b65440845fb8c8ab8358d30a295fb84f255 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:15:51 +0000 Subject: [PATCH 17/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/voice/enums.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/discord/voice/enums.py b/discord/voice/enums.py index ca61852156..2c164e2f7a 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -29,18 +29,18 @@ class OpCodes(Enum): - identify = 0 - select_protocol = 1 - ready = 2 - heartbeat = 3 + identify = 0 + select_protocol = 1 + ready = 2 + heartbeat = 3 session_description = 4 - speaking = 5 - heartbeat_ack = 6 - resume = 7 - hello = 8 - resumed = 9 - client_connect = 10 - client_disconnect = 11 + speaking = 5 + heartbeat_ack = 6 + resume = 7 + hello = 8 + resumed = 9 + client_connect = 10 + client_disconnect = 11 def __eq__(self, other: object) -> bool: if isinstance(other, int): @@ -54,12 +54,12 @@ def __int__(self) -> int: class ConnectionFlowState(Enum): - disconnected = 0 - set_guild_voice_state = 1 - got_voice_state_update = 2 + disconnected = 0 + set_guild_voice_state = 1 + got_voice_state_update = 2 got_voice_server_update = 3 - got_both_voice_updates = 4 - websocket_connected = 5 - got_websocket_ready = 6 - got_ip_discovery = 7 - connected = 8 + got_both_voice_updates = 4 + websocket_connected = 5 + got_websocket_ready = 6 + got_ip_discovery = 7 + connected = 8 From 753e11b9dea98b14ad9491f90ab807fb8bd97cb4 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:17:44 +0200 Subject: [PATCH 18/40] add voice to the packages --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b9fd493eba..6204f05b7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ packages = [ "discord.ext.pages", "discord.ext.bridge", "discord.bin", + "discord.voice", ] [tool.setuptools.dynamic] From 51858b91897f2e6132126a1a8471d82994558c5d Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:18:46 +0200 Subject: [PATCH 19/40] keep GuildVoiceState alias to VoiceState --- discord/types/voice.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/discord/types/voice.py b/discord/types/voice.py index 3840782a2f..307f98cee3 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -56,6 +56,9 @@ class VoiceState(TypedDict): guild_id: NotRequired[Snowflake] +GuildVoiceState = VoiceState + + class VoiceRegion(TypedDict): id: str name: str From b414b5f890bd16ff4a8f165b34b8627e06acf1dc Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 27 Aug 2025 00:21:44 +0200 Subject: [PATCH 20/40] fix type annotation --- discord/opus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/opus.py b/discord/opus.py index 6c8720d49e..5f5efee8dd 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -385,7 +385,7 @@ def __init__( fec: bool = True, expected_packet_loss: float = 0.15, bandwidth: BAND_CTL = "full", - signal_type: SIGNAL_TL = "auto", + signal_type: SIGNAL_CTL = "auto", ) -> None: if application not in application_ctl: raise ValueError("invalid application ctl type provided") From 9030478fdbe9f5b887befd752b246b47f3c0ec54 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Thu, 28 Aug 2025 00:53:16 +0200 Subject: [PATCH 21/40] voice recv things --- discord/sinks/core.py | 304 +++++++++++++++++++++++++++++++++++++- discord/sinks/enums.py | 32 ++++ discord/voice/_types.py | 106 +------------ discord/voice/client.py | 8 +- discord/voice/flags.py | 88 +++++++++++ discord/voice/gateway.py | 7 +- discord/voice/recorder.py | 75 ---------- discord/voice/state.py | 132 ++++++++++++++++- 8 files changed, 555 insertions(+), 197 deletions(-) create mode 100644 discord/sinks/enums.py create mode 100644 discord/voice/flags.py delete mode 100644 discord/voice/recorder.py diff --git a/discord/sinks/core.py b/discord/sinks/core.py index b67701cb67..d82bc374ce 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -25,20 +25,32 @@ from __future__ import annotations +import asyncio +from collections.abc import Callable, Coroutine +from functools import partial import io import os import struct import sys import threading import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeVar, overload -from ..types import snowflake +from discord.utils import MISSING + +from .enums import FilteringMode from .errors import SinkException if TYPE_CHECKING: + from typing_extensions import ParamSpec + + from discord import abc + from discord.types import snowflake from ..voice.client import VoiceClient + R = TypeVar('R') + P = ParamSpec('P') + __all__ = ( "Filters", "Sink", @@ -60,6 +72,90 @@ } +class Filter: + """Represents a filter for a :class:`~.Sink`. + + This has to be inherited in order to provide a filter to a sink. + + .. versionadded:: 2.7 + """ + + @overload + async def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool: ... + + @overload + def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool: ... + + def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool | Coroutine[Any, Any, bool]: + """|maybecoro| + + Represents the filter callback. + + This is called automatically everytime a voice packet is received to check whether it should be stored in + ``sink``. + + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet was received from. + ssrc: :class:`int` + The user's ssrc. + packet: :class:`~.RawData` + The raw data packet. + + Returns + ------- + :class:`bool` + Whether the filter was successful. + """ + raise NotImplementedError('subclasses must implement this') + + def cleanup(self) -> None: + """A function called when the filter is ready for cleanup.""" + pass + + +class Handler: + """Represents a handler for a :class:`~.Sink`. + + This has to be inherited in order to provide a handler to a sink. + + .. versionadded:: 2.7 + """ + + @overload + async def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any: ... + + @overload + def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any: ... + + def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any | Coroutine[Any, Any, Any]: + """|maybecoro| + + Represents the handler callback. + + This is called automatically everytime a voice packet which has successfully passed the filters is received. + + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet is from. + ssrc: :class:`int` + The user's ssrc. + packet: :class:`~.RawData` + The raw data packet. + """ + raise NotImplementedError('subclasses must implement this') + + def cleanup(self) -> None: + """A function called when the handler is ready for cleanup.""" + pass + + class Filters: """Filters for :class:`~.Sink` @@ -249,3 +345,207 @@ def get_all_audio(self): def get_user_audio(self, user: snowflake.Snowflake): """Gets the audio file(s) of one specific user.""" return os.path.realpath(self.audio_data.pop(user)) + + +class Sink: + """Represents a sink for voice recording. + + This is used as a way of "storing" the recordings. + + This class is abstracted, and must be subclassed in order to apply functionalities to + it. + + Parameters + ---------- + filters: List[:class:`~.Filter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.FilteringMode` + How the filters should work. If set to :attr:`~.FilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.FilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.Handler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + """ + + __listeners__: dict[str, list[Callable[..., Any]]] = {} + + def __init_subclass__(cls) -> None: + listeners: dict[str, list[Callable[..., Any]]] = {} + + for base in reversed(cls.__mro__): + for elem, value in base.__dict__.items(): + if elem in listeners: + del listeners[elem] + + if isinstance(value, staticmethod): + value = value.__func__ + elif isinstance(value, classmethod): + value = partial(value.__func__, cls) + + if not hasattr(value, '__listener__'): + continue + + event_name = getattr(value, '__listener_name__', elem).removeprefix('on_') + + try: + listeners[event_name].append(value) + except KeyError: + listeners[event_name] = [value] + + cls.__listeners__ = listeners + + def __init__( + self, + *, + filters: list[Filter] = MISSING, + filtering_mode: FilteringMode = FilteringMode.all, + handlers: list[Handler] = MISSING, + ) -> None: + self.filtering_mode: FilteringMode = filtering_mode + self._filters: list[Filter] = filters or [] + self._handlers: list[Handler] = handlers or [] + self.__dispatch_set: set[asyncio.Task[Any]] = set() + + def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: + event = event.removeprefix('on_') + + listeners = self.__listeners__.get(event, []) + + for listener in listeners: + task = asyncio.create_task( + listener(*args, **kwargs), + name=f'dispatch-{event}:{id(listener):#x}', + ) + self.__dispatch_set.add(task) + task.add_done_callback(self.__dispatch_set.remove) + + def cleanup(self) -> None: + """Cleans all the data in this sink. + + This should be called when you won't be performing any more operations in this sink. + """ + + for task in list(self.__dispatch_set): + if task.done(): + continue + task.set_result(None) + + for filter in self._filters: + filter.cleanup() + + for handler in self._handlers: + handler.cleanup() + + def __del__(self) -> None: + self.cleanup() + + def add_filter(self, filter: Filter, /) -> None: + """Adds a filter to this sink. + + Parameters + ---------- + filter: :class:`~.Filter` + The filter to add. + + Raises + ------ + TypeError + You did not provide a Filter object. + """ + + if not isinstance(filter, Filter): + raise TypeError(f'expected a Filter object, not {filter.__class__.__name__}') + self._filters.append(filter) + + def remove_filter(self, filter: Filter, /) -> None: + """Removes a filter from this sink. + + Parameters + ---------- + filter: :class:`~.Filter` + The filter to remove. + """ + + try: + self._filters.remove(filter) + except ValueError: + pass + + def add_handler(self, handler: Handler, /) -> None: + """Adds a handler to this sink. + + Parameters + ---------- + handler: :class:`~.Handler` + The handler to add. + + Raises + ------ + TypeError + You did not provide a Handler object. + """ + + if not isinstance(handler, Handler): + raise TypeError(f'expected a Handler object, not {handler.__class__.__name__}') + self._handlers.append(handler) + + def remove_handler(self, handler: Handler, /) -> None: + """Removes a handler from this sink. + + Parameters + ---------- + handler: :class:`~.Handler` + The handler to remove. + """ + + try: + self._handlers.remove(handler) + except ValueError: + pass + + @staticmethod + def listener(event: str = MISSING) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: + """Registers a function to be an event listener for this sink. + + The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised; and + also must be inside a sink class. + + Example + ------- + + .. code-block:: python3 + + class MySink(Sink): + @Sink.listener() + async def on_member_speaking_state_update(member, ssrc, state): + pass + + Parameters + ---------- + event: :class:`str` + The event name to listen to. If not provided, defaults to the function name. + + Raises + ------ + TypeError + The coroutine passed is not actually a coroutine, or the listener is not in a sink class. + """ + + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + parts = func.__qualname__.split('.') + + if not parts or not len(parts) > 1: + raise TypeError('event listeners must be declared in a Sink class') + + if parts[-1] != func.__name__: + raise NameError('qualified name and function name mismatch, this should not happen') + + if not asyncio.iscoroutinefunction(func): + raise TypeError('event listeners must be coroutine functions') + + func.__listener__ = True + if event is not MISSING: + func.__listener_name__ = event + return func + return decorator diff --git a/discord/sinks/enums.py b/discord/sinks/enums.py new file mode 100644 index 0000000000..6f552bf787 --- /dev/null +++ b/discord/sinks/enums.py @@ -0,0 +1,32 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from discord.enums import Enum + + +class FilteringMode(Enum): + all = 0 + any = 1 diff --git a/discord/voice/_types.py b/discord/voice/_types.py index 3222b5afe6..b2e252bd39 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -25,8 +25,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -37,14 +36,11 @@ RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) - from discord.sinks import Sink P = ParamSpec("P") R = TypeVar("R") - RecordCallback = Union[Callable[P, R], Callable[P, Awaitable[R]]] ClientT = TypeVar("ClientT", bound="Client", covariant=True) -VoiceProtocolT = TypeVar("VoiceProtocolT", bound="VoiceProtocol", covariant=True) class VoiceProtocol(Generic[ClientT]): @@ -166,103 +162,3 @@ def cleanup(self) -> None: """ key, _ = self.channel._get_voice_client_key() self.client._connection._remove_voice_client(key) - - -class VoiceRecorderProtocol(Generic[VoiceProtocolT]): - """A class that represents a Discord voice client recorder protocol. - - .. warning:: - - If you are an end user, you **should not construct this manually** but instead - take it from a :class:`VoiceProtocol` implementation, like :attr:`VoiceClient.recorder`. - The parameters and methods being documented here is so third party libraries can refer to it - when implementing their own RecorderProtocol types. - - This is an abstract class. The library provides a concrete implementation under - :class:`VoiceRecorderClient`. - - This class allows you to implement a protocol to allow for an external - method of receiving and handling voice data. - - .. versionadded:: 2.7 - - Parameters - ---------- - client: :class:`VoiceProtocol` - The voice client (or its subclasses) that are bound to this recorder. - channel: :class:`abc.Connectable` - The voice channel that is being recorder. If not provided, defaults to - :attr:`VoiceProtocol.channel` - """ - - def __init__( - self, client: VoiceProtocolT, channel: abc.Connectable | None = None - ) -> None: - self.client: VoiceProtocolT = client - self.channel: abc.Connectable = channel or client.channel - - def get_ssrc(self, user_id: int) -> int: - """Gets the ssrc of a user. - - Parameters - ---------- - user_id: :class:`int` - The user ID to get the ssrc from. - - Returns - ------- - :class:`int` - The ssrc for the provided user ID. - """ - raise NotImplementedError("subclasses must implement this") - - def unpack(self, data: bytes) -> bytes | None: - """Takes an audio packet received from Discord and decodes it. - - Parameters - ---------- - data: :class:`bytes` - The bytes received by Discord. - - Returns - ------- - Optional[:class:`bytes`] - The unpacked bytes, or ``None`` if they could not be unpacked. - """ - raise NotImplementedError("subclasses must implement this") - - def record( - self, - sink: Sink, - callback: RecordCallback[P, R], - sync_start: bool, - *callback_args: P.args, - **callback_kwargs: P.kwargs, - ) -> None: - r"""Start recording audio from the current voice channel in the provided sink. - - You must be in a voice channel. - - Parameters - ---------- - sink: :class:`~discord.Sink` - The sink to record to. - callback: Callable[..., Any] - The function called after the bot has stopped recording. This can take any arguments and - can return an awaitable. - sync_start: :class:`bool` - Whether the subsequent recording users will start with silence. This is useful for recording - audio just as it was heard. - - Raises - ------ - RecordingException - Not connected to a voice channel - TypeError - You did not pass a Sink object. - """ - raise NotImplementedError("subclasses must implement this") - - def stop(self) -> None: - """Stops recording.""" - raise NotImplementedError("subclasses must implement this") diff --git a/discord/voice/client.py b/discord/voice/client.py index c52abaae13..5a43c6db9d 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -53,6 +53,7 @@ RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) + from discord.sinks import Sink from discord.state import ConnectionState from discord.types.voice import SupportedModes from discord.user import ClientUser @@ -106,8 +107,6 @@ def __init__( self, client: Client, channel: abc.Connectable, - *, - use_recorder: bool = True, ) -> None: if not has_nacl: raise RuntimeError( @@ -134,11 +133,6 @@ def __init__( self._connection: VoiceConnectionState = self.create_connection_state() - # voice recv things - # self._recorder: VoiceRecorderClient | None = None - # if use_recorder: - # self._recorder = VoiceRecorderClient(self) - warn_nacl: bool = not has_nacl supported_modes: tuple[SupportedModes, ...] = ( "aead_xchacha20_poly1305_rtpsize", diff --git a/discord/voice/flags.py b/discord/voice/flags.py new file mode 100644 index 0000000000..45c75fd21e --- /dev/null +++ b/discord/voice/flags.py @@ -0,0 +1,88 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from discord.flags import BaseFlags, fill_with_flags, flag_value + + +@fill_with_flags() +class SpeakingFlags(BaseFlags): + r"""Wraps up a Discord user speaking state flag value. + + .. container:: operations + + .. describe:: x == y + + Checks if two flags are equal. + .. describe:: x != y + + Checks if two flags are not equal. + .. describe:: x + y + + Adds two flags together. Equivalent to ``x | y``. + .. describe:: x - y + + Substract two flags from each other. + .. describe:: x | y + + Returns the union of two flags. Equivalent to ``x + y``. + .. describe:: x & y + + Returns the intersection of two flags. + .. describe:: ~x + + Returns the inverse of a flag. + .. describe:: hash(x) + + Returns the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + .. versionadded:: 2.7 + + Attributes + ---------- + value: :class:`int` + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available flags. You should query + flags via the properties rather than using this raw value. + """ + + @flag_value + def voice(self): + """:class:`bool`: Normal transmission of voice audio""" + return 1 << 0 + + @flag_value + def soundshare(self): + """:class:`bool`: Transmission of context audio for video, no speaking indicator""" + return 1 << 1 + + @flag_value + def priority(self): + """:class:`bool`: Priority speaker, lowering audio of other speakers""" + return 1 << 2 diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index a6958722a7..aab003a527 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -53,6 +53,9 @@ class KeepAliveHandler(KeepAliveHandlerBase): + if TYPE_CHECKING: + ws: VoiceWebSocket + def __init__( self, *args: Any, @@ -116,7 +119,7 @@ def __init__( self.ssrc_map: dict[str, dict[str, Any]] = {} if hook: - self._hook = hook # type: ignore + self._hook = hook or state.ws_hook # type: ignore @property def token(self) -> str | None: @@ -186,7 +189,7 @@ async def received_message(self, msg: Any, /): ) self._keep_alive.start() - await utils.maybe_coroutine(self._hook, self, data) + await utils.maybe_coroutine(self._hook, self, msg) async def ready(self, data: dict[str, Any]) -> None: state = self.state diff --git a/discord/voice/recorder.py b/discord/voice/recorder.py deleted file mode 100644 index 5d59e0a953..0000000000 --- a/discord/voice/recorder.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from __future__ import annotations - -import asyncio -import threading -from typing import TYPE_CHECKING, Any, TypeVar - -from discord.opus import DecodeManager - -from ._types import VoiceRecorderProtocol - -if TYPE_CHECKING: - from discord.sinks import Sink - - from .client import VoiceClient - from .gateway import VoiceWebSocket - - VoiceClientT = TypeVar("VoiceClientT", bound=VoiceClient, covariant=True) - - -class VoiceRecorderClient(VoiceRecorderProtocol[VoiceClientT]): - """Represents a voice recorder for a voice client. - - You should not construct this but instead obtain it from :attr:`VoiceClient.recorder`. - - .. versionadded:: 2.7 - """ - - def __init__(self, client: VoiceClientT) -> None: - super().__init__(client) - - self._paused: asyncio.Event = asyncio.Event() - self._recording: asyncio.Event = asyncio.Event() - self.decoder: DecodeManager = DecodeManager(self) - self.sync_start: bool = False - self.sinks: dict[int, tuple[Sink, threading.Thread]] = {} - - def is_paused(self) -> bool: - """Whether the current recorder is paused.""" - return self._paused.is_set() - - def is_recording(self) -> bool: - """Whether the current recording is actively recording.""" - return self._recording.is_set() - - async def hook(self, ws: VoiceWebSocket, data: dict[str, Any]) -> None: ... - - def record( - self, - sink: Sink, - ) -> int: ... diff --git a/discord/voice/state.py b/discord/voice/state.py index 6884e47eba..acb16dbcac 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -31,14 +31,16 @@ import socket import threading from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict from discord import utils from discord.backoff import ExponentialBackoff from discord.errors import ConnectionClosed +from discord.object import Object -from .enums import ConnectionFlowState +from .enums import ConnectionFlowState, OpCodes from .gateway import VoiceWebSocket +from .flags import SpeakingFlags if TYPE_CHECKING: from discord import abc @@ -55,20 +57,29 @@ _log = logging.getLogger(__name__) -class SocketEventReader(threading.Thread): +class SocketReader(threading.Thread): def __init__( - self, state: VoiceConnectionState, *, start_paused: bool = True + self, + state: VoiceConnectionState, + name: str, + *, + start_paused: bool = True, ) -> None: super().__init__( daemon=True, - name=f"voice-socket-reader:{id(self):#x}", + name=name, ) + self.state: VoiceConnectionState = state self.start_paused: bool = start_paused self._callbacks: list[SocketReaderCallback] = [] self._running: threading.Event = threading.Event() self._end: threading.Event = threading.Event() self._idle_paused: bool = True + self._started: threading.Event = threading.Event() + + def is_running(self) -> bool: + return self._started.is_set() def register(self, callback: SocketReaderCallback) -> None: self._callbacks.append(callback) @@ -102,10 +113,12 @@ def resume(self, *, force: bool = False) -> None: self._running.set() def stop(self) -> None: + self._started.clear() self._end.set() self._running.set() def run(self) -> None: + self._started.set() self._end.clear() self._running.set() @@ -115,12 +128,78 @@ def run(self) -> None: try: self._do_run() except Exception: - _log.exception("Error while starting socket event reader at %s", self) + _log.exception( + 'An error ocurred while running the socket reader %s', + self.name, + ) finally: self.stop() self._running.clear() self._callbacks.clear() + def _do_run(self) -> None: + raise NotImplementedError + + +class SocketVoiceRecvReader(SocketReader): + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True, + ) -> None: + super().__init__( + state, + f'voice-recv-socket-reader:{id(self):#x}', + start_paused=start_paused, + ) + + def _do_run(self) -> None: + while not self._end.is_set(): + if not self._running.is_set(): + self._running.wait() + continue + + try: + readable, _, _ = select.select([self.state.socket], [], [], 30) + except (ValueError, TypeError, OSError) as e: + _log.debug( + "Select error handling socket in reader, this should be safe to ignore: %s: %s", + e.__class__.__name__, + e, + ) + continue + + if not readable: + continue + + try: + data = self.state.socket.recv(4096) + except OSError: + _log.debug( + 'Error reading from socket in %s, this should be safe to ignore', + self, + exc_info=True, + ) + else: + for cb in self._callbacks: + try: + cb(data) + except Exception: + _log.exception( + 'Error while calling %s in %s', + cb, + self, + ) + + +class SocketEventReader(SocketReader): + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True + ) -> None: + super().__init__( + state, + f'voice-socket-event-reader:{id(self):#x}', + start_paused=start_paused, + ) + def _do_run(self) -> None: while not self._end.is_set(): if not self._running.is_set(): @@ -160,6 +239,11 @@ def _do_run(self) -> None: ) +class SSRC(TypedDict): + user_id: int + speaking: SpeakingFlags + + class VoiceConnectionState: def __init__( self, @@ -199,6 +283,42 @@ def __init__( self._connector: asyncio.Task[None] | None = None self._socket_reader = SocketEventReader(self) self._socket_reader.start() + self._voice_recv_socket = SocketVoiceRecvReader(self) + self.user_ssrc_map: dict[int, SSRC] = {} + + def start_record_socket(self) -> None: + if self._voice_recv_socket.is_running(): + return + self._voice_recv_socket.start() + + def stop_record_socket(self) -> None: + if self._voice_recv_socket.is_running(): + self._voice_recv_socket.stop() + + def get_user_by_ssrc(self, ssrc: int) -> abc.Snowflake | None: + data = self.user_ssrc_map.get(ssrc) + if data is None: + return None + + user = int(data['user_id']) + return self.guild.get_member(user) or self.client._state.get_user(user) or Object(id=user) + + def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: + op = msg['op'] + data = msg.get('d', {}) + + if op == OpCodes.speaking: + ssrc = data['ssrc'] + user = int(data['user_id']) + speaking = data['speaking'] + + if ssrc in self.user_ssrc_map: + self.user_ssrc_map[ssrc]['speaking'].value = speaking + else: + self.user_ssrc_map[ssrc] = { + 'user_id': user, + 'speaking': SpeakingFlags._from_value(speaking), + } @property def state(self) -> ConnectionFlowState: From 0b7c1476e36e1a63b76e9fe99bd1c138766236e8 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 02:06:14 +0200 Subject: [PATCH 22/40] refactor voice recv logic --- discord/client.py | 2 +- discord/opus.py | 2 - discord/sinks/__init__.py | 1 + discord/sinks/core.py | 498 ++++++++++++++++++++++---------------- discord/sinks/enums.py | 6 +- discord/sinks/errors.py | 22 ++ discord/sinks/m4a.py | 255 ++++++++++++++----- discord/sinks/mka.py | 239 ++++++++++++++---- discord/sinks/mkv.py | 238 ++++++++++++++---- discord/sinks/mp3.py | 239 ++++++++++++++---- discord/sinks/mp4.py | 256 +++++++++++++++----- discord/sinks/ogg.py | 239 ++++++++++++++---- discord/sinks/pcm.py | 143 ++++++++++- discord/sinks/wave.py | 172 ++++++++++--- discord/voice/client.py | 95 +++++++- discord/voice/flags.py | 88 ------- discord/voice/state.py | 342 ++++++++++++++++++++++---- 17 files changed, 2142 insertions(+), 695 deletions(-) delete mode 100644 discord/voice/flags.py diff --git a/discord/client.py b/discord/client.py index 1859ff2775..5a822b8007 100644 --- a/discord/client.py +++ b/discord/client.py @@ -466,7 +466,7 @@ def _schedule_event( return task def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - logging.getLogger("discord.state").debug("Dispatching event %s", event) + _log.debug("Dispatching event %s", event) method = f"on_{event}" listeners = self._listeners.get(event) diff --git a/discord/opus.py b/discord/opus.py index 5f5efee8dd..14bfdace2c 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -42,8 +42,6 @@ from .sinks import RawData if TYPE_CHECKING: - from discord.voice.recorder import VoiceRecorderClient - T = TypeVar("T") APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] diff --git a/discord/sinks/__init__.py b/discord/sinks/__init__.py index 6db5209af0..6d38ae74f0 100644 --- a/discord/sinks/__init__.py +++ b/discord/sinks/__init__.py @@ -10,6 +10,7 @@ from .core import * from .errors import * +from .enums import * from .m4a import * from .mka import * from .mkv import * diff --git a/discord/sinks/core.py b/discord/sinks/core.py index d82bc374ce..eb044ea600 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -26,36 +26,34 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Iterable from functools import partial -import io -import os +import logging import struct import sys -import threading import time -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from discord import utils +from discord.enums import SpeakingState from discord.utils import MISSING -from .enums import FilteringMode -from .errors import SinkException +from .enums import SinkFilteringMode if TYPE_CHECKING: from typing_extensions import ParamSpec from discord import abc - from discord.types import snowflake from ..voice.client import VoiceClient R = TypeVar('R') P = ParamSpec('P') __all__ = ( - "Filters", "Sink", - "AudioData", "RawData", + 'SinkFilter', + 'SinkHandler', ) @@ -65,14 +63,11 @@ CREATE_NO_WINDOW = 0x08000000 -default_filters = { - "time": 0, - "users": [], - "max_size": 0, -} +S = TypeVar('S', bound='Sink') +_log = logging.getLogger(__name__) -class Filter: +class SinkFilter(Generic[S]): """Represents a filter for a :class:`~.Sink`. This has to be inherited in order to provide a filter to a sink. @@ -81,18 +76,17 @@ class Filter: """ @overload - async def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool: ... + async def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool: ... @overload - def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool: ... + def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool: ... - def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> bool | Coroutine[Any, Any, bool]: + def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool | Coroutine[Any, Any, bool]: """|maybecoro| - Represents the filter callback. + This is called automatically everytime a voice packet is received. - This is called automatically everytime a voice packet is received to check whether it should be stored in - ``sink``. + Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. Parameters ---------- @@ -100,8 +94,6 @@ def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> The sink the packet was received from, if the filter check goes through. user: :class:`~discord.abc.Snowflake` The user that the packet was received from. - ssrc: :class:`int` - The user's ssrc. packet: :class:`~.RawData` The raw data packet. @@ -112,12 +104,43 @@ def filter(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> """ raise NotImplementedError('subclasses must implement this') + @overload + async def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool: ... + + @overload + def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool: ... + + def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool | Coroutine[Any, Any, bool]: + """|maybecoro| + + This is called automatically everytime a speaking state is updated. + + Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. + + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet was received from. + before: :class:`~discord.SpeakingState` + The speaking state before the update. + after: :class:`~discord.SpeakingState` + The speaking state after the update. + + Returns + ------- + :class:`bool` + Whether the filter was successful. + """ + raise NotImplementedError('subclasses must implement this') + def cleanup(self) -> None: """A function called when the filter is ready for cleanup.""" pass -class Handler: +class SinkHandler(Generic[S]): """Represents a handler for a :class:`~.Sink`. This has to be inherited in order to provide a handler to a sink. @@ -126,16 +149,14 @@ class Handler: """ @overload - async def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any: ... + async def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any: ... @overload - def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any: ... + def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any: ... - def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> Any | Coroutine[Any, Any, Any]: + def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any | Coroutine[Any, Any, Any]: """|maybecoro| - Represents the handler callback. - This is called automatically everytime a voice packet which has successfully passed the filters is received. Parameters @@ -144,53 +165,38 @@ def handle(self, sink: Sink, user: abc.Snowflake, ssrc: int, packet: RawData) -> The sink the packet was received from, if the filter check goes through. user: :class:`~discord.abc.Snowflake` The user that the packet is from. - ssrc: :class:`int` - The user's ssrc. packet: :class:`~.RawData` The raw data packet. """ - raise NotImplementedError('subclasses must implement this') - - def cleanup(self) -> None: - """A function called when the handler is ready for cleanup.""" pass + @overload + async def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any: ... -class Filters: - """Filters for :class:`~.Sink` - - .. versionadded:: 2.0 - - Parameters - ---------- - container - Container of all Filters. - """ - - def __init__(self, **kwargs): - self.filtered_users = kwargs.get("users", default_filters["users"]) - self.seconds = kwargs.get("time", default_filters["time"]) - self.max_size = kwargs.get("max_size", default_filters["max_size"]) - self.finished = False + @overload + def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any: ... - @staticmethod - def container(func): # Contains all filters - def _filter(self, data, user): - if not self.filtered_users or user in self.filtered_users: - return func(self, data, user) + def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any | Coroutine[Any, Any, Any]: + """|maybecoro| - return _filter + This is called automatically everytime a speaking state update is received which has successfully passed the filters. - def init(self): - if self.seconds != 0: - thread = threading.Thread(target=self.wait_and_stop) - thread.start() + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet was received from. + before: :class:`~discord.SpeakingState` + The speaking state before the update. + after: :class:`~discord.SpeakingState` + The speaking state after the update. + """ + pass - def wait_and_stop(self): - time.sleep(self.seconds) - if self.finished: - return - self.vc.stop_recording() + def cleanup(self) -> None: + """A function called when the handler is ready for cleanup.""" + pass class RawData: @@ -227,147 +233,98 @@ def __init__(self, data: bytes, client: VoiceClient): self.decrypted_data: bytes = getattr( self.client, f"_decrypt_{self.client.mode}" )(self.header, self.data) - self.decoded_data: bytes | None = None + self.decoded_data: bytes = MISSING self.user_id: int | None = None self.receive_time: float = time.perf_counter() -class AudioData: - """Handles data that's been completely decrypted and decoded and is ready to be saved to file. - - .. versionadded:: 2.0 - """ - - def __init__(self, file): - self.file = file - self.finished = False - - def write(self, data): - """Writes audio data. - - Raises - ------ - ClientException - The AudioData is already finished writing. - """ - if self.finished: - raise SinkException("The AudioData is already finished writing.") - try: - self.file.write(data) - except ValueError: - pass - - def cleanup(self): - """Finishes and cleans up the audio data. - - Raises - ------ - ClientException - The AudioData is already finished writing. - """ - if self.finished: - raise SinkException("The AudioData is already finished writing.") - self.file.seek(0) - self.finished = True - - def on_format(self, encoding): - """Called when audio data is formatted. - - Raises - ------ - ClientException - The AudioData is still writing. - """ - if not self.finished: - raise SinkException("The AudioData is still writing.") - +class Sink: + r"""Represents a sink for voice recording. -class Sink(Filters): - """A sink "stores" recorded audio data. + This is used as a way of "storing" the recordings. - Can be subclassed for extra customizablilty. + This class is abstracted, and must be subclassed in order to apply functionalities to + it. - .. warning:: - It is recommended you use - the officially provided sink classes, - such as :class:`~discord.sinks.WaveSink`. + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. - just replace the following like so: :: + Events + ------ - vc.start_recording( - MySubClassedSink(), - finished_callback, - ctx.channel, - ) + These section outlines all the available sink events. - .. versionadded:: 2.0 + .. function:: on_voice_packet_receive(user, data) + Called when a voice packet is received from a member. - Raises - ------ - ClientException - An invalid encoding type was specified. - ClientException - Audio may only be formatted after recording is finished. - """ + This is called **after** the filters went through. - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) - self.vc: VoiceClient | None = None - self.audio_data = {} + :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param data: The RawData of the packet. + :type data: :class:`~.RawData` - def init(self, vc): # called under listen - self.vc = vc - super().init() + .. function:: on_unfiltered_voice_packet_receive(user, data) + Called when a voice packet is received from a member. - @Filters.container - def write(self, data, user): - if user not in self.audio_data: - file = io.BytesIO() - self.audio_data.update({user: AudioData(file)}) + Unlike ``on_voice_packet_receive``, this is called **before any filters** are called. - file = self.audio_data[user] - file.write(data) + :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param data: The RawData of the packet. + :type data: :class:`~.RawData` - def cleanup(self): - self.finished = True - for file in self.audio_data.values(): - file.cleanup() - self.format_audio(file) + .. function:: on_speaking_state_update(user, before, after) + Called when a member's voice state changes. - def get_all_audio(self): - """Gets all audio files.""" - return [x.file for x in self.audio_data.values()] + This is called **after** the filters went through. - def get_user_audio(self, user: snowflake.Snowflake): - """Gets the audio file(s) of one specific user.""" - return os.path.realpath(self.audio_data.pop(user)) + :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param before: The user's state before it was updated. + :type before: :class:`~discord.SpeakingFlags` + :param after: The user's state after it was updated. + :type after: :class:`~discord.SpeakingFlags` + .. function:: on_unfiltered_speaking_state_update(user, before, after) + Called when a voice packet is received from a member. -class Sink: - """Represents a sink for voice recording. + Unlike ``on_speaking_state_update``, this is called **before any filters** are called. - This is used as a way of "storing" the recordings. + :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. + :type user: :class:`~discord.abc.Snowflake` + :param before: The user's state before it was updated. + :type before: :class:`~discord.SpeakingFlags` + :param after: The user's state after it was updated. + :type after: :class:`~discord.SpeakingFlags` - This class is abstracted, and must be subclassed in order to apply functionalities to - it. + .. function:: on_error(event, exception, \*args, \*\*kwargs) + Called when an error ocurrs in any of the events above. The default implementation logs the exception + to stdout. - Parameters - ---------- - filters: List[:class:`~.Filter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.FilteringMode` - How the filters should work. If set to :attr:`~.FilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.FilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.Handler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. + :param event: The event in which the error ocurred. + :type event: :class:`str` + :param exception: The exception that ocurred. + :type exception: :class:`Exception` + :param \*args: The arguments that were passed to the event. + :param \*\*kwargs: The key-word arguments that were passed to the event. """ + if TYPE_CHECKING: + __filtering_mode: SinkFilteringMode + _filter_strat: Callable[..., bool] + client: VoiceClient + __listeners__: dict[str, list[Callable[..., Any]]] = {} def __init_subclass__(cls) -> None: @@ -398,28 +355,115 @@ def __init_subclass__(cls) -> None: def __init__( self, *, - filters: list[Filter] = MISSING, - filtering_mode: FilteringMode = FilteringMode.all, - handlers: list[Handler] = MISSING, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, ) -> None: - self.filtering_mode: FilteringMode = filtering_mode - self._filters: list[Filter] = filters or [] - self._handlers: list[Handler] = handlers or [] + self._paused: bool = False + self.filtering_mode = filtering_mode + self._filters: list[SinkFilter] = filters or [] + self._handlers: list[SinkHandler] = handlers or [] self.__dispatch_set: set[asyncio.Task[Any]] = set() + self._listeners: dict[str, list[Callable[[Iterable[object]], bool]]] = self.__listeners__ + + @property + def filtering_mode(self) -> SinkFilteringMode: + return self.__filtering_mode + + @filtering_mode.setter + def filtering_mode(self, value: SinkFilteringMode) -> None: + if value is SinkFilteringMode.all: + self._filter_strat = all + elif value is SinkFilteringMode.any: + self._filter_strat = any + else: + raise TypeError(f'expected a FilteringMode enum member, got {value.__class__.__name__}') + + self.__filtering_mode = value def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: - event = event.removeprefix('on_') + _log.debug('Dispatching sink %s event %s', self.__class__.__name__, event) + method = f'on_{event}' listeners = self.__listeners__.get(event, []) + for coro in listeners: + self._schedule_event(coro, method, *args, **kwargs) + + try: + coro = getattr(self, method) + except AttributeError: + pass + else: + self._schedule_event(coro, method, *args, **kwargs) + + async def _run_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any + ) -> None: + try: + await coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception as exc: + try: + await self.on_error(event_name, exc, *args, **kwargs) + except asyncio.CancelledError: + pass - for listener in listeners: + def _call_voice_packet_handlers(self, user: abc.Snowflake, packet: RawData) -> None: + for handler in self._handlers: task = asyncio.create_task( - listener(*args, **kwargs), - name=f'dispatch-{event}:{id(listener):#x}', + utils.maybe_coroutine( + handler.handle_packet, + self, + user, + packet, + ) ) self.__dispatch_set.add(task) task.add_done_callback(self.__dispatch_set.remove) + def _call_speaking_state_handlers(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + for handler in self._handlers: + task = asyncio.create_task( + utils.maybe_coroutine( + handler.handle_speaking_state, + self, + user, + before, + after, + ), + ) + self.__dispatch_set.add(task) + task.add_done_callback(self.__dispatch_set.remove) + + def _schedule_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task: + wrapped = self._run_event(coro, event_name, *args, **kwargs) + + task = asyncio.create_task(wrapped, name=f'sinks: {event_name}') + self.__dispatch_set.add(task) + task.add_done_callback(self.__dispatch_set.discard) + return task + + def __repr__(self) -> str: + return f'<{self.__class__.__name__} id={id(self):#x}>' + + def stop(self) -> None: + """Stops this sink's recording. + + This is the place where :meth:`.cleanup` should be called. + """ + self.cleanup() + def cleanup(self) -> None: """Cleans all the data in this sink. @@ -440,12 +484,12 @@ def cleanup(self) -> None: def __del__(self) -> None: self.cleanup() - def add_filter(self, filter: Filter, /) -> None: + def add_filter(self, filter: SinkFilter, /) -> None: """Adds a filter to this sink. Parameters ---------- - filter: :class:`~.Filter` + filter: :class:`~.SinkFilter` The filter to add. Raises @@ -454,16 +498,16 @@ def add_filter(self, filter: Filter, /) -> None: You did not provide a Filter object. """ - if not isinstance(filter, Filter): + if not isinstance(filter, SinkFilter): raise TypeError(f'expected a Filter object, not {filter.__class__.__name__}') self._filters.append(filter) - def remove_filter(self, filter: Filter, /) -> None: + def remove_filter(self, filter: SinkFilter, /) -> None: """Removes a filter from this sink. Parameters ---------- - filter: :class:`~.Filter` + filter: :class:`~.SinkFilter` The filter to remove. """ @@ -472,12 +516,12 @@ def remove_filter(self, filter: Filter, /) -> None: except ValueError: pass - def add_handler(self, handler: Handler, /) -> None: + def add_handler(self, handler: SinkHandler, /) -> None: """Adds a handler to this sink. Parameters ---------- - handler: :class:`~.Handler` + handler: :class:`~.SinkHandler` The handler to add. Raises @@ -486,16 +530,16 @@ def add_handler(self, handler: Handler, /) -> None: You did not provide a Handler object. """ - if not isinstance(handler, Handler): + if not isinstance(handler, SinkHandler): raise TypeError(f'expected a Handler object, not {handler.__class__.__name__}') self._handlers.append(handler) - def remove_handler(self, handler: Handler, /) -> None: + def remove_handler(self, handler: SinkHandler, /) -> None: """Removes a handler from this sink. Parameters ---------- - handler: :class:`~.Handler` + handler: :class:`~.SinkHandler` The handler to remove. """ @@ -549,3 +593,53 @@ def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutin func.__listener_name__ = event return func return decorator + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + pass + + async def on_unfiltered_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + pass + + async def on_speaking_state_update(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + pass + + async def on_unfiltered_speaking_state_update(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + pass + + async def on_error(self, event: str, exception: Exception, *args: Any, **kwargs: Any) -> None: + _log.exception( + 'An error ocurred in sink %s while dispatching the event %s', + self, + event, + exc_info=exception, + ) + + def is_recording(self) -> bool: + """Whether this sink is currently available to record, and doing so.""" + state = self.client._connection + return state.is_recording() and self in state.sinks + + def is_paused(self) -> bool: + """Whether this sink is currently paused from recording.""" + return self._paused + + def pause(self) -> None: + """Pauses the recording of this sink. + + No filter or handlers will be called when a sink is paused, and no + event will be dispatched. + + Pending events _could still be called_ even when a sink is paused, + so make sure you pause a sink when there are not current packets being + handled. + + You can resume the recording of this sink with :meth:`.resume`. + """ + self._paused = True + + def resume(self) -> None: + """Resumes the recording of this sink. + + You can pause the recording of this sink with :meth:`.pause`. + """ + self._paused = False diff --git a/discord/sinks/enums.py b/discord/sinks/enums.py index 6f552bf787..9dfb95e4cd 100644 --- a/discord/sinks/enums.py +++ b/discord/sinks/enums.py @@ -26,7 +26,11 @@ from discord.enums import Enum +__all__ = ( + 'SinkFilteringMode', +) -class FilteringMode(Enum): + +class SinkFilteringMode(Enum): all = 0 any = 1 diff --git a/discord/sinks/errors.py b/discord/sinks/errors.py index 5f036efff5..51e00db73f 100644 --- a/discord/sinks/errors.py +++ b/discord/sinks/errors.py @@ -87,3 +87,25 @@ class MKASinkError(SinkException): .. versionadded:: 2.0 """ + + +class MaxProcessesCountReached(SinkException): + """Exception thrown when you try to create an audio converter process and the maximum + process count threshold is exceeded. + + .. versionadded:: 2.7 + """ + + +class FFmpegNotFound(SinkException): + """Exception thrown when the provided FFmpeg executable path was not found. + + .. versionadded:: 2.7 + """ + + +class NoUserAdio(SinkException): + """Exception thrown when you try to format the audio of a user not saved in a sink. + + .. versionadded:: 2.7 + """ diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 1cff9da538..66376228f7 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -21,83 +21,228 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations +from collections import deque import io +import logging import os import subprocess import time +from typing import TYPE_CHECKING, Literal, overload -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import M4ASinkError +from discord import utils +from discord.file import File +from discord.utils import MISSING + +from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc + +_log = logging.getLogger(__name__) + +__all__ = ( + 'M4AConverterHandler', + 'M4ASink', +) + + +class M4AConverterHandler(SinkHandler['M4ASink']): + def handle_packet(self, sink: M4ASink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class M4ASink(Sink): """A special sink for .m4a files. - .. versionadded:: 2.0 - """ + This is essentially a :class:`~.Sink` with a :class:`~.M4AConverterHandler` handler + passed as a default. - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + .. versionadded:: 2.0 - self.encoding = "m4a" - self.vc = None - self.audio_data = {} + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + """ - def format_audio(self, audio): - """Formats the recorded audio. + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(M4AConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. + + If this is called during recording, there could be missing audio + packets. + + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. M4ASinkError - Audio may only be formatted after recording is finished. - M4ASinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around M4ASinkError. """ - if self.vc.recording: - raise M4ASinkError( - "Audio may only be formatted after recording is finished." - ) - m4a_file = f"{time.time()}.tmp" + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + + temp_path = f'{user_id}-{time.time()}-recording.m4a.tmp' args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ipod", - m4a_file, + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'ipod', + temp_path, ] - if os.path.exists(m4a_file): - os.remove( - m4a_file - ) # process will get stuck asking whether to overwrite, if file already exists. + + if os.path.exists(temp_path): + found = utils.find(lambda d: d[0] == temp_path, self.__process_queue) + if found: + _, old_process = found + old_process.kill() + _log.info('Killing old process (%s) to write in %s', old_process, temp_path) + + os.remove(temp_path) # process would get stuck asking whether to overwrite, if file already exists. + try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - except FileNotFoundError: - raise M4ASinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) + self.__process_queue.append((temp_path, process)) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise M4ASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise M4ASinkError(f'Audio formatting for user {user_id} failed') from exc - process.communicate(audio.file.read()) + process.communicate(data.read()) - with open(m4a_file, "rb") as f: - audio.file = io.BytesIO(f.read()) - audio.file.seek(0) - os.remove(m4a_file) + with open(temp_path, 'rb') as file: + buffer = io.BytesIO(file.read()) + buffer.seek(0) - audio.on_format(self.encoding) + try: + self.__process_queue.remove((temp_path, process)) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.m4a') + return buffer + + def _clean_process(self, path: str, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s (with temporary file at %s)', process, self, path) + process.kill() + if os.path.exists(path): + os.remove(path) + + def cleanup(self) -> None: + for path, process in self.__process_queue: + self._clean_process(path, process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index c2bbefb923..8dccea1d7c 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -22,75 +22,212 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from collections import deque import io +import logging import subprocess +import time +from typing import TYPE_CHECKING, Literal, overload + +from discord.file import File +from discord.utils import MISSING + +from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, MKASinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MKASinkError +_log = logging.getLogger(__name__) + +__all__ = ( + 'MKAConverterHandler', + 'MKASink', +) + + +class MKAConverterHandler(SinkHandler['MKASink']): + def handle_packet(self, sink: MKASink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class MKASink(Sink): """A special sink for .mka files. + This is essentially a :class:`~.Sink` with a :class:`~.MKAConverterHandler` handler + passed as a default. + .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(MKAConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. - self.encoding = "mka" - self.vc = None - self.audio_data = {} + If this is called during recording, there could be missing audio + packets. - def format_audio(self, audio): - """Formats the recorded audio. + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. MKASinkError - Audio may only be formatted after recording is finished. - MKASinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around MKASinkError. """ - if self.vc.recording: - raise MKASinkError( - "Audio may only be formatted after recording is finished." - ) + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'matroska', + 'pipe:1', ] + try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MKASinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.__process_queue.append(process) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MKASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + raise MKASinkError(f'Audio formatting for user {user_id} failed') from exc + + out = process.communicate(data.read())[0] + buffer = io.BytesIO(out) + buffer.seek(0) + + try: + self.__process_queue.remove(process) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.mka') + return buffer + + def _clean_process(self, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s', process, self) + process.kill() + + def cleanup(self) -> None: + for process in self.__process_queue: + self._clean_process(process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 93f4cc7444..283ac76344 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -22,74 +22,212 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from collections import deque import io +import logging import subprocess +import time +from typing import TYPE_CHECKING, Literal, overload + +from discord.file import File +from discord.utils import MISSING + +from .core import SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, MKVSinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc -from .core import Filters, Sink, default_filters -from .errors import MKVSinkError +_log = logging.getLogger(__name__) + +__all__ = ( + 'MKVConverterHandler', + 'MKVSink', +) + + +class MKVConverterHandler(SinkHandler['MKVSink']): + def handle_packet(self, sink: MKVSink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class MKVSink(Sink): """A special sink for .mkv files. + This is essentially a :class:`~.Sink` with a :class:`~.MKVConverterHandler` handler + passed as a default. + .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(MKVConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. - self.encoding = "mkv" - self.vc = None - self.audio_data = {} + If this is called during recording, there could be missing audio + packets. - def format_audio(self, audio): - """Formats the recorded audio. + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. MKVSinkError - Audio may only be formatted after recording is finished. - MKVSinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around MKVSinkError. """ - if self.vc.recording: - raise MKVSinkError( - "Audio may only be formatted after recording is finished." - ) + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'matroska', + 'pipe:1', ] + try: - process = subprocess.Popen( - args, # creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MKVSinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.__process_queue.append(process) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MKVSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + raise MKVSinkError(f'Audio formatting for user {user_id} failed') from exc + + out = process.communicate(data.read())[0] + buffer = io.BytesIO(out) + buffer.seek(0) + + try: + self.__process_queue.remove(process) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.mkv') + return buffer + + def _clean_process(self, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s', process, self) + process.kill() + + def cleanup(self) -> None: + for process in self.__process_queue: + self._clean_process(process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 74386a2738..21de10fc95 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -22,75 +22,212 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from collections import deque import io +import logging import subprocess +import time +from typing import TYPE_CHECKING, Literal, overload + +from discord.file import File +from discord.utils import MISSING + +from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, MP3SinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MP3SinkError +_log = logging.getLogger(__name__) + +__all__ = ( + 'MP3ConverterHandler', + 'MP3Sink', +) + + +class MP3ConverterHandler(SinkHandler['MP3Sink']): + def handle_packet(self, sink: MP3Sink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class MP3Sink(Sink): """A special sink for .mp3 files. + This is essentially a :class:`~.Sink` with a :class:`~.MP3ConverterHandler` handler + passed as a default. + .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(MP3ConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. - self.encoding = "mp3" - self.vc = None - self.audio_data = {} + If this is called during recording, there could be missing audio + packets. - def format_audio(self, audio): - """Formats the recorded audio. + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. MP3SinkError - Audio may only be formatted after recording is finished. - MP3SinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around MP3SinkError. """ - if self.vc.recording: - raise MP3SinkError( - "Audio may only be formatted after recording is finished." - ) + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp3", - "pipe:1", + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'mp3', + 'pipe:1', ] + try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - ) - except FileNotFoundError: - raise MP3SinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.__process_queue.append(process) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MP3SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + raise MP3SinkError(f'Audio formatting for user {user_id} failed') from exc + + out = process.communicate(data.read())[0] + buffer = io.BytesIO(out) + buffer.seek(0) + + try: + self.__process_queue.remove(process) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.mp3') + return buffer + + def _clean_process(self, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s', process, self) + process.kill() + + def cleanup(self) -> None: + for process in self.__process_queue: + self._clean_process(process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index c4d0ed2b63..b84bab40c8 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -22,82 +22,228 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from collections import deque import io +import logging import os import subprocess import time +from typing import TYPE_CHECKING, Literal, overload + +from discord import utils +from discord.file import File +from discord.utils import MISSING + +from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, MP4SinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc + +_log = logging.getLogger(__name__) -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import MP4SinkError +__all__ = ( + 'MP4ConverterHandler', + 'MP4Sink', +) + + +class MP4ConverterHandler(SinkHandler['MP4Sink']): + def handle_packet(self, sink: MP4Sink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class MP4Sink(Sink): """A special sink for .mp4 files. - .. versionadded:: 2.0 - """ + This is essentially a :class:`~.Sink` with a :class:`~.MP4ConverterHandler` handler + passed as a default. - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + .. versionadded:: 2.0 - self.encoding = "mp4" - self.vc = None - self.audio_data = {} + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + """ - def format_audio(self, audio): - """Formats the recorded audio. + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(MP4ConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. + + If this is called during recording, there could be missing audio + packets. + + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ - MP4SinkError - Audio may only be formatted after recording is finished. + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. MP4SinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around MP4SinkError. """ - if self.vc.recording: - raise MP4SinkError( - "Audio may only be formatted after recording is finished." - ) - mp4_file = f"{time.time()}.tmp" + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + + temp_path = f'{user_id}-{time.time()}-recording.mp4.tmp' args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp4", - mp4_file, + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'mp4', + temp_path, ] - if os.path.exists(mp4_file): - os.remove( - mp4_file - ) # process will get stuck asking whether to overwrite, if file already exists. + + if os.path.exists(temp_path): + found = utils.find(lambda d: d[0] == temp_path, self.__process_queue) + if found: + _, old_process = found + old_process.kill() + _log.info('Killing old process (%s) to write in %s', old_process, temp_path) + + os.remove(temp_path) # process would get stuck asking whether to overwrite, if file already exists. + try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - except FileNotFoundError: - raise MP4SinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) + self.__process_queue.append((temp_path, process)) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MP4SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise MP4SinkError(f'Audio formatting for user {user_id} failed') from exc - process.communicate(audio.file.read()) + process.communicate(data.read()) - with open(mp4_file, "rb") as f: - audio.file = io.BytesIO(f.read()) - audio.file.seek(0) - os.remove(mp4_file) + with open(temp_path, 'rb') as file: + buffer = io.BytesIO(file.read()) + buffer.seek(0) - audio.on_format(self.encoding) + try: + self.__process_queue.remove((temp_path, process)) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.mp4') + return buffer + + def _clean_process(self, path: str, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s (with temporary file at %s)', process, self, path) + process.kill() + if os.path.exists(path): + os.remove(path) + + def cleanup(self) -> None: + for path, process in self.__process_queue: + self._clean_process(path, process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 7b531464bd..9755c34571 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -22,75 +22,212 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from collections import deque import io +import logging import subprocess +import time +from typing import TYPE_CHECKING, Literal, overload + +from discord.file import File +from discord.utils import MISSING + +from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .enums import SinkFilteringMode +from .errors import FFmpegNotFound, OGGSinkError, MaxProcessesCountReached, NoUserAdio + +if TYPE_CHECKING: + from discord import abc -from .core import CREATE_NO_WINDOW, Filters, Sink, default_filters -from .errors import OGGSinkError +_log = logging.getLogger(__name__) + +__all__ = ( + 'OGGConverterHandler', + 'OGGSink', +) + + +class OGGConverterHandler(SinkHandler['OGGSink']): + def handle_packet(self, sink: OGGSink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class OGGSink(Sink): """A special sink for .ogg files. + This is essentially a :class:`~.Sink` with a :class:`~.OGGConverterHandler` handler + passed as a default. + .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. + max_audio_processes_count: :class:`int` + The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then + when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + max_audio_processes_count: int = 10, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + handlers = handlers or [] + handlers.append(OGGConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audio data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + executable: str = ..., + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + executable: str = 'ffmpeg', + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. - self.encoding = "ogg" - self.vc = None - self.audio_data = {} + If this is called during recording, there could be missing audio + packets. - def format_audio(self, audio): - """Formats the recorded audio. + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + executable: :class:`str` + The FFmpeg executable path to use for this formatting. It defaults + to ``ffmpeg``. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + FFmpegNotFound + The provided FFmpeg executable was not found. + MaxProcessesCountReached + You tried to go over the maximum processes count threshold. OGGSinkError - Audio may only be formatted after recording is finished. - OGGSinkError - Formatting the audio failed. + Any error raised while formatting, wrapped around OGGSinkError. """ - if self.vc.recording: - raise OGGSinkError( - "Audio may only be formatted after recording is finished." - ) + + if len(self.__process_queue) >= 10: + raise MaxProcessesCountReached + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + _log.info('There is no audio data for %s, ignoring.', user_id) + raise NoUserAdio + args = [ - "ffmpeg", - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ogg", - "pipe:1", + executable, + '-f', + 's16le', + '-ar', + '48000', + '-loglevel', + 'error', + '-ac', + '2', + '-i', + '-', + '-f', + 'ogg', + 'pipe:1', ] + try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - except FileNotFoundError: - raise OGGSinkError("ffmpeg was not found.") from None + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + self.__process_queue.append(process) + except FileNotFoundError as exc: + raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise OGGSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc - - out = process.communicate(audio.file.read())[0] - out = io.BytesIO(out) - out.seek(0) - audio.file = out - audio.on_format(self.encoding) + raise OGGSinkError(f'Audio formatting for user {user_id} failed') from exc + + out = process.communicate(data.read())[0] + buffer = io.BytesIO(out) + buffer.seek(0) + + try: + self.__process_queue.remove(process) + except ValueError: + pass + + if as_file: + return File(buffer, filename=f'{user_id}-{time.time()}-recording.ogg') + return buffer + + def _clean_process(self, process: subprocess.Popen) -> None: + _log.debug('Cleaning process %s for sink %s', process, self) + process.kill() + + def cleanup(self) -> None: + for process in self.__process_queue: + self._clean_process(process) + self.__process_queue.clear() + + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index c587da349a..88fcd6c29c 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -22,24 +22,145 @@ DEALINGS IN THE SOFTWARE. """ -from .core import Filters, Sink, default_filters +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Literal, overload + +from discord.file import File +from discord.utils import MISSING + +from .core import RawData, Sink, SinkHandler, SinkFilter +from .errors import NoUserAdio +from .enums import SinkFilteringMode + +if TYPE_CHECKING: + from discord import abc + +__all__ = ( + 'PCMConverterHandler', + 'PCMSink', +) + + +class PCMConverterHandler(SinkHandler['PCMSink']): + def handle_packet(self, sink: PCMSink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) class PCMSink(Sink): """A special sink for .pcm files. + This is essentially a :class:`~.Sink` with a :class:`.PCMConverterHandler` handler + passed as a default. + .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + handlers = handlers or [] + handlers.append(PCMConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audiop data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + as_file: Literal[True], + ) -> File: ... + + @overload + def format_user_audio( + self, + user_id: int, + *, + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... + + def format_user_audio( + self, + user_id: int, + *, + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. + + If this is called during recording, there could be missing audio + packets. + + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. + + Raises + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. + """ + + try: + data = self.__audio_data.pop(user_id) + except KeyError: + raise NoUserAdio + + data.seek(0) + + if as_file: + return File(data, filename=f'{user_id}-recording.pcm') + return data - self.encoding = "pcm" - self.vc = None - self.audio_data = {} + def cleanup(self) -> None: + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() - def format_audio(self, audio): - return + self.__audio_data.clear() + super().cleanup() diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index 37f5aac933..0cb14d158c 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -22,48 +22,164 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Literal, overload import wave -from .core import Filters, Sink, default_filters -from .errors import WaveSinkError +from discord.file import File +from discord.utils import MISSING + +from .core import SinkFilter, SinkHandler, RawData, Sink +from .enums import SinkFilteringMode +from .errors import NoUserAdio + +if TYPE_CHECKING: + from discord import abc + +__all__ = ( + 'WaveConverterHandler', + 'WavConverterHandler', + 'WaveSink', + 'WavSink', +) + + +class WaveConverterHandler(SinkHandler['WaveSink']): + def handle_packet(self, sink: WaveSink, user: abc.Snowflake, packet: RawData) -> None: + data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) + data.write(packet.decoded_data) + + +WavConverterHandler: SinkHandler['WavSink'] = WaveConverterHandler # type: ignore class WaveSink(Sink): - """A special sink for .wav(wave) files. + """A special sink for .wav(e) files. + + This is essentially a :class:`~.Sink` with a :class:`.WaveConverterHandler` handler. .. versionadded:: 2.0 + + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. """ - def __init__(self, *, filters=None): - if filters is None: - filters = default_filters - self.filters = filters - Filters.__init__(self, **self.filters) + def __init__( + self, + *, + filters: list[SinkFilter] = MISSING, + filtering_mode: SinkFilteringMode = SinkFilteringMode.all, + handlers: list[SinkHandler] = MISSING, + ) -> None: + self.__audio_data: dict[int, io.BytesIO] = {} + handlers = handlers or [] + handlers.append(WaveConverterHandler()) + + super().__init__( + filters=filters, + filtering_mode=filtering_mode, + handlers=handlers, + ) + + def get_user_audio(self, user_id: int) -> io.BytesIO | None: + """Gets a user's saved audiop data, or ``None``.""" + return self.__audio_data.get(user_id) + + def _create_audio_packet_for(self, uid: int) -> io.BytesIO: + data = self.__audio_data[uid] = io.BytesIO() + return data + + @overload + def format_user_audio( + self, + user_id: int, + *, + as_file: Literal[True], + ) -> File: ... - self.encoding = "wav" - self.vc = None - self.audio_data = {} + @overload + def format_user_audio( + self, + user_id: int, + *, + as_file: Literal[False] = ..., + ) -> io.BytesIO: ... - def format_audio(self, audio): - """Formats the recorded audio. + def format_user_audio( + self, + user_id: int, + *, + as_file: bool = False, + ) -> io.BytesIO | File: + """Formats a user's saved audio data. + + This should be called after the bot has stopped recording. + + If this is called during recording, there could be missing audio + packets. + + After this, the user's audio data will be resetted to 0 bytes and + seeked to 0. + + Parameters + ---------- + user_id: :class:`int` + The user ID of which format the audio data into a file. + as_file: :class:`bool` + Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. + + Returns + ------- + Union[:class:`io.BytesIO`, :class:`~discord.File`] + The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` + object with the buffer set as the audio bytes. Raises - ------ - WaveSinkError - Audio may only be formatted after recording is finished. - WaveSinkError - Formatting the audio failed. + ------- + NoUserAudio + You tried to format the audio of a user that was not stored in this sink. """ - if self.vc.recording: - raise WaveSinkError( - "Audio may only be formatted after recording is finished." - ) - data = audio.file - with wave.open(data, "wb") as f: - f.setnchannels(self.vc.decoder.CHANNELS) - f.setsampwidth(self.vc.decoder.SAMPLE_SIZE // self.vc.decoder.CHANNELS) - f.setframerate(self.vc.decoder.SAMPLING_RATE) + try: + data = self.__audio_data.pop(user_id) + except KeyError: + raise NoUserAdio + + decoder = self.client.decoder + + with wave.open(data, 'wb') as f: + f.setnchannels(decoder.CHANNELS) + f.setsampwidth(decoder.SAMPLE_SIZE // decoder.CHANNELS) + f.setframerate(decoder.SAMPLING_RATE) data.seek(0) - audio.on_format(self.encoding) + + if as_file: + return File(data, filename=f'{user_id}-recording.pcm') + return data + + def cleanup(self) -> None: + for _, buffer in self.__audio_data.items(): + if not buffer.closed: + buffer.close() + + self.__audio_data.clear() + super().cleanup() + + +WavSink = WaveSink +"""An alias for :class:`~.WaveSink`. + +.. versionadded:: 2.7 +""" diff --git a/discord/voice/client.py b/discord/voice/client.py index 5a43c6db9d..27d9b53cb8 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -29,12 +29,13 @@ import datetime import logging import struct -from collections.abc import Callable +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, Literal, overload from discord import opus from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource +from discord.sinks.errors import RecordingException from discord.utils import MISSING from ._types import VoiceProtocol @@ -596,3 +597,95 @@ def elapsed(self) -> datetime.timedelta: if self._player: return datetime.timedelta(milliseconds=self._player.played_frames() * 20) return datetime.timedelta() + + def start_recording( + self, + sink: Sink, + callback: Callable[..., Coroutine[Any, Any, Any]] = MISSING, + *args: Any, + sync_start: bool = MISSING, + ) -> None: + r"""Start recording the audio from the current connected channel to the provided sink. + + .. versionadded:: 2.0 + .. versionchanged:: 2.7 + You can now have multiple concurrent recording sinks in the same voice client. + + Parameters + ---------- + sink: :class:`~.Sink` + A Sink in which all audio packets will be processed in. + callback: :ref:`coroutine ` + A function which is called after the bot has stopped recording. + + .. versionchanged:: 2.7 + This parameter is now optional. + \*args: + The arguments to pass to the callback coroutine. + sync_start: :class:`bool` + If ``True``, the recordings of subsequent users will start with silence. + This is useful for recording audio just as it was heard. + + .. warning:: + + This is a global voice client variable, this means, you can't have individual + sinks with different ``sync_start`` values. If you are willing to have such + functionality, you should consider creating your own :class:`discord.SinkHandler`. + + .. versionchanged:: 2.7 + This now defaults to ``MISSING``. + + Raises + ------ + RecordingException + Not connected to a voice channel + TypeError + You did not provide a Sink object. + """ + + if not self.is_connected(): + raise RecordingException('not connected to a voice channel') + if not isinstance(sink, Sink): + raise TypeError(f'expected a Sink object, got {sink.__class__.__name__}') + + if sync_start is not MISSING: + self._connection.sync_recording_start = sync_start + + sink.client = self + self._connection.add_sink(sink) + if callback is not MISSING: + self._connection.recording_done_callbacks.append((callback, args)) + + def stop_recording( + self, + *, + sink: Sink | None = None, + ) -> None: + """Stops the recording of the provided ``sink``, or all recording sinks. + + .. versionadded:: 2.0 + + Paremeters + ---------- + sink: :class:`discord.Sink` + The sink to stop recording. + + Raises + ------ + RecordingException + The provided sink is not currently recording, or if ``None``, you are not recording. + """ + + if sink is not None: + try: + self._connection.sinks.remove(sink) + except ValueError: + raise RecordingException('the provided sink is not currently recording') + + sink.stop() + return + self._connection.stop_record_socket() + + def is_recording(self) -> bool: + """Whether the current client is recording in any sink.""" + return self._connection.is_recording() diff --git a/discord/voice/flags.py b/discord/voice/flags.py deleted file mode 100644 index 45c75fd21e..0000000000 --- a/discord/voice/flags.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-2021 Rapptz -Copyright (c) 2021-present Pycord Development - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" -from __future__ import annotations - -from discord.flags import BaseFlags, fill_with_flags, flag_value - - -@fill_with_flags() -class SpeakingFlags(BaseFlags): - r"""Wraps up a Discord user speaking state flag value. - - .. container:: operations - - .. describe:: x == y - - Checks if two flags are equal. - .. describe:: x != y - - Checks if two flags are not equal. - .. describe:: x + y - - Adds two flags together. Equivalent to ``x | y``. - .. describe:: x - y - - Substract two flags from each other. - .. describe:: x | y - - Returns the union of two flags. Equivalent to ``x + y``. - .. describe:: x & y - - Returns the intersection of two flags. - .. describe:: ~x - - Returns the inverse of a flag. - .. describe:: hash(x) - - Returns the flag's hash. - .. describe:: iter(x) - - Returns an iterator of ``(name, value)`` pairs. This allows it - to be, for example, constructed as a dict or a list of pairs. - - .. versionadded:: 2.7 - - Attributes - ---------- - value: :class:`int` - The raw value. This value is a bit array field of a 53-bit integer - representing the currently available flags. You should query - flags via the properties rather than using this raw value. - """ - - @flag_value - def voice(self): - """:class:`bool`: Normal transmission of voice audio""" - return 1 << 0 - - @flag_value - def soundshare(self): - """:class:`bool`: Transmission of context audio for video, no speaking indicator""" - return 1 << 1 - - @flag_value - def priority(self): - """:class:`bool`: Priority speaker, lowering audio of other speakers""" - return 1 << 2 diff --git a/discord/voice/state.py b/discord/voice/state.py index acb16dbcac..f6e0ebb174 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -26,21 +26,25 @@ from __future__ import annotations import asyncio +from collections import deque import logging import select import socket +import struct import threading from collections.abc import Callable, Coroutine +import time from typing import TYPE_CHECKING, Any, TypedDict -from discord import utils +from discord import utils, opus from discord.backoff import ExponentialBackoff +from discord.enums import try_enum, SpeakingState from discord.errors import ConnectionClosed from discord.object import Object +from discord.sinks import RawData, Sink from .enums import ConnectionFlowState, OpCodes from .gateway import VoiceWebSocket -from .flags import SpeakingFlags if TYPE_CHECKING: from discord import abc @@ -49,6 +53,7 @@ from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent from discord.types.voice import SupportedModes from discord.user import ClientUser + from discord.state import ConnectionState from .client import VoiceClient @@ -62,6 +67,7 @@ def __init__( self, state: VoiceConnectionState, name: str, + buffer_size: int, *, start_paused: bool = True, ) -> None: @@ -70,6 +76,7 @@ def __init__( name=name, ) + self.buffer_size: int = buffer_size self.state: VoiceConnectionState = state self.start_paused: bool = start_paused self._callbacks: list[SocketReaderCallback] = [] @@ -101,6 +108,9 @@ def pause(self) -> None: self._idle_paused = False self._running.clear() + def is_paused(self) -> bool: + return self._idle_paused or (not self._running.is_set() and not self._end.is_set()) + def resume(self, *, force: bool = False) -> None: if self._running.is_set(): return @@ -137,20 +147,6 @@ def run(self) -> None: self._running.clear() self._callbacks.clear() - def _do_run(self) -> None: - raise NotImplementedError - - -class SocketVoiceRecvReader(SocketReader): - def __init__( - self, state: VoiceConnectionState, *, start_paused: bool = True, - ) -> None: - super().__init__( - state, - f'voice-recv-socket-reader:{id(self):#x}', - start_paused=start_paused, - ) - def _do_run(self) -> None: while not self._end.is_set(): if not self._running.is_set(): @@ -171,7 +167,7 @@ def _do_run(self) -> None: continue try: - data = self.state.socket.recv(4096) + data = self.state.socket.recv(self.buffer_size) except OSError: _log.debug( 'Error reading from socket in %s, this should be safe to ignore', @@ -190,6 +186,18 @@ def _do_run(self) -> None: ) +class SocketVoiceRecvReader(SocketReader): + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True, + ) -> None: + super().__init__( + state, + f'voice-recv-socket-reader:{id(self):#x}', + 4096, + start_paused=start_paused, + ) + + class SocketEventReader(SocketReader): def __init__( self, state: VoiceConnectionState, *, start_paused: bool = True @@ -197,9 +205,88 @@ def __init__( super().__init__( state, f'voice-socket-event-reader:{id(self):#x}', + 2048, start_paused=start_paused, ) + +class DecoderThread(threading.Thread, opus._OpusStruct): + def __init__( + self, state: VoiceConnectionState, *, start_paused: bool = True, + ) -> None: + super().__init__( + daemon=True, + name=f'voice-recv-decoder-thread:{id(self):#x}', + ) + + self.state: VoiceConnectionState = state + self.client: VoiceClient = state.client + self.start_paused: bool = start_paused + self._idle_paused: bool = True + self._started: threading.Event = threading.Event() + self._running: threading.Event = threading.Event() + self._end: threading.Event = threading.Event() + + self.decode_queue: deque[RawData] = deque() + self.decoders: dict[int, opus.Decoder] = {} + + self._end: threading.Event = threading.Event() + + def decode(self, frame: RawData) -> None: + if not isinstance(frame, RawData): + raise TypeError(f'expected a RawData object, got {frame.__class__.__name__}') + self.decode_queue.append(frame) + + def is_running(self) -> bool: + return self._started.is_set() + + def pause(self) -> None: + self._idle_paused = False + self._running.clear() + + def resume(self, *, force: bool = False) -> None: + if self._running.is_set(): + return + + if not force and not self.decode_queue: + self._idle_paused = True + return + + self._idle_paused = False + self._running.set() + + def stop(self) -> None: + self._started.clear() + self._end.set() + self._running.set() + + def run(self) -> None: + self._started.set() + self._end.clear() + self._running.set() + + if self.start_paused: + self.pause() + + try: + self._do_run() + except Exception: + _log.exception( + 'An error ocurred while running the decoder thread %s', + self.name, + ) + finally: + self.stop() + self._running.clear() + self.decode_queue.clear() + + def get_decoder(self, ssrc: int) -> opus.Decoder: + try: + return self.decoders[ssrc] + except KeyError: + d = self.decoders[ssrc] = opus.Decoder() + return d + def _do_run(self) -> None: while not self._end.is_set(): if not self._running.is_set(): @@ -207,41 +294,29 @@ def _do_run(self) -> None: continue try: - readable, _, _ = select.select([self.state.socket], [], [], 30) - except (ValueError, TypeError, OSError) as e: - _log.debug( - "Select error handling socket in reader, this should be safe to ignore: %s: %s", - e.__class__.__name__, - e, - ) - continue - - if not readable: + data = self.decode_queue.popleft() + except IndexError: continue try: - data = self.state.socket.recv(2048) - except OSError: - _log.debug( - "Error reading from socket in %s, this should be safe to ignore.", - self, + if data.decrypted_data is None: + continue + else: + data.decoded_data = self.get_decoder(data.ssrc).decode( + data.decrypted_data, + ) + except opus.OpusError: + _log.exception( + 'Error ocurred while decoding opus frame', exc_info=True, ) - else: - for cb in self._callbacks: - try: - cb(data) - except Exception: - _log.exception( - "Error while calling %s in %s", - cb, - self, - ) + + self.state.dispatch_packet_sinks(data) class SSRC(TypedDict): user_id: int - speaking: SpeakingFlags + speaking: SpeakingState class VoiceConnectionState: @@ -255,6 +330,7 @@ def __init__( ) -> None: self.client: VoiceClient = client self.hook = hook + self.loop: asyncio.AbstractEventLoop = client.loop self.timeout: float = 30.0 self.reconnect: bool = True @@ -274,6 +350,7 @@ def __init__( self.session_id: str | None = None self.token: str | None = None + self._connection: ConnectionState = client._state self._state: ConnectionFlowState = ConnectionFlowState.disconnected self._expecting_disconnect: bool = False self._connected = threading.Event() @@ -284,7 +361,15 @@ def __init__( self._socket_reader = SocketEventReader(self) self._socket_reader.start() self._voice_recv_socket = SocketVoiceRecvReader(self) + self._voice_recv_socket.register(self.handle_voice_recv_packet) + self._decoder_thread = DecoderThread(self) self.user_ssrc_map: dict[int, SSRC] = {} + self.user_voice_timestamps: dict[int, tuple[int, float]] = {} + self.sync_recording_start: bool = False + self.first_received_packet_ts: float = MISSING + self.sinks: list[Sink] = [] + self.recording_done_callbacks: list[tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]]] = [] + self.__sink_dispatch_task_set: set[asyncio.Task[Any]] = set() def start_record_socket(self) -> None: if self._voice_recv_socket.is_running(): @@ -295,13 +380,142 @@ def stop_record_socket(self) -> None: if self._voice_recv_socket.is_running(): self._voice_recv_socket.stop() + for cb, args in self.recording_done_callbacks: + task = self.loop.create_task(cb(*args)) + self.__sink_dispatch_task_set.add(task) + task.add_done_callback(self.__sink_dispatch_task_set.remove) + + for sink in self.sinks: + sink.stop() + + self.recording_done_callbacks.clear() + self.sinks.clear() + + def handle_voice_recv_packet(self, packet: bytes) -> None: + if packet[1] != 0x78: + # We should ignore any payload types we do not understand + # Ref: RFC 3550 5.1 payload type + # At some point we noted that we should ignore only types 200 - 204 inclusive. + # They were marked as RTCP: provides information about the connection + # this was too broad of a whitelist, it is unclear if this is too narrow of a whitelist + return + + if self.paused_recording(): + return + + data = RawData(packet, self.client) + + if data.decrypted_data != opus.OPUS_SILENCE: + return + + self._decoder_thread.decode(data) + + def is_first_packet(self) -> bool: + return not self.user_voice_timestamps or not self.sync_recording_start + + def dispatch_packet_sinks(self, data: RawData) -> None: + + if data.ssrc not in self.user_ssrc_map: + if self.is_first_packet(): + self.first_received_packet_ts = data.receive_time + silence = 0 + else: + silence = ( + (data.receive_time - self.first_received_packet_ts) * 48000 + ) + else: + stored_timestamp, stored_recv_time = self.user_voice_timestamps[data.ssrc] + dRT = data.receive_time - stored_recv_time * 48000 + dT = data.timestamp - stored_timestamp + diff = abs(100 - dT * 100 / dRT) + + if diff > 60 and dT != 960: + silence = dRT - 960 + else: + silence = dT - 960 + + self.user_voice_timestamps[data.ssrc] = (data.timestamp, data.receive_time) + + data.decoded_data = ( + struct.pack(' None: + user = self.get_user_by_ssrc(data.ssrc) + if not user: + _log.debug( + 'Ignoring received packet %s because the SSRC was waited for but was not found', + data, + ) + return + + data.user_id = user.id + + for sink in self.sinks: + if sink.is_paused(): + continue + + sink.dispatch('unfiltered_voice_packet_receive', user, data) + + futures = [ + self.loop.create_task(utils.maybe_coroutine(fil.filter_packet, sink, user, data)) + for fil in sink._filters + ] + strat = sink._filter_strat + + done, pending = await asyncio.wait(futures) + + if pending: + for task in pending: + task.set_result(False) + + done = (*done, *pending) + + if strat([f.result() for f in done]): + sink.dispatch('voice_packet_receive', user, data) + sink._call_voice_packet_handlers(user, data) + + def is_recording(self) -> bool: + return self._voice_recv_socket.is_running() + + def paused_recording(self) -> bool: + return self._voice_recv_socket.is_paused() + + def add_sink(self, sink: Sink) -> None: + self.sinks.append(sink) + self.start_record_socket() + + def remove_sink(self, sink: Sink) -> None: + try: + self.sinks.remove(sink) + except ValueError: + pass + def get_user_by_ssrc(self, ssrc: int) -> abc.Snowflake | None: data = self.user_ssrc_map.get(ssrc) if data is None: return None user = int(data['user_id']) - return self.guild.get_member(user) or self.client._state.get_user(user) or Object(id=user) + return self.get_user(user) + + def get_user(self, id: int) -> abc.Snowflake: + state = self._connection + return ( + self.guild.get_member(id) or + state.get_user(id) or + Object(id=id) + ) def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: op = msg['op'] @@ -310,16 +524,48 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: if op == OpCodes.speaking: ssrc = data['ssrc'] user = int(data['user_id']) - speaking = data['speaking'] + raw_speaking = data['speaking'] + speaking = try_enum(SpeakingState, raw_speaking) + old_data = self.user_ssrc_map.get(ssrc) + old_speaking = (old_data or {}).get('speaking', SpeakingState.none) + + self._dispatch_speaking_state(old_speaking, speaking, user) - if ssrc in self.user_ssrc_map: - self.user_ssrc_map[ssrc]['speaking'].value = speaking + if old_data is None: + self.user_ssrc_map[ssrc]['speaking'] = speaking else: self.user_ssrc_map[ssrc] = { 'user_id': user, - 'speaking': SpeakingFlags._from_value(speaking), + 'speaking': speaking, } + def _dispatch_speaking_state(self, before: SpeakingState, after: SpeakingState, uid: int) -> None: + resolved = self.get_user(uid) + + for sink in self.sinks: + if sink.is_paused(): + continue + + sink.dispatch('unfiltered_speaking_state_update', resolved, before, after) + + futures = [ + self.loop.create_task(utils.maybe_coroutine(fil.filter_packet, sink, user, data)) + for fil in sink._filters + ] + strat = sink._filter_strat + + done, pending = await asyncio.wait(futures) + + if pending: + for task in pending: + task.set_result(False) + + done = (*done, *pending) + + if strat([f.result() for f in done]): + sink.dispatch('speaking_state_update', resolved, before, after) + sink._call_speaking_state_handlers(resolved, before, after) + @property def state(self) -> ConnectionFlowState: return self._state From 83036da0fe40430db62acf689cc044894cb1ae2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:07:01 +0000 Subject: [PATCH 23/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/sinks/__init__.py | 2 +- discord/sinks/core.py | 158 ++++++++++++++++++++++++-------------- discord/sinks/enums.py | 5 +- discord/sinks/m4a.py | 76 ++++++++++-------- discord/sinks/mka.py | 65 +++++++++------- discord/sinks/mkv.py | 62 ++++++++------- discord/sinks/mp3.py | 65 +++++++++------- discord/sinks/mp4.py | 77 +++++++++++-------- discord/sinks/ogg.py | 65 +++++++++------- discord/sinks/pcm.py | 18 +++-- discord/sinks/wave.py | 26 ++++--- discord/voice/client.py | 16 ++-- discord/voice/state.py | 102 +++++++++++++----------- 13 files changed, 431 insertions(+), 306 deletions(-) diff --git a/discord/sinks/__init__.py b/discord/sinks/__init__.py index 6d38ae74f0..15498aff27 100644 --- a/discord/sinks/__init__.py +++ b/discord/sinks/__init__.py @@ -9,8 +9,8 @@ """ from .core import * -from .errors import * from .enums import * +from .errors import * from .m4a import * from .mka import * from .mkv import * diff --git a/discord/sinks/core.py b/discord/sinks/core.py index eb044ea600..b7af990692 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -26,12 +26,12 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Iterable -from functools import partial import logging import struct import sys import time +from collections.abc import Callable, Coroutine, Iterable +from functools import partial from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from discord import utils @@ -44,16 +44,17 @@ from typing_extensions import ParamSpec from discord import abc + from ..voice.client import VoiceClient - R = TypeVar('R') - P = ParamSpec('P') + R = TypeVar("R") + P = ParamSpec("P") __all__ = ( "Sink", "RawData", - 'SinkFilter', - 'SinkHandler', + "SinkFilter", + "SinkHandler", ) @@ -63,7 +64,7 @@ CREATE_NO_WINDOW = 0x08000000 -S = TypeVar('S', bound='Sink') +S = TypeVar("S", bound="Sink") _log = logging.getLogger(__name__) @@ -76,12 +77,16 @@ class SinkFilter(Generic[S]): """ @overload - async def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool: ... + async def filter_packet( + self, sink: S, user: abc.Snowflake, packet: RawData + ) -> bool: ... @overload def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool: ... - def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool | Coroutine[Any, Any, bool]: + def filter_packet( + self, sink: S, user: abc.Snowflake, packet: RawData + ) -> bool | Coroutine[Any, Any, bool]: """|maybecoro| This is called automatically everytime a voice packet is received. @@ -102,15 +107,21 @@ def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool | :class:`bool` Whether the filter was successful. """ - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") @overload - async def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool: ... + async def filter_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> bool: ... @overload - def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool: ... + def filter_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> bool: ... - def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> bool | Coroutine[Any, Any, bool]: + def filter_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> bool | Coroutine[Any, Any, bool]: """|maybecoro| This is called automatically everytime a speaking state is updated. @@ -133,11 +144,10 @@ def filter_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingSt :class:`bool` Whether the filter was successful. """ - raise NotImplementedError('subclasses must implement this') + raise NotImplementedError("subclasses must implement this") def cleanup(self) -> None: """A function called when the filter is ready for cleanup.""" - pass class SinkHandler(Generic[S]): @@ -149,12 +159,16 @@ class SinkHandler(Generic[S]): """ @overload - async def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any: ... + async def handle_packet( + self, sink: S, user: abc.Snowflake, packet: RawData + ) -> Any: ... @overload def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any: ... - def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any | Coroutine[Any, Any, Any]: + def handle_packet( + self, sink: S, user: abc.Snowflake, packet: RawData + ) -> Any | Coroutine[Any, Any, Any]: """|maybecoro| This is called automatically everytime a voice packet which has successfully passed the filters is received. @@ -168,15 +182,20 @@ def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any | packet: :class:`~.RawData` The raw data packet. """ - pass @overload - async def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any: ... + async def handle_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> Any: ... @overload - def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any: ... + def handle_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> Any: ... - def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> Any | Coroutine[Any, Any, Any]: + def handle_speaking_state( + self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> Any | Coroutine[Any, Any, Any]: """|maybecoro| This is called automatically everytime a speaking state update is received which has successfully passed the filters. @@ -192,11 +211,9 @@ def handle_speaking_state(self, sink: S, user: abc.Snowflake, before: SpeakingSt after: :class:`~discord.SpeakingState` The speaking state after the update. """ - pass def cleanup(self) -> None: """A function called when the handler is ready for cleanup.""" - pass class RawData: @@ -340,10 +357,12 @@ def __init_subclass__(cls) -> None: elif isinstance(value, classmethod): value = partial(value.__func__, cls) - if not hasattr(value, '__listener__'): + if not hasattr(value, "__listener__"): continue - event_name = getattr(value, '__listener_name__', elem).removeprefix('on_') + event_name = getattr(value, "__listener_name__", elem).removeprefix( + "on_" + ) try: listeners[event_name].append(value) @@ -364,7 +383,9 @@ def __init__( self._filters: list[SinkFilter] = filters or [] self._handlers: list[SinkHandler] = handlers or [] self.__dispatch_set: set[asyncio.Task[Any]] = set() - self._listeners: dict[str, list[Callable[[Iterable[object]], bool]]] = self.__listeners__ + self._listeners: dict[str, list[Callable[[Iterable[object]], bool]]] = ( + self.__listeners__ + ) @property def filtering_mode(self) -> SinkFilteringMode: @@ -377,13 +398,15 @@ def filtering_mode(self, value: SinkFilteringMode) -> None: elif value is SinkFilteringMode.any: self._filter_strat = any else: - raise TypeError(f'expected a FilteringMode enum member, got {value.__class__.__name__}') + raise TypeError( + f"expected a FilteringMode enum member, got {value.__class__.__name__}" + ) self.__filtering_mode = value def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: - _log.debug('Dispatching sink %s event %s', self.__class__.__name__, event) - method = f'on_{event}' + _log.debug("Dispatching sink %s event %s", self.__class__.__name__, event) + method = f"on_{event}" listeners = self.__listeners__.get(event, []) for coro in listeners: @@ -401,7 +424,7 @@ async def _run_event( coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, - **kwargs: Any + **kwargs: Any, ) -> None: try: await coro(*args, **kwargs) @@ -426,7 +449,9 @@ def _call_voice_packet_handlers(self, user: abc.Snowflake, packet: RawData) -> N self.__dispatch_set.add(task) task.add_done_callback(self.__dispatch_set.remove) - def _call_speaking_state_handlers(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + def _call_speaking_state_handlers( + self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> None: for handler in self._handlers: task = asyncio.create_task( utils.maybe_coroutine( @@ -449,13 +474,13 @@ def _schedule_event( ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) - task = asyncio.create_task(wrapped, name=f'sinks: {event_name}') + task = asyncio.create_task(wrapped, name=f"sinks: {event_name}") self.__dispatch_set.add(task) task.add_done_callback(self.__dispatch_set.discard) return task def __repr__(self) -> str: - return f'<{self.__class__.__name__} id={id(self):#x}>' + return f"<{self.__class__.__name__} id={id(self):#x}>" def stop(self) -> None: """Stops this sink's recording. @@ -499,7 +524,9 @@ def add_filter(self, filter: SinkFilter, /) -> None: """ if not isinstance(filter, SinkFilter): - raise TypeError(f'expected a Filter object, not {filter.__class__.__name__}') + raise TypeError( + f"expected a Filter object, not {filter.__class__.__name__}" + ) self._filters.append(filter) def remove_filter(self, filter: SinkFilter, /) -> None: @@ -531,7 +558,9 @@ def add_handler(self, handler: SinkHandler, /) -> None: """ if not isinstance(handler, SinkHandler): - raise TypeError(f'expected a Handler object, not {handler.__class__.__name__}') + raise TypeError( + f"expected a Handler object, not {handler.__class__.__name__}" + ) self._handlers.append(handler) def remove_handler(self, handler: SinkHandler, /) -> None: @@ -549,22 +578,16 @@ def remove_handler(self, handler: SinkHandler, /) -> None: pass @staticmethod - def listener(event: str = MISSING) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: + def listener( + event: str = MISSING, + ) -> Callable[ + [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] + ]: """Registers a function to be an event listener for this sink. The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised; and also must be inside a sink class. - Example - ------- - - .. code-block:: python3 - - class MySink(Sink): - @Sink.listener() - async def on_member_speaking_state_update(member, ssrc, state): - pass - Parameters ---------- event: :class:`str` @@ -574,41 +597,64 @@ async def on_member_speaking_state_update(member, ssrc, state): ------ TypeError The coroutine passed is not actually a coroutine, or the listener is not in a sink class. + + Example + ------- + + .. code-block:: python3 + + class MySink(Sink): + @Sink.listener() + async def on_member_speaking_state_update(member, ssrc, state): + pass """ - def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: - parts = func.__qualname__.split('.') + def decorator( + func: Callable[P, Coroutine[Any, Any, R]], + ) -> Callable[P, Coroutine[Any, Any, R]]: + parts = func.__qualname__.split(".") if not parts or not len(parts) > 1: - raise TypeError('event listeners must be declared in a Sink class') + raise TypeError("event listeners must be declared in a Sink class") if parts[-1] != func.__name__: - raise NameError('qualified name and function name mismatch, this should not happen') + raise NameError( + "qualified name and function name mismatch, this should not happen" + ) if not asyncio.iscoroutinefunction(func): - raise TypeError('event listeners must be coroutine functions') + raise TypeError("event listeners must be coroutine functions") func.__listener__ = True if event is not MISSING: func.__listener_name__ = event return func + return decorator async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: pass - async def on_unfiltered_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + async def on_unfiltered_voice_packet_receive( + self, user: abc.Snowflake, data: RawData + ) -> None: pass - async def on_speaking_state_update(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + async def on_speaking_state_update( + self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> None: pass - async def on_unfiltered_speaking_state_update(self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState) -> None: + async def on_unfiltered_speaking_state_update( + self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState + ) -> None: pass - async def on_error(self, event: str, exception: Exception, *args: Any, **kwargs: Any) -> None: + async def on_error( + self, event: str, exception: Exception, *args: Any, **kwargs: Any + ) -> None: _log.exception( - 'An error ocurred in sink %s while dispatching the event %s', + "An error ocurred in sink %s while dispatching the event %s", self, event, exc_info=exception, diff --git a/discord/sinks/enums.py b/discord/sinks/enums.py index 9dfb95e4cd..f09daf8b56 100644 --- a/discord/sinks/enums.py +++ b/discord/sinks/enums.py @@ -22,13 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from discord.enums import Enum -__all__ = ( - 'SinkFilteringMode', -) +__all__ = ("SinkFilteringMode",) class SinkFilteringMode(Enum): diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 66376228f7..dd5ff7fb8d 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -21,21 +21,22 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations -from collections import deque import io import logging import os import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord import utils from discord.file import File from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAdio @@ -45,13 +46,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'M4AConverterHandler', - 'M4ASink', + "M4AConverterHandler", + "M4ASink", ) -class M4AConverterHandler(SinkHandler['M4ASink']): - def handle_packet(self, sink: M4ASink, user: abc.Snowflake, packet: RawData) -> None: +class M4AConverterHandler(SinkHandler["M4ASink"]): + def handle_packet( + self, sink: M4ASink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -89,7 +92,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(M4AConverterHandler()) @@ -129,7 +134,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -159,7 +164,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -176,24 +181,24 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio - temp_path = f'{user_id}-{time.time()}-recording.m4a.tmp' + temp_path = f"{user_id}-{time.time()}-recording.m4a.tmp" args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'ipod', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "ipod", temp_path, ] @@ -202,21 +207,27 @@ def format_user_audio( if found: _, old_process = found old_process.kill() - _log.info('Killing old process (%s) to write in %s', old_process, temp_path) + _log.info( + "Killing old process (%s) to write in %s", old_process, temp_path + ) - os.remove(temp_path) # process would get stuck asking whether to overwrite, if file already exists. + os.remove( + temp_path + ) # process would get stuck asking whether to overwrite, if file already exists. try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE + ) self.__process_queue.append((temp_path, process)) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise M4ASinkError(f'Audio formatting for user {user_id} failed') from exc + raise M4ASinkError(f"Audio formatting for user {user_id} failed") from exc process.communicate(data.read()) - with open(temp_path, 'rb') as file: + with open(temp_path, "rb") as file: buffer = io.BytesIO(file.read()) buffer.seek(0) @@ -226,11 +237,16 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.m4a') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.m4a") return buffer def _clean_process(self, path: str, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s (with temporary file at %s)', process, self, path) + _log.debug( + "Cleaning process %s for sink %s (with temporary file at %s)", + process, + self, + path, + ) process.kill() if os.path.exists(path): os.remove(path) diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 8dccea1d7c..eac6fbab00 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -24,19 +24,19 @@ from __future__ import annotations -from collections import deque import io import logging import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord.file import File from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MKASinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MKASinkError, NoUserAdio if TYPE_CHECKING: from discord import abc @@ -44,13 +44,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'MKAConverterHandler', - 'MKASink', + "MKAConverterHandler", + "MKASink", ) -class MKAConverterHandler(SinkHandler['MKASink']): - def handle_packet(self, sink: MKASink, user: abc.Snowflake, packet: RawData) -> None: +class MKAConverterHandler(SinkHandler["MKASink"]): + def handle_packet( + self, sink: MKASink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -88,7 +90,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[subprocess.Popen] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(MKAConverterHandler()) @@ -128,7 +132,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -158,7 +162,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -175,33 +179,38 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'matroska', - 'pipe:1', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "matroska", + "pipe:1", ] try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + process = subprocess.Popen( + args, + creationflags=CREATE_NO_WINDOW, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) self.__process_queue.append(process) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MKASinkError(f'Audio formatting for user {user_id} failed') from exc + raise MKASinkError(f"Audio formatting for user {user_id} failed") from exc out = process.communicate(data.read())[0] buffer = io.BytesIO(out) @@ -213,11 +222,11 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.mka') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.mka") return buffer def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s', process, self) + _log.debug("Cleaning process %s for sink %s", process, self) process.kill() def cleanup(self) -> None: diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 283ac76344..e1ccfc8299 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -24,19 +24,19 @@ from __future__ import annotations -from collections import deque import io import logging import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord.file import File from discord.utils import MISSING -from .core import SinkHandler, Sink, SinkFilter, RawData +from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MKVSinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MKVSinkError, NoUserAdio if TYPE_CHECKING: from discord import abc @@ -44,13 +44,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'MKVConverterHandler', - 'MKVSink', + "MKVConverterHandler", + "MKVSink", ) -class MKVConverterHandler(SinkHandler['MKVSink']): - def handle_packet(self, sink: MKVSink, user: abc.Snowflake, packet: RawData) -> None: +class MKVConverterHandler(SinkHandler["MKVSink"]): + def handle_packet( + self, sink: MKVSink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -88,7 +90,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[subprocess.Popen] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(MKVConverterHandler()) @@ -128,7 +132,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -158,7 +162,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -175,33 +179,35 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'matroska', - 'pipe:1', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "matroska", + "pipe:1", ] try: - process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + process = subprocess.Popen( + args, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) self.__process_queue.append(process) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MKVSinkError(f'Audio formatting for user {user_id} failed') from exc + raise MKVSinkError(f"Audio formatting for user {user_id} failed") from exc out = process.communicate(data.read())[0] buffer = io.BytesIO(out) @@ -213,11 +219,11 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.mkv') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.mkv") return buffer def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s', process, self) + _log.debug("Cleaning process %s for sink %s", process, self) process.kill() def cleanup(self) -> None: diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 21de10fc95..2c8ef4faaa 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -24,19 +24,19 @@ from __future__ import annotations -from collections import deque import io import logging import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord.file import File from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MP3SinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MP3SinkError, NoUserAdio if TYPE_CHECKING: from discord import abc @@ -44,13 +44,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'MP3ConverterHandler', - 'MP3Sink', + "MP3ConverterHandler", + "MP3Sink", ) -class MP3ConverterHandler(SinkHandler['MP3Sink']): - def handle_packet(self, sink: MP3Sink, user: abc.Snowflake, packet: RawData) -> None: +class MP3ConverterHandler(SinkHandler["MP3Sink"]): + def handle_packet( + self, sink: MP3Sink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -88,7 +90,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[subprocess.Popen] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(MP3ConverterHandler()) @@ -128,7 +132,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -158,7 +162,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -175,33 +179,38 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'mp3', - 'pipe:1', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "mp3", + "pipe:1", ] try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + process = subprocess.Popen( + args, + creationflags=CREATE_NO_WINDOW, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) self.__process_queue.append(process) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MP3SinkError(f'Audio formatting for user {user_id} failed') from exc + raise MP3SinkError(f"Audio formatting for user {user_id} failed") from exc out = process.communicate(data.read())[0] buffer = io.BytesIO(out) @@ -213,11 +222,11 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.mp3') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.mp3") return buffer def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s', process, self) + _log.debug("Cleaning process %s for sink %s", process, self) process.kill() def cleanup(self) -> None: diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index b84bab40c8..6c1d0ff6ab 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -24,21 +24,21 @@ from __future__ import annotations -from collections import deque import io import logging import os import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord import utils from discord.file import File from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MP4SinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MP4SinkError, NoUserAdio if TYPE_CHECKING: from discord import abc @@ -46,13 +46,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'MP4ConverterHandler', - 'MP4Sink', + "MP4ConverterHandler", + "MP4Sink", ) -class MP4ConverterHandler(SinkHandler['MP4Sink']): - def handle_packet(self, sink: MP4Sink, user: abc.Snowflake, packet: RawData) -> None: +class MP4ConverterHandler(SinkHandler["MP4Sink"]): + def handle_packet( + self, sink: MP4Sink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -90,7 +92,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(MP4ConverterHandler()) @@ -130,7 +134,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -160,7 +164,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -177,24 +181,24 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio - temp_path = f'{user_id}-{time.time()}-recording.mp4.tmp' + temp_path = f"{user_id}-{time.time()}-recording.mp4.tmp" args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'mp4', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "mp4", temp_path, ] @@ -203,21 +207,27 @@ def format_user_audio( if found: _, old_process = found old_process.kill() - _log.info('Killing old process (%s) to write in %s', old_process, temp_path) + _log.info( + "Killing old process (%s) to write in %s", old_process, temp_path + ) - os.remove(temp_path) # process would get stuck asking whether to overwrite, if file already exists. + os.remove( + temp_path + ) # process would get stuck asking whether to overwrite, if file already exists. try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE + ) self.__process_queue.append((temp_path, process)) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise MP4SinkError(f'Audio formatting for user {user_id} failed') from exc + raise MP4SinkError(f"Audio formatting for user {user_id} failed") from exc process.communicate(data.read()) - with open(temp_path, 'rb') as file: + with open(temp_path, "rb") as file: buffer = io.BytesIO(file.read()) buffer.seek(0) @@ -227,11 +237,16 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.mp4') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.mp4") return buffer def _clean_process(self, path: str, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s (with temporary file at %s)', process, self, path) + _log.debug( + "Cleaning process %s for sink %s (with temporary file at %s)", + process, + self, + path, + ) process.kill() if os.path.exists(path): os.remove(path) diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 9755c34571..3fd2fc4727 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -24,19 +24,19 @@ from __future__ import annotations -from collections import deque import io import logging import subprocess import time +from collections import deque from typing import TYPE_CHECKING, Literal, overload from discord.file import File from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, SinkHandler, Sink, SinkFilter, RawData +from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, OGGSinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, NoUserAdio, OGGSinkError if TYPE_CHECKING: from discord import abc @@ -44,13 +44,15 @@ _log = logging.getLogger(__name__) __all__ = ( - 'OGGConverterHandler', - 'OGGSink', + "OGGConverterHandler", + "OGGSink", ) -class OGGConverterHandler(SinkHandler['OGGSink']): - def handle_packet(self, sink: OGGSink, user: abc.Snowflake, packet: RawData) -> None: +class OGGConverterHandler(SinkHandler["OGGSink"]): + def handle_packet( + self, sink: OGGSink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -88,7 +90,9 @@ def __init__( max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque(maxlen=max_audio_processes_count) + self.__process_queue: deque[subprocess.Popen] = deque( + maxlen=max_audio_processes_count + ) handlers = handlers or [] handlers.append(OGGConverterHandler()) @@ -128,7 +132,7 @@ def format_user_audio( self, user_id: int, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", as_file: bool = False, ) -> io.BytesIO | File: """Formats a user's saved audio data. @@ -158,7 +162,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. FFmpegNotFound @@ -175,33 +179,38 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - _log.info('There is no audio data for %s, ignoring.', user_id) + _log.info("There is no audio data for %s, ignoring.", user_id) raise NoUserAdio args = [ executable, - '-f', - 's16le', - '-ar', - '48000', - '-loglevel', - 'error', - '-ac', - '2', - '-i', - '-', - '-f', - 'ogg', - 'pipe:1', + "-f", + "s16le", + "-ar", + "48000", + "-loglevel", + "error", + "-ac", + "2", + "-i", + "-", + "-f", + "ogg", + "pipe:1", ] try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + process = subprocess.Popen( + args, + creationflags=CREATE_NO_WINDOW, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) self.__process_queue.append(process) except FileNotFoundError as exc: raise FFmpegNotFound from exc except subprocess.SubprocessError as exc: - raise OGGSinkError(f'Audio formatting for user {user_id} failed') from exc + raise OGGSinkError(f"Audio formatting for user {user_id} failed") from exc out = process.communicate(data.read())[0] buffer = io.BytesIO(out) @@ -213,11 +222,11 @@ def format_user_audio( pass if as_file: - return File(buffer, filename=f'{user_id}-{time.time()}-recording.ogg') + return File(buffer, filename=f"{user_id}-{time.time()}-recording.ogg") return buffer def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug('Cleaning process %s for sink %s', process, self) + _log.debug("Cleaning process %s for sink %s", process, self) process.kill() def cleanup(self) -> None: diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index 88fcd6c29c..0af4012c04 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -30,21 +30,23 @@ from discord.file import File from discord.utils import MISSING -from .core import RawData, Sink, SinkHandler, SinkFilter -from .errors import NoUserAdio +from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode +from .errors import NoUserAdio if TYPE_CHECKING: from discord import abc __all__ = ( - 'PCMConverterHandler', - 'PCMSink', + "PCMConverterHandler", + "PCMSink", ) -class PCMConverterHandler(SinkHandler['PCMSink']): - def handle_packet(self, sink: PCMSink, user: abc.Snowflake, packet: RawData) -> None: +class PCMConverterHandler(SinkHandler["PCMSink"]): + def handle_packet( + self, sink: PCMSink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) @@ -141,7 +143,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. """ @@ -154,7 +156,7 @@ def format_user_audio( data.seek(0) if as_file: - return File(data, filename=f'{user_id}-recording.pcm') + return File(data, filename=f"{user_id}-recording.pcm") return data def cleanup(self) -> None: diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index 0cb14d158c..fce203deba 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -25,13 +25,13 @@ from __future__ import annotations import io -from typing import TYPE_CHECKING, Literal, overload import wave +from typing import TYPE_CHECKING, Literal, overload from discord.file import File from discord.utils import MISSING -from .core import SinkFilter, SinkHandler, RawData, Sink +from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode from .errors import NoUserAdio @@ -39,20 +39,22 @@ from discord import abc __all__ = ( - 'WaveConverterHandler', - 'WavConverterHandler', - 'WaveSink', - 'WavSink', + "WaveConverterHandler", + "WavConverterHandler", + "WaveSink", + "WavSink", ) -class WaveConverterHandler(SinkHandler['WaveSink']): - def handle_packet(self, sink: WaveSink, user: abc.Snowflake, packet: RawData) -> None: +class WaveConverterHandler(SinkHandler["WaveSink"]): + def handle_packet( + self, sink: WaveSink, user: abc.Snowflake, packet: RawData + ) -> None: data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) data.write(packet.decoded_data) -WavConverterHandler: SinkHandler['WavSink'] = WaveConverterHandler # type: ignore +WavConverterHandler: SinkHandler[WavSink] = WaveConverterHandler # type: ignore class WaveSink(Sink): @@ -146,7 +148,7 @@ def format_user_audio( object with the buffer set as the audio bytes. Raises - ------- + ------ NoUserAudio You tried to format the audio of a user that was not stored in this sink. """ @@ -158,7 +160,7 @@ def format_user_audio( decoder = self.client.decoder - with wave.open(data, 'wb') as f: + with wave.open(data, "wb") as f: f.setnchannels(decoder.CHANNELS) f.setsampwidth(decoder.SAMPLE_SIZE // decoder.CHANNELS) f.setframerate(decoder.SAMPLING_RATE) @@ -166,7 +168,7 @@ def format_user_audio( data.seek(0) if as_file: - return File(data, filename=f'{user_id}-recording.pcm') + return File(data, filename=f"{user_id}-recording.pcm") return data def cleanup(self) -> None: diff --git a/discord/voice/client.py b/discord/voice/client.py index 27d9b53cb8..4073b935ae 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -644,9 +644,9 @@ def start_recording( """ if not self.is_connected(): - raise RecordingException('not connected to a voice channel') + raise RecordingException("not connected to a voice channel") if not isinstance(sink, Sink): - raise TypeError(f'expected a Sink object, got {sink.__class__.__name__}') + raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}") if sync_start is not MISSING: self._connection.sync_recording_start = sync_start @@ -665,22 +665,22 @@ def stop_recording( .. versionadded:: 2.0 - Paremeters - ---------- - sink: :class:`discord.Sink` - The sink to stop recording. - Raises ------ RecordingException The provided sink is not currently recording, or if ``None``, you are not recording. + + Paremeters + ---------- + sink: :class:`discord.Sink` + The sink to stop recording. """ if sink is not None: try: self._connection.sinks.remove(sink) except ValueError: - raise RecordingException('the provided sink is not currently recording') + raise RecordingException("the provided sink is not currently recording") sink.stop() return diff --git a/discord/voice/state.py b/discord/voice/state.py index f6e0ebb174..ec16428c31 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -26,19 +26,19 @@ from __future__ import annotations import asyncio -from collections import deque import logging import select import socket import struct import threading -from collections.abc import Callable, Coroutine import time +from collections import deque +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, TypedDict -from discord import utils, opus +from discord import opus, utils from discord.backoff import ExponentialBackoff -from discord.enums import try_enum, SpeakingState +from discord.enums import SpeakingState, try_enum from discord.errors import ConnectionClosed from discord.object import Object from discord.sinks import RawData, Sink @@ -51,9 +51,9 @@ from discord.guild import Guild from discord.member import VoiceState from discord.raw_models import RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent + from discord.state import ConnectionState from discord.types.voice import SupportedModes from discord.user import ClientUser - from discord.state import ConnectionState from .client import VoiceClient @@ -109,7 +109,9 @@ def pause(self) -> None: self._running.clear() def is_paused(self) -> bool: - return self._idle_paused or (not self._running.is_set() and not self._end.is_set()) + return self._idle_paused or ( + not self._running.is_set() and not self._end.is_set() + ) def resume(self, *, force: bool = False) -> None: if self._running.is_set(): @@ -139,7 +141,7 @@ def run(self) -> None: self._do_run() except Exception: _log.exception( - 'An error ocurred while running the socket reader %s', + "An error ocurred while running the socket reader %s", self.name, ) finally: @@ -170,7 +172,7 @@ def _do_run(self) -> None: data = self.state.socket.recv(self.buffer_size) except OSError: _log.debug( - 'Error reading from socket in %s, this should be safe to ignore', + "Error reading from socket in %s, this should be safe to ignore", self, exc_info=True, ) @@ -180,7 +182,7 @@ def _do_run(self) -> None: cb(data) except Exception: _log.exception( - 'Error while calling %s in %s', + "Error while calling %s in %s", cb, self, ) @@ -188,11 +190,14 @@ def _do_run(self) -> None: class SocketVoiceRecvReader(SocketReader): def __init__( - self, state: VoiceConnectionState, *, start_paused: bool = True, + self, + state: VoiceConnectionState, + *, + start_paused: bool = True, ) -> None: super().__init__( state, - f'voice-recv-socket-reader:{id(self):#x}', + f"voice-recv-socket-reader:{id(self):#x}", 4096, start_paused=start_paused, ) @@ -204,7 +209,7 @@ def __init__( ) -> None: super().__init__( state, - f'voice-socket-event-reader:{id(self):#x}', + f"voice-socket-event-reader:{id(self):#x}", 2048, start_paused=start_paused, ) @@ -212,11 +217,14 @@ def __init__( class DecoderThread(threading.Thread, opus._OpusStruct): def __init__( - self, state: VoiceConnectionState, *, start_paused: bool = True, + self, + state: VoiceConnectionState, + *, + start_paused: bool = True, ) -> None: super().__init__( daemon=True, - name=f'voice-recv-decoder-thread:{id(self):#x}', + name=f"voice-recv-decoder-thread:{id(self):#x}", ) self.state: VoiceConnectionState = state @@ -234,7 +242,9 @@ def __init__( def decode(self, frame: RawData) -> None: if not isinstance(frame, RawData): - raise TypeError(f'expected a RawData object, got {frame.__class__.__name__}') + raise TypeError( + f"expected a RawData object, got {frame.__class__.__name__}" + ) self.decode_queue.append(frame) def is_running(self) -> bool: @@ -272,7 +282,7 @@ def run(self) -> None: self._do_run() except Exception: _log.exception( - 'An error ocurred while running the decoder thread %s', + "An error ocurred while running the decoder thread %s", self.name, ) finally: @@ -307,7 +317,7 @@ def _do_run(self) -> None: ) except opus.OpusError: _log.exception( - 'Error ocurred while decoding opus frame', + "Error ocurred while decoding opus frame", exc_info=True, ) @@ -368,7 +378,9 @@ def __init__( self.sync_recording_start: bool = False self.first_received_packet_ts: float = MISSING self.sinks: list[Sink] = [] - self.recording_done_callbacks: list[tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]]] = [] + self.recording_done_callbacks: list[ + tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]] + ] = [] self.__sink_dispatch_task_set: set[asyncio.Task[Any]] = set() def start_record_socket(self) -> None: @@ -420,9 +432,7 @@ def dispatch_packet_sinks(self, data: RawData) -> None: self.first_received_packet_ts = data.receive_time silence = 0 else: - silence = ( - (data.receive_time - self.first_received_packet_ts) * 48000 - ) + silence = (data.receive_time - self.first_received_packet_ts) * 48000 else: stored_timestamp, stored_recv_time = self.user_voice_timestamps[data.ssrc] dRT = data.receive_time - stored_recv_time * 48000 @@ -437,7 +447,7 @@ def dispatch_packet_sinks(self, data: RawData) -> None: self.user_voice_timestamps[data.ssrc] = (data.timestamp, data.receive_time) data.decoded_data = ( - struct.pack(' None: user = self.get_user_by_ssrc(data.ssrc) if not user: _log.debug( - 'Ignoring received packet %s because the SSRC was waited for but was not found', + "Ignoring received packet %s because the SSRC was waited for but was not found", data, ) return @@ -465,10 +475,12 @@ async def _dispatch_packet(self, data: RawData) -> None: if sink.is_paused(): continue - sink.dispatch('unfiltered_voice_packet_receive', user, data) + sink.dispatch("unfiltered_voice_packet_receive", user, data) futures = [ - self.loop.create_task(utils.maybe_coroutine(fil.filter_packet, sink, user, data)) + self.loop.create_task( + utils.maybe_coroutine(fil.filter_packet, sink, user, data) + ) for fil in sink._filters ] strat = sink._filter_strat @@ -482,7 +494,7 @@ async def _dispatch_packet(self, data: RawData) -> None: done = (*done, *pending) if strat([f.result() for f in done]): - sink.dispatch('voice_packet_receive', user, data) + sink.dispatch("voice_packet_receive", user, data) sink._call_voice_packet_handlers(user, data) def is_recording(self) -> bool: @@ -506,50 +518,50 @@ def get_user_by_ssrc(self, ssrc: int) -> abc.Snowflake | None: if data is None: return None - user = int(data['user_id']) + user = int(data["user_id"]) return self.get_user(user) def get_user(self, id: int) -> abc.Snowflake: state = self._connection - return ( - self.guild.get_member(id) or - state.get_user(id) or - Object(id=id) - ) + return self.guild.get_member(id) or state.get_user(id) or Object(id=id) def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: - op = msg['op'] - data = msg.get('d', {}) + op = msg["op"] + data = msg.get("d", {}) if op == OpCodes.speaking: - ssrc = data['ssrc'] - user = int(data['user_id']) - raw_speaking = data['speaking'] + ssrc = data["ssrc"] + user = int(data["user_id"]) + raw_speaking = data["speaking"] speaking = try_enum(SpeakingState, raw_speaking) old_data = self.user_ssrc_map.get(ssrc) - old_speaking = (old_data or {}).get('speaking', SpeakingState.none) + old_speaking = (old_data or {}).get("speaking", SpeakingState.none) self._dispatch_speaking_state(old_speaking, speaking, user) if old_data is None: - self.user_ssrc_map[ssrc]['speaking'] = speaking + self.user_ssrc_map[ssrc]["speaking"] = speaking else: self.user_ssrc_map[ssrc] = { - 'user_id': user, - 'speaking': speaking, + "user_id": user, + "speaking": speaking, } - def _dispatch_speaking_state(self, before: SpeakingState, after: SpeakingState, uid: int) -> None: + def _dispatch_speaking_state( + self, before: SpeakingState, after: SpeakingState, uid: int + ) -> None: resolved = self.get_user(uid) for sink in self.sinks: if sink.is_paused(): continue - sink.dispatch('unfiltered_speaking_state_update', resolved, before, after) + sink.dispatch("unfiltered_speaking_state_update", resolved, before, after) futures = [ - self.loop.create_task(utils.maybe_coroutine(fil.filter_packet, sink, user, data)) + self.loop.create_task( + utils.maybe_coroutine(fil.filter_packet, sink, user, data) + ) for fil in sink._filters ] strat = sink._filter_strat @@ -563,7 +575,7 @@ def _dispatch_speaking_state(self, before: SpeakingState, after: SpeakingState, done = (*done, *pending) if strat([f.result() for f in done]): - sink.dispatch('speaking_state_update', resolved, before, after) + sink.dispatch("speaking_state_update", resolved, before, after) sink._call_speaking_state_handlers(resolved, before, after) @property From 4277b49d7643c0d5051087365c57ec3b1a4390a0 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 02:13:42 +0200 Subject: [PATCH 24/40] document default handlers --- discord/sinks/m4a.py | 6 ++++++ discord/sinks/mka.py | 6 ++++++ discord/sinks/mkv.py | 6 ++++++ discord/sinks/mp3.py | 6 ++++++ discord/sinks/mp4.py | 6 ++++++ discord/sinks/ogg.py | 6 ++++++ discord/sinks/pcm.py | 6 ++++++ discord/sinks/wave.py | 10 ++++++++++ docs/api/sinks.rst | 40 ++++++++++++++++++++++++++++++++++------ 9 files changed, 86 insertions(+), 6 deletions(-) diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index dd5ff7fb8d..8dbcf6a3f4 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -52,6 +52,12 @@ class M4AConverterHandler(SinkHandler["M4ASink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.M4ASink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: M4ASink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index eac6fbab00..06df9b5a6f 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -50,6 +50,12 @@ class MKAConverterHandler(SinkHandler["MKASink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.MKASink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: MKASink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index e1ccfc8299..538182ad79 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -50,6 +50,12 @@ class MKVConverterHandler(SinkHandler["MKVSink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.MKVSink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: MKVSink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 2c8ef4faaa..c19eddf9c1 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -50,6 +50,12 @@ class MP3ConverterHandler(SinkHandler["MP3Sink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.MP3Sink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: MP3Sink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index 6c1d0ff6ab..00c75c5d7f 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -52,6 +52,12 @@ class MP4ConverterHandler(SinkHandler["MP4Sink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.MP4Sink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: MP4Sink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 3fd2fc4727..102b62db09 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -50,6 +50,12 @@ class OGGConverterHandler(SinkHandler["OGGSink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.OGGSink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: OGGSink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index 0af4012c04..740cfbd2f9 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -44,6 +44,12 @@ class PCMConverterHandler(SinkHandler["PCMSink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.PCMSink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: PCMSink, user: abc.Snowflake, packet: RawData ) -> None: diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index fce203deba..b8a71e4d54 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -47,6 +47,12 @@ class WaveConverterHandler(SinkHandler["WaveSink"]): + """Default handler to add received voice packets to the audio cache data in + a :class:`~.WaveSink`. + + .. versionadded:: 2.7 + """ + def handle_packet( self, sink: WaveSink, user: abc.Snowflake, packet: RawData ) -> None: @@ -55,6 +61,10 @@ def handle_packet( WavConverterHandler: SinkHandler[WavSink] = WaveConverterHandler # type: ignore +"""An alias for :class:`~.WaveConverterHandler` + +.. versionadded:: 2.7 +""" class WaveSink(Sink): diff --git a/docs/api/sinks.rst b/docs/api/sinks.rst index 0ee8ca3a73..af02a59c08 100644 --- a/docs/api/sinks.rst +++ b/docs/api/sinks.rst @@ -6,25 +6,28 @@ Sinks Core ---- -.. autoclass:: discord.sinks.Filters +.. autoclass:: discord.sinks.Sink :members: -.. autoclass:: discord.sinks.Sink +.. autoclass:: discord.sinks.RawData :members: -.. autoclass:: discord.sinks.AudioData +.. autoclass:: discord.sinks.SinkHandler :members: -.. autoclass:: discord.sinks.RawData +.. autoclass:: discord.sinks.SinkFilter :members: -Sink Classes ------------- +Default Sinks +------------- .. autoclass:: discord.sinks.WaveSink :members: +.. autoclass:: discord.sinks.WavSink + :members: + .. autoclass:: discord.sinks.MP3Sink :members: @@ -42,3 +45,28 @@ Sink Classes .. autoclass:: discord.sinks.OGGSink :members: + + +Default Handlers +---------------- + +.. autoclass:: discord.sinks.WaveConverterHandler() + :members: + +.. autoclass:: discord.sinks.MP3ConverterHandler() + :members: + +.. autoclass:: discord.sinks.MP4ConverterHandler() + :members: + +.. autoclass:: discord.sinks.M4AConverterHandler() + :members: + +.. autoclass:: discord.sinks.MKVConverterHandler() + :members: + +.. autoclass:: discord.sinks.MKAConverterHandler() + :members: + +.. autoclass:: discord.sinks.OGGConverterHandler() + :members: From 46978ecb81d954fdd20c2d36ef3dde5e721fd44c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:14:21 +0000 Subject: [PATCH 25/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/sinks/m4a.py | 2 +- discord/sinks/mka.py | 2 +- discord/sinks/mkv.py | 2 +- discord/sinks/mp3.py | 2 +- discord/sinks/mp4.py | 2 +- discord/sinks/ogg.py | 2 +- discord/sinks/pcm.py | 2 +- discord/sinks/wave.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 8dbcf6a3f4..a174f87f49 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -54,7 +54,7 @@ class M4AConverterHandler(SinkHandler["M4ASink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.M4ASink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 06df9b5a6f..4cda2a9020 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -52,7 +52,7 @@ class MKAConverterHandler(SinkHandler["MKASink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.MKASink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 538182ad79..0c50d0af9d 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -52,7 +52,7 @@ class MKVConverterHandler(SinkHandler["MKVSink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.MKVSink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index c19eddf9c1..14442afa29 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -52,7 +52,7 @@ class MP3ConverterHandler(SinkHandler["MP3Sink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.MP3Sink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index 00c75c5d7f..d090584f81 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -54,7 +54,7 @@ class MP4ConverterHandler(SinkHandler["MP4Sink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.MP4Sink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 102b62db09..cce7ad348a 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -52,7 +52,7 @@ class OGGConverterHandler(SinkHandler["OGGSink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.OGGSink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index 740cfbd2f9..a1a52beecc 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -46,7 +46,7 @@ class PCMConverterHandler(SinkHandler["PCMSink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.PCMSink`. - + .. versionadded:: 2.7 """ diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index b8a71e4d54..d23ed4c0b5 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -49,7 +49,7 @@ class WaveConverterHandler(SinkHandler["WaveSink"]): """Default handler to add received voice packets to the audio cache data in a :class:`~.WaveSink`. - + .. versionadded:: 2.7 """ From 6c124bf78438b4a7b3fc633cae1a35ff5843655e Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 02:18:27 +0200 Subject: [PATCH 26/40] fix await outside async --- discord/voice/state.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/discord/voice/state.py b/discord/voice/state.py index ec16428c31..4adefccecf 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -537,7 +537,7 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: old_data = self.user_ssrc_map.get(ssrc) old_speaking = (old_data or {}).get("speaking", SpeakingState.none) - self._dispatch_speaking_state(old_speaking, speaking, user) + self.dispatch_speaking_state(old_speaking, speaking, user) if old_data is None: self.user_ssrc_map[ssrc]["speaking"] = speaking @@ -547,7 +547,14 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: "speaking": speaking, } - def _dispatch_speaking_state( + def dispatch_speaking_state(self, before: SpeakingState, after: SpeakingState, user_id: int) -> None: + task = self.loop.create_task( + self._dispatch_speaking_state(before, after, user_id), + ) + self.__sink_dispatch_task_set.add(task) + task.add_done_callback(self.__sink_dispatch_task_set.remove) + + async def _dispatch_speaking_state( self, before: SpeakingState, after: SpeakingState, uid: int ) -> None: resolved = self.get_user(uid) @@ -560,7 +567,7 @@ def _dispatch_speaking_state( futures = [ self.loop.create_task( - utils.maybe_coroutine(fil.filter_packet, sink, user, data) + utils.maybe_coroutine(fil.filter_speaking_state, sink, resolved, before, after) ) for fil in sink._filters ] From b34c25482306b2e9fda7cdf8b8ae3fb50617c7be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:20:19 +0000 Subject: [PATCH 27/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/voice/state.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/discord/voice/state.py b/discord/voice/state.py index 4adefccecf..f77e54866b 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -547,7 +547,9 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: "speaking": speaking, } - def dispatch_speaking_state(self, before: SpeakingState, after: SpeakingState, user_id: int) -> None: + def dispatch_speaking_state( + self, before: SpeakingState, after: SpeakingState, user_id: int + ) -> None: task = self.loop.create_task( self._dispatch_speaking_state(before, after, user_id), ) @@ -567,7 +569,9 @@ async def _dispatch_speaking_state( futures = [ self.loop.create_task( - utils.maybe_coroutine(fil.filter_speaking_state, sink, resolved, before, after) + utils.maybe_coroutine( + fil.filter_speaking_state, sink, resolved, before, after + ) ) for fil in sink._filters ] From 18ded59469ebd41d6fc24dc9d9cf078351e5fadb Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 02:31:29 +0200 Subject: [PATCH 28/40] name error --- discord/voice/client.py | 2 +- discord/voice/state.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/discord/voice/client.py b/discord/voice/client.py index 4073b935ae..8e42d11519 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -35,6 +35,7 @@ from discord import opus from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource +from discord.sinks.core import Sink from discord.sinks.errors import RecordingException from discord.utils import MISSING @@ -54,7 +55,6 @@ RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) - from discord.sinks import Sink from discord.state import ConnectionState from discord.types.voice import SupportedModes from discord.user import ClientUser diff --git a/discord/voice/state.py b/discord/voice/state.py index f77e54866b..33dcb76641 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -384,6 +384,9 @@ def __init__( self.__sink_dispatch_task_set: set[asyncio.Task[Any]] = set() def start_record_socket(self) -> None: + if self._voice_recv_socket.is_paused(): + self._voice_recv_socket.resume() + return if self._voice_recv_socket.is_running(): return self._voice_recv_socket.start() @@ -887,6 +890,7 @@ async def disconnect( if cleanup: self._socket_reader.stop() + self.stop_record_socket() self.client.stop() self._connected.set() @@ -932,6 +936,7 @@ async def soft_disconnect( finally: self.state = with_state self._socket_reader.pause() + self._voice_recv_socket.pause() if self.socket: self.socket.close() From 6f03ace135019d9fc456b7828114cf0027376e0e Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 02:38:57 +0200 Subject: [PATCH 29/40] fix dispatchs not working ig --- discord/voice/state.py | 66 ++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/discord/voice/state.py b/discord/voice/state.py index 33dcb76641..41a56c9441 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -241,6 +241,7 @@ def __init__( self._end: threading.Event = threading.Event() def decode(self, frame: RawData) -> None: + _log.debug('Decoding frame %s', frame) if not isinstance(frame, RawData): raise TypeError( f"expected a RawData object, got {frame.__class__.__name__}" @@ -407,6 +408,7 @@ def stop_record_socket(self) -> None: self.sinks.clear() def handle_voice_recv_packet(self, packet: bytes) -> None: + _log.debug('Handling voice packet %s', packet) if packet[1] != 0x78: # We should ignore any payload types we do not understand # Ref: RFC 3550 5.1 payload type @@ -429,7 +431,7 @@ def is_first_packet(self) -> bool: return not self.user_voice_timestamps or not self.sync_recording_start def dispatch_packet_sinks(self, data: RawData) -> None: - + _log.debug('Dispatching packet %s in all sinks', data) if data.ssrc not in self.user_ssrc_map: if self.is_first_packet(): self.first_received_packet_ts = data.receive_time @@ -480,23 +482,28 @@ async def _dispatch_packet(self, data: RawData) -> None: sink.dispatch("unfiltered_voice_packet_receive", user, data) - futures = [ - self.loop.create_task( - utils.maybe_coroutine(fil.filter_packet, sink, user, data) - ) - for fil in sink._filters - ] - strat = sink._filter_strat + if sink._filters: + futures = [ + self.loop.create_task( + utils.maybe_coroutine(fil.filter_packet, sink, user, data) + ) + for fil in sink._filters + ] + strat = sink._filter_strat - done, pending = await asyncio.wait(futures) + done, pending = await asyncio.wait(futures) - if pending: - for task in pending: - task.set_result(False) + if pending: + for task in pending: + task.set_result(False) + + done = (*done, *pending) - done = (*done, *pending) + result = strat([f.result() for f in done]) + else: + result = True - if strat([f.result() for f in done]): + if result: sink.dispatch("voice_packet_receive", user, data) sink._call_voice_packet_handlers(user, data) @@ -570,25 +577,28 @@ async def _dispatch_speaking_state( sink.dispatch("unfiltered_speaking_state_update", resolved, before, after) - futures = [ - self.loop.create_task( - utils.maybe_coroutine( - fil.filter_speaking_state, sink, resolved, before, after + if sink._filters: + futures = [ + self.loop.create_task( + utils.maybe_coroutine(fil.filter_packet, sink, resolved, before, after) ) - ) - for fil in sink._filters - ] - strat = sink._filter_strat + for fil in sink._filters + ] + strat = sink._filter_strat - done, pending = await asyncio.wait(futures) + done, pending = await asyncio.wait(futures) - if pending: - for task in pending: - task.set_result(False) + if pending: + for task in pending: + task.set_result(False) - done = (*done, *pending) + done = (*done, *pending) + + result = strat([f.result() for f in done]) + else: + result = True - if strat([f.result() for f in done]): + if result: sink.dispatch("speaking_state_update", resolved, before, after) sink._call_speaking_state_handlers(resolved, before, after) From 94d30826410beb68f73a3807f2b540662c410157 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 03:08:41 +0200 Subject: [PATCH 30/40] fix voice recv things fuck --- discord/voice/state.py | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/discord/voice/state.py b/discord/voice/state.py index 41a56c9441..7d73d19656 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -27,6 +27,7 @@ import asyncio import logging +from os import altsep import select import socket import struct @@ -60,6 +61,7 @@ MISSING = utils.MISSING SocketReaderCallback = Callable[[bytes], Any] _log = logging.getLogger(__name__) +_recv_log = logging.getLogger('discord.voice.receiver') class SocketReader(threading.Thread): @@ -84,6 +86,7 @@ def __init__( self._end: threading.Event = threading.Event() self._idle_paused: bool = True self._started: threading.Event = threading.Event() + self._warned_wait: bool = False def is_running(self) -> bool: return self._started.is_set() @@ -152,9 +155,16 @@ def run(self) -> None: def _do_run(self) -> None: while not self._end.is_set(): if not self._running.is_set(): + if not self._warned_wait: + _log.warning('Socket reader %s is waiting to be set as running', self.name) + self._warned_wait = True self._running.wait() continue + if self._warned_wait: + _log.info('Socket reader %s was set as running', self.name) + self._warned_wait = False + try: readable, _, _ = select.select([self.state.socket], [], [], 30) except (ValueError, TypeError, OSError) as e: @@ -234,6 +244,7 @@ def __init__( self._started: threading.Event = threading.Event() self._running: threading.Event = threading.Event() self._end: threading.Event = threading.Event() + self._warned_queue: bool = False self.decode_queue: deque[RawData] = deque() self.decoders: dict[int, opus.Decoder] = {} @@ -241,12 +252,12 @@ def __init__( self._end: threading.Event = threading.Event() def decode(self, frame: RawData) -> None: - _log.debug('Decoding frame %s', frame) if not isinstance(frame, RawData): raise TypeError( f"expected a RawData object, got {frame.__class__.__name__}" ) self.decode_queue.append(frame) + _log.debug('Added frame %s to decode queue', frame) def is_running(self) -> bool: return self._started.is_set() @@ -300,17 +311,28 @@ def get_decoder(self, ssrc: int) -> opus.Decoder: def _do_run(self) -> None: while not self._end.is_set(): - if not self._running.is_set(): - self._running.wait() + if not self.decode_queue: + if not self._warned_queue: + _recv_log.warning('No decode queue found, waiting') + self._warned_queue = True + + time.sleep(0.01) continue + if self._warned_queue: + _recv_log.info('Queue was filled') + self._warned_queue = False + try: data = self.decode_queue.popleft() except IndexError: continue + _recv_log.debug('Popped %s from the decode queue', data) + try: if data.decrypted_data is None: + _log.warning('Frame %s has no decrypted data, skipping', data) continue else: data.decoded_data = self.get_decoder(data.ssrc).decode( @@ -390,12 +412,19 @@ def start_record_socket(self) -> None: return if self._voice_recv_socket.is_running(): return + + if not self._decoder_thread.is_running(): + self._decoder_thread.start() + self._voice_recv_socket.start() def stop_record_socket(self) -> None: if self._voice_recv_socket.is_running(): self._voice_recv_socket.stop() + if self._decoder_thread.is_running(): + self._decoder_thread.stop() + for cb, args in self.recording_done_callbacks: task = self.loop.create_task(cb(*args)) self.__sink_dispatch_task_set.add(task) @@ -408,7 +437,7 @@ def stop_record_socket(self) -> None: self.sinks.clear() def handle_voice_recv_packet(self, packet: bytes) -> None: - _log.debug('Handling voice packet %s', packet) + _recv_log.debug('Handling voice packet %s', packet) if packet[1] != 0x78: # We should ignore any payload types we do not understand # Ref: RFC 3550 5.1 payload type @@ -418,14 +447,17 @@ def handle_voice_recv_packet(self, packet: bytes) -> None: return if self.paused_recording(): + _log.debug('Ignoring packet %s because recording is stopped', packet) return data = RawData(packet, self.client) - if data.decrypted_data != opus.OPUS_SILENCE: + if data.decrypted_data == opus.OPUS_SILENCE: + _log.debug('Ignoring packet %s because it is an opus silence frame', data) return self._decoder_thread.decode(data) + _recv_log.debug('Submitted frame %s to decoder thread', data) def is_first_packet(self) -> bool: return not self.user_voice_timestamps or not self.sync_recording_start From fe8e6028cb7959229e7bd087421c1afbeb187b00 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sat, 30 Aug 2025 19:38:53 +0200 Subject: [PATCH 31/40] try to fix things --- discord/opus.py | 67 +++-------------------------------------- discord/sinks/errors.py | 2 +- discord/sinks/m4a.py | 4 +-- discord/sinks/mka.py | 4 +-- discord/sinks/mkv.py | 4 +-- discord/sinks/mp3.py | 4 +-- discord/sinks/mp4.py | 4 +-- discord/sinks/ogg.py | 4 +-- discord/sinks/pcm.py | 4 +-- discord/sinks/wave.py | 4 +-- discord/voice/state.py | 61 +++++++++++++++++++++++-------------- 11 files changed, 61 insertions(+), 101 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index 14bfdace2c..caf533586a 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -42,6 +42,8 @@ from .sinks import RawData if TYPE_CHECKING: + from discord.voice.client import VoiceClient + T = TypeVar("T") APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] @@ -346,9 +348,9 @@ class OpusError(DiscordException): The error code returned. """ - def __init__(self, code: int): + def __init__(self, code: int = 0, message: str | None = None): self.code: int = code - msg = _lib.opus_strerror(self.code).decode("utf-8") + msg = message or _lib.opus_strerror(self.code).decode("utf-8") _log.info('"%s" has happened', msg) super().__init__(msg) @@ -523,7 +525,7 @@ def _get_last_packet_duration(self): def decode(self, data, *, fec=False): if data is None and fec: - raise OpusError("Invalid arguments: FEC cannot be used with null data") + raise OpusError(message="Invalid arguments: FEC cannot be used with null data") if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME @@ -545,62 +547,3 @@ def decode(self, data, *, fec=False): ) return array.array("h", pcm[: ret * channel_count]).tobytes() - - -class DecodeManager(threading.Thread, _OpusStruct): - def __init__(self, client: VoiceRecorderClient): - super().__init__(daemon=True, name="DecodeManager") - - self.client: VoiceRecorderClient = client - self.decode_queue: list[RawData] = [] - - self.decoder: dict[int, Decoder] = {} - - self._end_thread = threading.Event() - - def decode(self, opus_frame: RawData): - if not isinstance(opus_frame, RawData): - raise TypeError("opus_frame should be a RawData object.") - self.decode_queue.append(opus_frame) - - def run(self): - while not self._end_thread.is_set(): - try: - data = self.decode_queue.pop(0) - except IndexError: - time.sleep(0.001) - continue - - try: - if data.decrypted_data is None: - continue - else: - data.decoded_data = self.get_decoder(data.ssrc).decode( - data.decrypted_data - ) - except OpusError: - _log.exception( - "Error occurred while decoding opus frame.", exc_info=True - ) - continue - - self.client.receive_audio(data) - - def stop(self) -> None: - while self.decoding: - time.sleep(0.1) - self.decoder = {} - gc.collect() - _log.debug("Decoder Process Killed") - self._end_thread.set() - - def get_decoder(self, ssrc: int) -> Decoder: - d = self.decoder.get(ssrc) - if d is not None: - return d - self.decoder[ssrc] = Decoder() - return self.decoder[ssrc] - - @property - def decoding(self) -> bool: - return bool(self.decode_queue) diff --git a/discord/sinks/errors.py b/discord/sinks/errors.py index 51e00db73f..9f3d081bc4 100644 --- a/discord/sinks/errors.py +++ b/discord/sinks/errors.py @@ -104,7 +104,7 @@ class FFmpegNotFound(SinkException): """ -class NoUserAdio(SinkException): +class NoUserAudio(SinkException): """Exception thrown when you try to format the audio of a user not saved in a sink. .. versionadded:: 2.7 diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index a174f87f49..4afd032071 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -38,7 +38,7 @@ from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAdio +from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAudio if TYPE_CHECKING: from discord import abc @@ -188,7 +188,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio temp_path = f"{user_id}-{time.time()}-recording.m4a.tmp" args = [ diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 4cda2a9020..4a565b3003 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -36,7 +36,7 @@ from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MKASinkError, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MKASinkError, NoUserAudio if TYPE_CHECKING: from discord import abc @@ -186,7 +186,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio args = [ executable, diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 0c50d0af9d..ed876d4c20 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -36,7 +36,7 @@ from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MKVSinkError, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MKVSinkError, NoUserAudio if TYPE_CHECKING: from discord import abc @@ -186,7 +186,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio args = [ executable, diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 14442afa29..4cb52d0d42 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -36,7 +36,7 @@ from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MP3SinkError, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MP3SinkError, NoUserAudio if TYPE_CHECKING: from discord import abc @@ -186,7 +186,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio args = [ executable, diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index d090584f81..ba20d1efb9 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -38,7 +38,7 @@ from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MP4SinkError, NoUserAdio +from .errors import FFmpegNotFound, MaxProcessesCountReached, MP4SinkError, NoUserAudio if TYPE_CHECKING: from discord import abc @@ -188,7 +188,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio temp_path = f"{user_id}-{time.time()}-recording.mp4.tmp" args = [ diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index cce7ad348a..839fd4cea6 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -36,7 +36,7 @@ from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, NoUserAdio, OGGSinkError +from .errors import FFmpegNotFound, MaxProcessesCountReached, NoUserAudio, OGGSinkError if TYPE_CHECKING: from discord import abc @@ -186,7 +186,7 @@ def format_user_audio( data = self.__audio_data.pop(user_id) except KeyError: _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAdio + raise NoUserAudio args = [ executable, diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index a1a52beecc..e1e07d0e89 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -32,7 +32,7 @@ from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import NoUserAdio +from .errors import NoUserAudio if TYPE_CHECKING: from discord import abc @@ -157,7 +157,7 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - raise NoUserAdio + raise NoUserAudio data.seek(0) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index d23ed4c0b5..9f8199810d 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -33,7 +33,7 @@ from .core import RawData, Sink, SinkFilter, SinkHandler from .enums import SinkFilteringMode -from .errors import NoUserAdio +from .errors import NoUserAudio if TYPE_CHECKING: from discord import abc @@ -166,7 +166,7 @@ def format_user_audio( try: data = self.__audio_data.pop(user_id) except KeyError: - raise NoUserAdio + raise NoUserAudio decoder = self.client.decoder diff --git a/discord/voice/state.py b/discord/voice/state.py index 7d73d19656..40faf30ea2 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -27,7 +27,6 @@ import asyncio import logging -from os import altsep import select import socket import struct @@ -149,6 +148,7 @@ def run(self) -> None: ) finally: self.stop() + self._started.clear() self._running.clear() self._callbacks.clear() @@ -251,13 +251,18 @@ def __init__( self._end: threading.Event = threading.Event() - def decode(self, frame: RawData) -> None: + def decode(self, frame: RawData, *, left: bool = False) -> None: if not isinstance(frame, RawData): raise TypeError( f"expected a RawData object, got {frame.__class__.__name__}" ) - self.decode_queue.append(frame) - _log.debug('Added frame %s to decode queue', frame) + + if left: + self.decode_queue.appendleft(frame) + else: + self.decode_queue.append(frame) + self.resume() + _recv_log.debug('Added frame %s to decode queue', frame) def is_running(self) -> bool: return self._started.is_set() @@ -266,6 +271,11 @@ def pause(self) -> None: self._idle_paused = False self._running.clear() + def is_paused(self) -> bool: + return self._idle_paused or ( + not self._running.is_set() and not self._end.is_set() + ) + def resume(self, *, force: bool = False) -> None: if self._running.is_set(): return @@ -299,6 +309,7 @@ def run(self) -> None: ) finally: self.stop() + self._started.clear() self._running.clear() self.decode_queue.clear() @@ -311,12 +322,11 @@ def get_decoder(self, ssrc: int) -> opus.Decoder: def _do_run(self) -> None: while not self._end.is_set(): - if not self.decode_queue: - if not self._warned_queue: + if not self._running.is_set(): + if self._warned_queue: _recv_log.warning('No decode queue found, waiting') self._warned_queue = True - - time.sleep(0.01) + self._running.wait() continue if self._warned_queue: @@ -326,6 +336,8 @@ def _do_run(self) -> None: try: data = self.decode_queue.popleft() except IndexError: + _recv_log.warning('No more voice packets found, idle pausing.') + self.pause() continue _recv_log.debug('Popped %s from the decode queue', data) @@ -344,6 +356,14 @@ def _do_run(self) -> None: exc_info=True, ) + + _recv_log.info('Decoded frame %s to %s', data, data.decoded_data) + + if data.decoded_data is MISSING: + _recv_log.warning('Decoded data %s is still MISSING after decode, appending it to left to decode', data) + self.decode(data, left=True) + continue + self.state.dispatch_packet_sinks(data) @@ -396,6 +416,7 @@ def __init__( self._voice_recv_socket = SocketVoiceRecvReader(self) self._voice_recv_socket.register(self.handle_voice_recv_packet) self._decoder_thread = DecoderThread(self) + self.start_record_socket() self.user_ssrc_map: dict[int, SSRC] = {} self.user_voice_timestamps: dict[int, tuple[int, float]] = {} self.sync_recording_start: bool = False @@ -407,23 +428,19 @@ def __init__( self.__sink_dispatch_task_set: set[asyncio.Task[Any]] = set() def start_record_socket(self) -> None: - if self._voice_recv_socket.is_paused(): + try: + self._voice_recv_socket.start() + except RuntimeError: self._voice_recv_socket.resume() - return - if self._voice_recv_socket.is_running(): - return - if not self._decoder_thread.is_running(): + try: self._decoder_thread.start() - - self._voice_recv_socket.start() + except RuntimeError: + self._decoder_thread.resume() def stop_record_socket(self) -> None: - if self._voice_recv_socket.is_running(): - self._voice_recv_socket.stop() - - if self._decoder_thread.is_running(): - self._decoder_thread.stop() + self._voice_recv_socket.stop() + self._decoder_thread.stop() for cb, args in self.recording_done_callbacks: task = self.loop.create_task(cb(*args)) @@ -447,13 +464,13 @@ def handle_voice_recv_packet(self, packet: bytes) -> None: return if self.paused_recording(): - _log.debug('Ignoring packet %s because recording is stopped', packet) + _recv_log.debug('Ignoring packet %s because recording is stopped', packet) return data = RawData(packet, self.client) if data.decrypted_data == opus.OPUS_SILENCE: - _log.debug('Ignoring packet %s because it is an opus silence frame', data) + _recv_log.debug('Ignoring packet %s because it is an opus silence frame', data) return self._decoder_thread.decode(data) From 212ce1d9a1019127e46befbe5490e588a7bf2bc0 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Sun, 31 Aug 2025 18:32:58 +0200 Subject: [PATCH 32/40] voice recv is no longer funny --- discord/opus.py | 9 +- discord/sinks/core.py | 279 +++++++++++++++++++++++++++------------- discord/sinks/m4a.py | 31 ++--- discord/sinks/mka.py | 31 ++--- discord/sinks/mkv.py | 31 ++--- discord/sinks/mp3.py | 36 ++---- discord/sinks/mp4.py | 31 ++--- discord/sinks/ogg.py | 31 ++--- discord/sinks/pcm.py | 31 ++--- discord/sinks/wave.py | 38 ++---- discord/voice/client.py | 138 ++++++++++++++++---- discord/voice/state.py | 239 ++++++++++------------------------ 12 files changed, 444 insertions(+), 481 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index caf533586a..7d44925c32 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -536,10 +536,11 @@ def decode(self, data, *, fec=False): samples_per_frame = self.packet_get_samples_per_frame(data) frame_size = frames * samples_per_frame - pcm = ( - ctypes.c_int16 - * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) - )() + # pcm = ( + # ctypes.c_int16 + # * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) + # )() + pcm = (ctypes.c_int16 * (frame_size * channel_count))() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) ret = _lib.opus_decode( diff --git a/discord/sinks/core.py b/discord/sinks/core.py index b7af990692..ac2feda13a 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -26,15 +26,17 @@ from __future__ import annotations import asyncio +from collections import namedtuple import logging import struct import sys import time from collections.abc import Callable, Coroutine, Iterable from functools import partial +import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload -from discord import utils +from discord import utils, opus from discord.enums import SpeakingState from discord.utils import MISSING @@ -68,6 +70,10 @@ _log = logging.getLogger(__name__) +def is_rtcp(data: bytes) -> bool: + return 200 <= data[1] <= 204 + + class SinkFilter(Generic[S]): """Represents a filter for a :class:`~.Sink`. @@ -146,6 +152,42 @@ def filter_speaking_state( """ raise NotImplementedError("subclasses must implement this") + @overload + async def filter_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> bool: ... + + @overload + def filter_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> bool: ... + + def filter_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> bool | Coroutine[Any, Any, bool]: + """|maybecoro| + + This is called automatically everytime a speaking state is updated. + + Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. + + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet was received from. + channel: :class:`~discord.abc.Snowflake` + The channel the user has connected to. This is usually resolved into the proper guild channel type, but + defaults to a :class:`~discord.Object` when not found. + + Returns + ------- + :class:`bool` + Whether the filter was successful. + """ + raise NotImplementedError("subclasses must implement this") + def cleanup(self) -> None: """A function called when the filter is ready for cleanup.""" @@ -212,6 +254,34 @@ def handle_speaking_state( The speaking state after the update. """ + @overload + async def handle_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> Any: ... + + @overload + def handle_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> Any: ... + + def handle_user_connect( + self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + ) -> Any | Coroutine[Any, Any, Any]: + """|maybecoro| + + This is called automatically everytime a user has connected a voice channel which has successfully passed the filters. + + Parameters + ---------- + sink: :class:`~.Sink` + The sink the packet was received from, if the filter check goes through. + user: :class:`~discord.abc.Snowflake` + The user that the packet was received from. + channel: :class:`~discord.abc.Snowflake` + The channel the user has connected to. This is usually resolved into the proper guild channel type, but + defaults to a :class:`~discord.Object` when not found. + """ + def cleanup(self) -> None: """A function called when the handler is ready for cleanup.""" @@ -222,119 +292,128 @@ class RawData: .. versionadded:: 2.0 """ + unpacker = struct.Struct('>xxHII') + _ext_header = namedtuple('Extension', 'profile length values') + _ext_magic = b'\xbe\xde' + if TYPE_CHECKING: sequence: int timestamp: int ssrc: int - def __init__(self, data: bytes, client: VoiceClient): - self.data: bytearray = bytearray(data) + def __init__(self, raw_data: bytes, client: VoiceClient): + data: bytearray = bytearray(raw_data) self.client: VoiceClient = client - unpacker = struct.Struct(">xxHII") - self.sequence, self.timestamp, self.ssrc = unpacker.unpack_from(self.data[:12]) + self.version: int = data[0] >> 6 + self.padding: bool = bool(data[0] & 0b00100000) + self.extended: bool = bool(data[0] & 0b00010000) + self.cc: int = data[0] & 0b00001111 + self.marker: bool = bool(data[1] & 0b10000000) + self.payload: int = data[1] & 0b01111111 + + self.sequence, self.timestamp, self.ssrc = self.unpacker.unpack_from(data) + self.csrcs: tuple[int, ...] = () + self.extension = None + self.extension_data: dict[int, bytes] = {} + + self.header = data[:12] + self.data = data[12:] + self.decrypted_data: bytes | None = None + self.decoded_data: bytes = MISSING - # RFC3550 5.1: RTP Fixed Header Fields - if self.client.mode.endswith("_rtpsize"): - # If It Has CSRC Chunks - cutoff = 12 + (data[0] & 0b00_0_0_1111) * 4 - # If It Has A Extension - if data[0] & 0b00_0_1_0000: - cutoff += 4 - else: - cutoff = 12 + self.nonce: bytes = b'' + self._rtpsize: bool = False - self.header: bytes = data[:cutoff] - self.data = self.data[cutoff:] + self._decoder: opus.Decoder = opus.Decoder() + self.receive_time: float = time.perf_counter() - self.decrypted_data: bytes = getattr( - self.client, f"_decrypt_{self.client.mode}" - )(self.header, self.data) - self.decoded_data: bytes = MISSING + if self.cc: + fmt = '>%sI' % self.cc + offset = struct.calcsize(fmt) + 12 + self.csrcs = struct.unpack(fmt, data[12:offset]) + self.data = data[offset:] - self.user_id: int | None = None - self.receive_time: float = time.perf_counter() + def adjust_rtpsize(self) -> None: + self._rtpsize = True + self.nonce = self.data[-4:] + if not self.extended: + self.data = self.data[:-4] -class Sink: - r"""Represents a sink for voice recording. + self.header += self.data[:4] + self.data = self.data[4:-4] - This is used as a way of "storing" the recordings. + def update_headers(self, data: bytes) -> int: + if not self.extended: + return 0 - This class is abstracted, and must be subclassed in order to apply functionalities to - it. + if self._rtpsize: + data = self.header[-4:] + data - Parameters - ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. + profile, length = struct.unpack_from('>2sH', data) - Events - ------ + if profile == self._ext_magic: + self._parse_bede_header(data, length) - These section outlines all the available sink events. + values = struct.unpack('>%sI' % length, data[4 : 4 + length * 4]) + self.extension = self._ext_header(profile, length, values) - .. function:: on_voice_packet_receive(user, data) - Called when a voice packet is received from a member. + offset = 4 + length * 4 + if self._rtpsize: + offset -= 4 + return offset - This is called **after** the filters went through. + def _parse_bede_header(self, data: bytes, length: int) -> None: + offset = 4 + n = 0 - :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. - :type user: :class:`~discord.abc.Snowflake` - :param data: The RawData of the packet. - :type data: :class:`~.RawData` + while n < length: + next_byte = data[offset : offset + 1] - .. function:: on_unfiltered_voice_packet_receive(user, data) - Called when a voice packet is received from a member. + if next_byte == b'\x00': + offset += 1 + continue - Unlike ``on_voice_packet_receive``, this is called **before any filters** are called. + header = struct.unpack('>B', next_byte)[0] - :param user: The user the packet is from. This can sometimes be a :class:`~discord.Object` object. - :type user: :class:`~discord.abc.Snowflake` - :param data: The RawData of the packet. - :type data: :class:`~.RawData` + element_id = header >> 4 + element_len = 1 + (header & 0b0000_1111) - .. function:: on_speaking_state_update(user, before, after) - Called when a member's voice state changes. + self.extension_data[element_id] = data[offset + 1 : offset + 1 + element_len] + offset += 1 + element_len + n += 1 - This is called **after** the filters went through. + async def decode(self) -> bytes: + if not self.decrypted_data: + _log.debug('Attempted to decode an empty decrypted data frame') + return b'' - :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. - :type user: :class:`~discord.abc.Snowflake` - :param before: The user's state before it was updated. - :type before: :class:`~discord.SpeakingFlags` - :param after: The user's state after it was updated. - :type after: :class:`~discord.SpeakingFlags` + return await asyncio.to_thread( + self._decoder.decode, + self.decrypted_data, + ) - .. function:: on_unfiltered_speaking_state_update(user, before, after) - Called when a voice packet is received from a member. - Unlike ``on_speaking_state_update``, this is called **before any filters** are called. +class Sink: + r"""Represents a sink for voice recording. - :param user: The user which speaking state has changed. This can sometimes be a :class:`~discord.Object` object. - :type user: :class:`~discord.abc.Snowflake` - :param before: The user's state before it was updated. - :type before: :class:`~discord.SpeakingFlags` - :param after: The user's state after it was updated. - :type after: :class:`~discord.SpeakingFlags` + This is used as a way of "storing" the recordings. - .. function:: on_error(event, exception, \*args, \*\*kwargs) - Called when an error ocurrs in any of the events above. The default implementation logs the exception - to stdout. + This class is abstracted, and must be subclassed in order to apply functionalities to + it. - :param event: The event in which the error ocurred. - :type event: :class:`str` - :param exception: The exception that ocurred. - :type exception: :class:`Exception` - :param \*args: The arguments that were passed to the event. - :param \*\*kwargs: The key-word arguments that were passed to the event. + Parameters + ---------- + filters: List[:class:`~.SinkFilter`] + The filters to apply to this sink recorder. + filtering_mode: :class:`~.SinkFilteringMode` + How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through + in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, + only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. + handlers: List[:class:`~.SinkHandler`] + The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example + store a certain packet data in a file, or local mapping. """ if TYPE_CHECKING: @@ -447,7 +526,22 @@ def _call_voice_packet_handlers(self, user: abc.Snowflake, packet: RawData) -> N ) ) self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.remove) + task.add_done_callback(self.__dispatch_set.discard) + + def _call_user_connect_handlers( + self, user: abc.Snowflake, channel: abc.Snowflake, + ) -> None: + for handler in self._handlers: + task = asyncio.create_task( + utils.maybe_coroutine( + handler.handle_user_connect, + self, + user, + channel, + ), + ) + self.__dispatch_set.add(task) + task.add_done_callback(self.__dispatch_set.discard) def _call_speaking_state_handlers( self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState @@ -463,7 +557,7 @@ def _call_speaking_state_handlers( ), ) self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.remove) + task.add_done_callback(self.__dispatch_set.discard) def _schedule_event( self, @@ -498,7 +592,7 @@ def cleanup(self) -> None: for task in list(self.__dispatch_set): if task.done(): continue - task.set_result(None) + task.cancel() for filter in self._filters: filter.cleanup() @@ -506,9 +600,6 @@ def cleanup(self) -> None: for handler in self._handlers: handler.cleanup() - def __del__(self) -> None: - self.cleanup() - def add_filter(self, filter: SinkFilter, /) -> None: """Adds a filter to this sink. @@ -650,6 +741,16 @@ async def on_unfiltered_speaking_state_update( ) -> None: pass + async def on_user_connect( + self, user: abc.Snowflake, channel: abc.Snowflake, + ) -> None: + pass + + async def on_unfiltered_user_connect( + self, user: abc.Snowflake, channel: abc.Snowflake + ) -> None: + pass + async def on_error( self, event: str, exception: Exception, *args: Any, **kwargs: Any ) -> None: @@ -663,7 +764,7 @@ async def on_error( def is_recording(self) -> bool: """Whether this sink is currently available to record, and doing so.""" state = self.client._connection - return state.is_recording() and self in state.sinks + return state.is_recording() and id(self) in state._sinks def is_paused(self) -> bool: """Whether this sink is currently paused from recording.""" diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 4afd032071..4b438c7260 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -41,36 +41,20 @@ from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "M4AConverterHandler", "M4ASink", ) -class M4AConverterHandler(SinkHandler["M4ASink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.M4ASink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: M4ASink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class M4ASink(Sink): """A special sink for .m4a files. - This is essentially a :class:`~.Sink` with a :class:`~.M4AConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -92,18 +76,15 @@ class M4ASink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(M4AConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -268,3 +249,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 4a565b3003..f39bb6150b 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -39,36 +39,20 @@ from .errors import FFmpegNotFound, MaxProcessesCountReached, MKASinkError, NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "MKAConverterHandler", "MKASink", ) -class MKAConverterHandler(SinkHandler["MKASink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.MKASink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: MKASink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class MKASink(Sink): """A special sink for .mka files. - This is essentially a :class:`~.Sink` with a :class:`~.MKAConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -90,18 +74,15 @@ class MKASink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[subprocess.Popen] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(MKAConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -246,3 +227,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index ed876d4c20..ef5cc0c92d 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -39,36 +39,20 @@ from .errors import FFmpegNotFound, MaxProcessesCountReached, MKVSinkError, NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "MKVConverterHandler", "MKVSink", ) -class MKVConverterHandler(SinkHandler["MKVSink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.MKVSink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: MKVSink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class MKVSink(Sink): """A special sink for .mkv files. - This is essentially a :class:`~.Sink` with a :class:`~.MKVConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -90,18 +74,15 @@ class MKVSink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[subprocess.Popen] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(MKVConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -243,3 +224,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 4cb52d0d42..89ac17a657 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -39,36 +39,20 @@ from .errors import FFmpegNotFound, MaxProcessesCountReached, MP3SinkError, NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "MP3ConverterHandler", "MP3Sink", ) -class MP3ConverterHandler(SinkHandler["MP3Sink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.MP3Sink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: MP3Sink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class MP3Sink(Sink): """A special sink for .mp3 files. - This is essentially a :class:`~.Sink` with a :class:`~.MP3ConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -90,18 +74,15 @@ class MP3Sink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[subprocess.Popen] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(MP3ConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -110,10 +91,13 @@ def __init__( def get_user_audio(self, user_id: int) -> io.BytesIO | None: """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) + ret = self.__audio_data.get(user_id) + _log.debug('Found stored user ID %s with buffer %s', user_id, ret) + return ret def _create_audio_packet_for(self, uid: int) -> io.BytesIO: data = self.__audio_data[uid] = io.BytesIO() + _log.debug('Created user ID %s buffer', uid) return data @overload @@ -246,3 +230,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index ba20d1efb9..9c5b71d0fc 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -41,36 +41,20 @@ from .errors import FFmpegNotFound, MaxProcessesCountReached, MP4SinkError, NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "MP4ConverterHandler", "MP4Sink", ) -class MP4ConverterHandler(SinkHandler["MP4Sink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.MP4Sink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: MP4Sink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class MP4Sink(Sink): """A special sink for .mp4 files. - This is essentially a :class:`~.Sink` with a :class:`~.MP4ConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -92,18 +76,15 @@ class MP4Sink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(MP4ConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -268,3 +249,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 839fd4cea6..fc9fa4a3fe 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -39,36 +39,20 @@ from .errors import FFmpegNotFound, MaxProcessesCountReached, NoUserAudio, OGGSinkError if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc _log = logging.getLogger(__name__) __all__ = ( - "OGGConverterHandler", "OGGSink", ) -class OGGConverterHandler(SinkHandler["OGGSink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.OGGSink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: OGGSink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class OGGSink(Sink): """A special sink for .ogg files. - This is essentially a :class:`~.Sink` with a :class:`~.OGGConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -90,18 +74,15 @@ class OGGSink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, max_audio_processes_count: int = 10, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} self.__process_queue: deque[subprocess.Popen] = deque( maxlen=max_audio_processes_count ) - handlers = handlers or [] - handlers.append(OGGConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -246,3 +227,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index e1e07d0e89..5f44641184 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -35,34 +35,18 @@ from .errors import NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc __all__ = ( - "PCMConverterHandler", "PCMSink", ) -class PCMConverterHandler(SinkHandler["PCMSink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.PCMSink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: PCMSink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - class PCMSink(Sink): """A special sink for .pcm files. - This is essentially a :class:`~.Sink` with a :class:`.PCMConverterHandler` handler - passed as a default. - .. versionadded:: 2.0 Parameters @@ -81,14 +65,11 @@ class PCMSink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - handlers = handlers or [] - handlers.append(PCMConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -172,3 +153,7 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index 9f8199810d..ea588a6fce 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -36,42 +36,19 @@ from .errors import NoUserAudio if TYPE_CHECKING: + from typing_extensions import Self + from discord import abc __all__ = ( - "WaveConverterHandler", - "WavConverterHandler", "WaveSink", "WavSink", ) -class WaveConverterHandler(SinkHandler["WaveSink"]): - """Default handler to add received voice packets to the audio cache data in - a :class:`~.WaveSink`. - - .. versionadded:: 2.7 - """ - - def handle_packet( - self, sink: WaveSink, user: abc.Snowflake, packet: RawData - ) -> None: - data = sink.get_user_audio(user.id) or sink._create_audio_packet_for(user.id) - data.write(packet.decoded_data) - - -WavConverterHandler: SinkHandler[WavSink] = WaveConverterHandler # type: ignore -"""An alias for :class:`~.WaveConverterHandler` - -.. versionadded:: 2.7 -""" - - class WaveSink(Sink): """A special sink for .wav(e) files. - This is essentially a :class:`~.Sink` with a :class:`.WaveConverterHandler` handler. - .. versionadded:: 2.0 Parameters @@ -90,14 +67,11 @@ class WaveSink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, + filters: list[SinkFilter[Self]] = MISSING, filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + handlers: list[SinkHandler[Self]] = MISSING, ) -> None: self.__audio_data: dict[int, io.BytesIO] = {} - handlers = handlers or [] - handlers.append(WaveConverterHandler()) - super().__init__( filters=filters, filtering_mode=filtering_mode, @@ -189,6 +163,10 @@ def cleanup(self) -> None: self.__audio_data.clear() super().cleanup() + async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: + buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) + buffer.write(data.decoded_data) + WavSink = WaveSink """An alias for :class:`~.WaveSink`. diff --git a/discord/voice/client.py b/discord/voice/client.py index 8e42d11519..c6ca239818 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -35,7 +35,7 @@ from discord import opus from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource -from discord.sinks.core import Sink +from discord.sinks.core import Sink, RawData, is_rtcp from discord.sinks.errors import RecordingException from discord.utils import MISSING @@ -300,6 +300,8 @@ def _get_voice_packet(self, data: Any) -> bytes: encrypt_packet = getattr(self, f"_encrypt_{self.mode}") return encrypt_packet(header, data) + # encryption methods + def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: # deprecated box = nacl.secret.SecretBox(bytes(self.secret_key)) @@ -335,54 +337,134 @@ def _encrypt_aead_xchacha20_poly1305_rtpsize( + nonce[:4] ) - def _decrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: - # deprecated + # decryption methods + + def _decrypt_rtp_xsalsa20_poly1305(self, data: bytes) -> bytes: + packet = RawData(data, self) + nonce = bytearray(24) + nonce[:12] = packet.header + box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(bytes(packet.data), bytes(nonce)) + if packet.extended: + offset = packet.update_headers(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes: nonce = bytearray(24) - nonce[:12] = header + nonce[:8] = data[:8] - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(data[8:], bytes(nonce)) + + return data[:8] + result + + def _decrypt_xsalsa20_poly1305(self, data: bytes) -> bytes: + if is_rtcp(data): + func = self._decrypt_rtcp_xsalsa20_poly1305 + else: + func = self._decrypt_rtp_xsalsa20_poly1305 + return func(data) + + def _decrypt_rtp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: + packet = RawData(data, self) + nonce = packet.data[-24:] + voice_data = packet.data[:-24] - def _decrypt_xsalsa20_poly1305_suffix(self, header: bytes, data: Any) -> bytes: - # deprecated box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(bytes(voice_data), bytes(nonce)) - nonce_size = nacl.secret.SecretBox.NONCE_SIZE - nonce = data[-nonce_size:] + if packet.extended: + offset = packet.update_headers(result) + result = result[offset:] - return self.strip_header_ext(box.decrypt(bytes(data[:-nonce_size]), nonce)) + return result + + def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: + nonce = data[-24:] + header = data[:8] - def _decrypt_xsalsa20_poly1305_lite(self, header: bytes, data: Any) -> bytes: - # deprecated box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(data[8:-24], nonce) + + return header + result + def _decrypt_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: + if is_rtcp(data): + func = self._decrypt_rtcp_xsalsa20_poly1305_suffix + else: + func = self._decrypt_rtp_xsalsa20_poly1305_suffix + return func(data) + + def _decrypt_rtp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: + packet = RawData(data, self) + nonce = bytearray(24) + nonce[:4] = packet.data[-4:] + voice_data = packet.data[:-4] + + box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(bytes(voice_data), bytes(nonce)) + + if packet.extended: + offset = packet.update_headers(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: nonce = bytearray(24) nonce[:4] = data[-4:] - data = data[:-4] + header = data[:8] + + box = nacl.secret.SecretBox(bytes(self.secret_key)) + result = box.decrypt(data[8:-4], bytes(nonce)) - return self.strip_header_ext(box.decrypt(bytes(data), bytes(nonce))) + return header + result - def _decrypt_aead_xchacha20_poly1305_rtpsize( - self, header: bytes, data: Any - ) -> bytes: + def _decrypt_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: + if is_rtcp(data): + func = self._decrypt_rtcp_xsalsa20_poly1305_lite + else: + func = self._decrypt_rtp_xsalsa20_poly1305_lite + return func(data) + + def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: + packet = RawData(data, self) + packet.adjust_rtpsize() + + nonce = bytearray(24) + nonce[:4] = packet.nonce + voice_data = packet.data + + # Blob vomit box = nacl.secret.Aead(bytes(self.secret_key)) + result = box.decrypt(bytes(voice_data), bytes(packet.header), bytes(nonce)) + + if packet.extended: + offset = packet.update_headers(result) + result = result[offset:] + + return result + def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: nonce = bytearray(24) nonce[:4] = data[-4:] - data = data[:-4] + header = data[:8] - return self.strip_header_ext( - box.decrypt(bytes(data), bytes(header), bytes(nonce)) - ) + box = nacl.secret.Aead(bytes(self.secret_key)) + result = box.decrypt(data[8:-4], bytes(header), bytes(nonce)) - @staticmethod - def strip_header_ext(data: bytes) -> bytes: - if len(data) > 4 and data[0] == 0xBE and data[1] == 0xDE: - _, length = struct.unpack_from(">HH", data) - offset = 4 + length * 4 - data = data[offset:] - return data + return header + result + + def _decrypt_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: + if is_rtcp(data): + func = self._decrypt_rtcp_aead_xchacha20_poly1305_rtpsize + else: + func = self._decrypt_rtp_aead_xchacha20_poly1305_rtpsize + return func(data) @overload def play( diff --git a/discord/voice/state.py b/discord/voice/state.py index 40faf30ea2..3329d06c77 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -189,7 +189,9 @@ def _do_run(self) -> None: else: for cb in self._callbacks: try: - cb(data) + task = self.state.loop.create_task(utils.maybe_coroutine(cb, data)) + self.state._sink_dispatch_task_set.add(task) + task.add_done_callback(self.state._sink_dispatch_task_set.discard) except Exception: _log.exception( "Error while calling %s in %s", @@ -225,148 +227,6 @@ def __init__( ) -class DecoderThread(threading.Thread, opus._OpusStruct): - def __init__( - self, - state: VoiceConnectionState, - *, - start_paused: bool = True, - ) -> None: - super().__init__( - daemon=True, - name=f"voice-recv-decoder-thread:{id(self):#x}", - ) - - self.state: VoiceConnectionState = state - self.client: VoiceClient = state.client - self.start_paused: bool = start_paused - self._idle_paused: bool = True - self._started: threading.Event = threading.Event() - self._running: threading.Event = threading.Event() - self._end: threading.Event = threading.Event() - self._warned_queue: bool = False - - self.decode_queue: deque[RawData] = deque() - self.decoders: dict[int, opus.Decoder] = {} - - self._end: threading.Event = threading.Event() - - def decode(self, frame: RawData, *, left: bool = False) -> None: - if not isinstance(frame, RawData): - raise TypeError( - f"expected a RawData object, got {frame.__class__.__name__}" - ) - - if left: - self.decode_queue.appendleft(frame) - else: - self.decode_queue.append(frame) - self.resume() - _recv_log.debug('Added frame %s to decode queue', frame) - - def is_running(self) -> bool: - return self._started.is_set() - - def pause(self) -> None: - self._idle_paused = False - self._running.clear() - - def is_paused(self) -> bool: - return self._idle_paused or ( - not self._running.is_set() and not self._end.is_set() - ) - - def resume(self, *, force: bool = False) -> None: - if self._running.is_set(): - return - - if not force and not self.decode_queue: - self._idle_paused = True - return - - self._idle_paused = False - self._running.set() - - def stop(self) -> None: - self._started.clear() - self._end.set() - self._running.set() - - def run(self) -> None: - self._started.set() - self._end.clear() - self._running.set() - - if self.start_paused: - self.pause() - - try: - self._do_run() - except Exception: - _log.exception( - "An error ocurred while running the decoder thread %s", - self.name, - ) - finally: - self.stop() - self._started.clear() - self._running.clear() - self.decode_queue.clear() - - def get_decoder(self, ssrc: int) -> opus.Decoder: - try: - return self.decoders[ssrc] - except KeyError: - d = self.decoders[ssrc] = opus.Decoder() - return d - - def _do_run(self) -> None: - while not self._end.is_set(): - if not self._running.is_set(): - if self._warned_queue: - _recv_log.warning('No decode queue found, waiting') - self._warned_queue = True - self._running.wait() - continue - - if self._warned_queue: - _recv_log.info('Queue was filled') - self._warned_queue = False - - try: - data = self.decode_queue.popleft() - except IndexError: - _recv_log.warning('No more voice packets found, idle pausing.') - self.pause() - continue - - _recv_log.debug('Popped %s from the decode queue', data) - - try: - if data.decrypted_data is None: - _log.warning('Frame %s has no decrypted data, skipping', data) - continue - else: - data.decoded_data = self.get_decoder(data.ssrc).decode( - data.decrypted_data, - ) - except opus.OpusError: - _log.exception( - "Error ocurred while decoding opus frame", - exc_info=True, - ) - - - _recv_log.info('Decoded frame %s to %s', data, data.decoded_data) - - if data.decoded_data is MISSING: - _recv_log.warning('Decoded data %s is still MISSING after decode, appending it to left to decode', data) - self.decode(data, left=True) - continue - - self.state.dispatch_packet_sinks(data) - - class SSRC(TypedDict): user_id: int speaking: SpeakingState @@ -415,17 +275,20 @@ def __init__( self._socket_reader.start() self._voice_recv_socket = SocketVoiceRecvReader(self) self._voice_recv_socket.register(self.handle_voice_recv_packet) - self._decoder_thread = DecoderThread(self) self.start_record_socket() self.user_ssrc_map: dict[int, SSRC] = {} self.user_voice_timestamps: dict[int, tuple[int, float]] = {} self.sync_recording_start: bool = False self.first_received_packet_ts: float = MISSING - self.sinks: list[Sink] = [] + self._sinks: dict[int, Sink] = {} self.recording_done_callbacks: list[ tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]] ] = [] - self.__sink_dispatch_task_set: set[asyncio.Task[Any]] = set() + self._sink_dispatch_task_set: set[asyncio.Task[Any]] = set() + + @property + def sinks(self) -> list[Sink]: + return list(self._sinks.values()) def start_record_socket(self) -> None: try: @@ -433,19 +296,13 @@ def start_record_socket(self) -> None: except RuntimeError: self._voice_recv_socket.resume() - try: - self._decoder_thread.start() - except RuntimeError: - self._decoder_thread.resume() - def stop_record_socket(self) -> None: self._voice_recv_socket.stop() - self._decoder_thread.stop() for cb, args in self.recording_done_callbacks: task = self.loop.create_task(cb(*args)) - self.__sink_dispatch_task_set.add(task) - task.add_done_callback(self.__sink_dispatch_task_set.remove) + self._sink_dispatch_task_set.add(task) + task.add_done_callback(self._sink_dispatch_task_set.remove) for sink in self.sinks: sink.stop() @@ -453,7 +310,7 @@ def stop_record_socket(self) -> None: self.recording_done_callbacks.clear() self.sinks.clear() - def handle_voice_recv_packet(self, packet: bytes) -> None: + async def handle_voice_recv_packet(self, packet: bytes) -> None: _recv_log.debug('Handling voice packet %s', packet) if packet[1] != 0x78: # We should ignore any payload types we do not understand @@ -473,8 +330,7 @@ def handle_voice_recv_packet(self, packet: bytes) -> None: _recv_log.debug('Ignoring packet %s because it is an opus silence frame', data) return - self._decoder_thread.decode(data) - _recv_log.debug('Submitted frame %s to decoder thread', data) + await data.decode() def is_first_packet(self) -> bool: return not self.user_voice_timestamps or not self.sync_recording_start @@ -505,14 +361,17 @@ def dispatch_packet_sinks(self, data: RawData) -> None: + data.decoded_data ) + sleep_time = 0.01 while data.ssrc not in self.user_ssrc_map: - time.sleep(0.05) + time.sleep(sleep_time) + # duplicate sleep time, just for testing + sleep_time *= 2 task = self.loop.create_task( self._dispatch_packet(data), ) - self.__sink_dispatch_task_set.add(task) - task.add_done_callback(self.__sink_dispatch_task_set.remove) + self._sink_dispatch_task_set.add(task) + task.add_done_callback(self._sink_dispatch_task_set.discard) async def _dispatch_packet(self, data: RawData) -> None: user = self.get_user_by_ssrc(data.ssrc) @@ -563,13 +422,13 @@ def paused_recording(self) -> bool: return self._voice_recv_socket.is_paused() def add_sink(self, sink: Sink) -> None: - self.sinks.append(sink) + self._sinks[id(sink)] = sink self.start_record_socket() def remove_sink(self, sink: Sink) -> None: try: - self.sinks.remove(sink) - except ValueError: + self._sinks.pop(id(sink)) + except KeyError: pass def get_user_by_ssrc(self, ssrc: int) -> abc.Snowflake | None: @@ -605,6 +464,12 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: "user_id": user, "speaking": speaking, } + elif op == OpCodes.client_connect: + user_ids = [int(uid) for uid in data['user_ids']] + + for uid in user_ids: + user = self.get_user(uid) + self.dispatch_user_connect(user) def dispatch_speaking_state( self, before: SpeakingState, after: SpeakingState, user_id: int @@ -612,8 +477,46 @@ def dispatch_speaking_state( task = self.loop.create_task( self._dispatch_speaking_state(before, after, user_id), ) - self.__sink_dispatch_task_set.add(task) - task.add_done_callback(self.__sink_dispatch_task_set.remove) + self._sink_dispatch_task_set.add(task) + task.add_done_callback(self._sink_dispatch_task_set.discard) + + def dispatch_user_connect(self, user: abc.Snowflake) -> None: + task = self.loop.create_task( + self._dispatch_user_connect(self.channel_id, user), + ) + self._sink_dispatch_task_set.add(task) + task.add_done_callback(self._sink_dispatch_task_set.discard) + + async def _dispatch_user_connect(self, chid: int | None, user: abc.Snowflake) -> None: + channel = self.guild._resolve_channel(chid) or Object(id=chid or 0) + + for sink in self.sinks: + if sink.is_paused(): + continue + + sink.dispatch('unfiltered_user_connect', user, channel) + + if sink._filters: + futures = [ + self.loop.create_task( + utils.maybe_coroutine(fil.filter_user_connect, user, channel) + ) for fil in sink._filters + ] + strat = sink._filter_strat + + done, pending = await asyncio.wait(futures) + + if pending: + for task in pending: + task.cancel() + + result = strat([f.result() for f in done]) + else: + result = True + + if result: + sink.dispatch('user_connect', user, channel) + sink._call_user_connect_handlers(user, channel) async def _dispatch_speaking_state( self, before: SpeakingState, after: SpeakingState, uid: int @@ -638,10 +541,10 @@ async def _dispatch_speaking_state( done, pending = await asyncio.wait(futures) if pending: + # there should not be any pending futures + # but if there are, simply discard them for task in pending: - task.set_result(False) - - done = (*done, *pending) + task.cancel() result = strat([f.result() for f in done]) else: From 109d27c51dfbbb82d81a22d2c3ae6fd5e25af28a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 Aug 2025 16:38:15 +0000 Subject: [PATCH 33/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/gateway.py | 1 - discord/opus.py | 6 ++-- discord/sinks/core.py | 69 +++++++++++++++++++++++++++-------------- discord/sinks/m4a.py | 4 +-- discord/sinks/mka.py | 4 +-- discord/sinks/mkv.py | 4 +-- discord/sinks/mp3.py | 8 ++--- discord/sinks/mp4.py | 4 +-- discord/sinks/ogg.py | 4 +-- discord/sinks/pcm.py | 4 +-- discord/voice/client.py | 4 +-- discord/voice/state.py | 44 ++++++++++++++++---------- docs/api/sinks.rst | 2 +- 13 files changed, 89 insertions(+), 69 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index de8f9eaa48..2160818bda 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -28,7 +28,6 @@ import asyncio import concurrent.futures import logging -import struct import sys import threading import time diff --git a/discord/opus.py b/discord/opus.py index 7d44925c32..cd096f56d9 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -34,8 +34,6 @@ import os.path import struct import sys -import threading -import time from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar from .errors import DiscordException @@ -525,7 +523,9 @@ def _get_last_packet_duration(self): def decode(self, data, *, fec=False): if data is None and fec: - raise OpusError(message="Invalid arguments: FEC cannot be used with null data") + raise OpusError( + message="Invalid arguments: FEC cannot be used with null data" + ) if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME diff --git a/discord/sinks/core.py b/discord/sinks/core.py index ac2feda13a..c828e8034f 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -26,17 +26,16 @@ from __future__ import annotations import asyncio -from collections import namedtuple import logging import struct import sys import time +from collections import namedtuple from collections.abc import Callable, Coroutine, Iterable from functools import partial -import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload -from discord import utils, opus +from discord import opus, utils from discord.enums import SpeakingState from discord.utils import MISSING @@ -154,16 +153,25 @@ def filter_speaking_state( @overload async def filter_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> bool: ... @overload def filter_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> bool: ... def filter_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> bool | Coroutine[Any, Any, bool]: """|maybecoro| @@ -256,16 +264,25 @@ def handle_speaking_state( @overload async def handle_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> Any: ... @overload def handle_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> Any: ... def handle_user_connect( - self, sink: S, user: abc.Snowflake, channel: abc.Snowflake, + self, + sink: S, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> Any | Coroutine[Any, Any, Any]: """|maybecoro| @@ -292,9 +309,9 @@ class RawData: .. versionadded:: 2.0 """ - unpacker = struct.Struct('>xxHII') - _ext_header = namedtuple('Extension', 'profile length values') - _ext_magic = b'\xbe\xde' + unpacker = struct.Struct(">xxHII") + _ext_header = namedtuple("Extension", "profile length values") + _ext_magic = b"\xbe\xde" if TYPE_CHECKING: sequence: int @@ -322,14 +339,14 @@ def __init__(self, raw_data: bytes, client: VoiceClient): self.decrypted_data: bytes | None = None self.decoded_data: bytes = MISSING - self.nonce: bytes = b'' + self.nonce: bytes = b"" self._rtpsize: bool = False self._decoder: opus.Decoder = opus.Decoder() self.receive_time: float = time.perf_counter() if self.cc: - fmt = '>%sI' % self.cc + fmt = ">%sI" % self.cc offset = struct.calcsize(fmt) + 12 self.csrcs = struct.unpack(fmt, data[12:offset]) self.data = data[offset:] @@ -351,12 +368,12 @@ def update_headers(self, data: bytes) -> int: if self._rtpsize: data = self.header[-4:] + data - profile, length = struct.unpack_from('>2sH', data) + profile, length = struct.unpack_from(">2sH", data) if profile == self._ext_magic: self._parse_bede_header(data, length) - values = struct.unpack('>%sI' % length, data[4 : 4 + length * 4]) + values = struct.unpack(">%sI" % length, data[4 : 4 + length * 4]) self.extension = self._ext_header(profile, length, values) offset = 4 + length * 4 @@ -371,23 +388,25 @@ def _parse_bede_header(self, data: bytes, length: int) -> None: while n < length: next_byte = data[offset : offset + 1] - if next_byte == b'\x00': + if next_byte == b"\x00": offset += 1 continue - header = struct.unpack('>B', next_byte)[0] + header = struct.unpack(">B", next_byte)[0] element_id = header >> 4 element_len = 1 + (header & 0b0000_1111) - self.extension_data[element_id] = data[offset + 1 : offset + 1 + element_len] + self.extension_data[element_id] = data[ + offset + 1 : offset + 1 + element_len + ] offset += 1 + element_len n += 1 async def decode(self) -> bytes: if not self.decrypted_data: - _log.debug('Attempted to decode an empty decrypted data frame') - return b'' + _log.debug("Attempted to decode an empty decrypted data frame") + return b"" return await asyncio.to_thread( self._decoder.decode, @@ -529,7 +548,9 @@ def _call_voice_packet_handlers(self, user: abc.Snowflake, packet: RawData) -> N task.add_done_callback(self.__dispatch_set.discard) def _call_user_connect_handlers( - self, user: abc.Snowflake, channel: abc.Snowflake, + self, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> None: for handler in self._handlers: task = asyncio.create_task( @@ -742,7 +763,9 @@ async def on_unfiltered_speaking_state_update( pass async def on_user_connect( - self, user: abc.Snowflake, channel: abc.Snowflake, + self, + user: abc.Snowflake, + channel: abc.Snowflake, ) -> None: pass diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 4b438c7260..e817cfc540 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -47,9 +47,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "M4ASink", -) +__all__ = ("M4ASink",) class M4ASink(Sink): diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index f39bb6150b..3018bf84e4 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -45,9 +45,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "MKASink", -) +__all__ = ("MKASink",) class MKASink(Sink): diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index ef5cc0c92d..5dda65f16c 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -45,9 +45,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "MKVSink", -) +__all__ = ("MKVSink",) class MKVSink(Sink): diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 89ac17a657..3ce023cbd7 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -45,9 +45,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "MP3Sink", -) +__all__ = ("MP3Sink",) class MP3Sink(Sink): @@ -92,12 +90,12 @@ def __init__( def get_user_audio(self, user_id: int) -> io.BytesIO | None: """Gets a user's saved audio data, or ``None``.""" ret = self.__audio_data.get(user_id) - _log.debug('Found stored user ID %s with buffer %s', user_id, ret) + _log.debug("Found stored user ID %s with buffer %s", user_id, ret) return ret def _create_audio_packet_for(self, uid: int) -> io.BytesIO: data = self.__audio_data[uid] = io.BytesIO() - _log.debug('Created user ID %s buffer', uid) + _log.debug("Created user ID %s buffer", uid) return data @overload diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index 9c5b71d0fc..29a47ddece 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -47,9 +47,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "MP4Sink", -) +__all__ = ("MP4Sink",) class MP4Sink(Sink): diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index fc9fa4a3fe..8075aa3a99 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -45,9 +45,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "OGGSink", -) +__all__ = ("OGGSink",) class OGGSink(Sink): diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index 5f44641184..ddc156b173 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -39,9 +39,7 @@ from discord import abc -__all__ = ( - "PCMSink", -) +__all__ = ("PCMSink",) class PCMSink(Sink): diff --git a/discord/voice/client.py b/discord/voice/client.py index c6ca239818..41e305e8b1 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -35,7 +35,7 @@ from discord import opus from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource -from discord.sinks.core import Sink, RawData, is_rtcp +from discord.sinks.core import RawData, Sink, is_rtcp from discord.sinks.errors import RecordingException from discord.utils import MISSING @@ -404,7 +404,7 @@ def _decrypt_rtp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: nonce = bytearray(24) nonce[:4] = packet.data[-4:] voice_data = packet.data[:-4] - + box = nacl.secret.SecretBox(bytes(self.secret_key)) result = box.decrypt(bytes(voice_data), bytes(nonce)) diff --git a/discord/voice/state.py b/discord/voice/state.py index 3329d06c77..81e9d1734f 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -32,7 +32,6 @@ import struct import threading import time -from collections import deque from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, TypedDict @@ -60,7 +59,7 @@ MISSING = utils.MISSING SocketReaderCallback = Callable[[bytes], Any] _log = logging.getLogger(__name__) -_recv_log = logging.getLogger('discord.voice.receiver') +_recv_log = logging.getLogger("discord.voice.receiver") class SocketReader(threading.Thread): @@ -156,13 +155,15 @@ def _do_run(self) -> None: while not self._end.is_set(): if not self._running.is_set(): if not self._warned_wait: - _log.warning('Socket reader %s is waiting to be set as running', self.name) + _log.warning( + "Socket reader %s is waiting to be set as running", self.name + ) self._warned_wait = True self._running.wait() continue if self._warned_wait: - _log.info('Socket reader %s was set as running', self.name) + _log.info("Socket reader %s was set as running", self.name) self._warned_wait = False try: @@ -189,9 +190,13 @@ def _do_run(self) -> None: else: for cb in self._callbacks: try: - task = self.state.loop.create_task(utils.maybe_coroutine(cb, data)) + task = self.state.loop.create_task( + utils.maybe_coroutine(cb, data) + ) self.state._sink_dispatch_task_set.add(task) - task.add_done_callback(self.state._sink_dispatch_task_set.discard) + task.add_done_callback( + self.state._sink_dispatch_task_set.discard + ) except Exception: _log.exception( "Error while calling %s in %s", @@ -311,7 +316,7 @@ def stop_record_socket(self) -> None: self.sinks.clear() async def handle_voice_recv_packet(self, packet: bytes) -> None: - _recv_log.debug('Handling voice packet %s', packet) + _recv_log.debug("Handling voice packet %s", packet) if packet[1] != 0x78: # We should ignore any payload types we do not understand # Ref: RFC 3550 5.1 payload type @@ -321,13 +326,15 @@ async def handle_voice_recv_packet(self, packet: bytes) -> None: return if self.paused_recording(): - _recv_log.debug('Ignoring packet %s because recording is stopped', packet) + _recv_log.debug("Ignoring packet %s because recording is stopped", packet) return data = RawData(packet, self.client) if data.decrypted_data == opus.OPUS_SILENCE: - _recv_log.debug('Ignoring packet %s because it is an opus silence frame', data) + _recv_log.debug( + "Ignoring packet %s because it is an opus silence frame", data + ) return await data.decode() @@ -336,7 +343,7 @@ def is_first_packet(self) -> bool: return not self.user_voice_timestamps or not self.sync_recording_start def dispatch_packet_sinks(self, data: RawData) -> None: - _log.debug('Dispatching packet %s in all sinks', data) + _log.debug("Dispatching packet %s in all sinks", data) if data.ssrc not in self.user_ssrc_map: if self.is_first_packet(): self.first_received_packet_ts = data.receive_time @@ -465,7 +472,7 @@ def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: "speaking": speaking, } elif op == OpCodes.client_connect: - user_ids = [int(uid) for uid in data['user_ids']] + user_ids = [int(uid) for uid in data["user_ids"]] for uid in user_ids: user = self.get_user(uid) @@ -487,20 +494,23 @@ def dispatch_user_connect(self, user: abc.Snowflake) -> None: self._sink_dispatch_task_set.add(task) task.add_done_callback(self._sink_dispatch_task_set.discard) - async def _dispatch_user_connect(self, chid: int | None, user: abc.Snowflake) -> None: + async def _dispatch_user_connect( + self, chid: int | None, user: abc.Snowflake + ) -> None: channel = self.guild._resolve_channel(chid) or Object(id=chid or 0) for sink in self.sinks: if sink.is_paused(): continue - sink.dispatch('unfiltered_user_connect', user, channel) + sink.dispatch("unfiltered_user_connect", user, channel) if sink._filters: futures = [ self.loop.create_task( utils.maybe_coroutine(fil.filter_user_connect, user, channel) - ) for fil in sink._filters + ) + for fil in sink._filters ] strat = sink._filter_strat @@ -515,7 +525,7 @@ async def _dispatch_user_connect(self, chid: int | None, user: abc.Snowflake) -> result = True if result: - sink.dispatch('user_connect', user, channel) + sink.dispatch("user_connect", user, channel) sink._call_user_connect_handlers(user, channel) async def _dispatch_speaking_state( @@ -532,7 +542,9 @@ async def _dispatch_speaking_state( if sink._filters: futures = [ self.loop.create_task( - utils.maybe_coroutine(fil.filter_packet, sink, resolved, before, after) + utils.maybe_coroutine( + fil.filter_packet, sink, resolved, before, after + ) ) for fil in sink._filters ] diff --git a/docs/api/sinks.rst b/docs/api/sinks.rst index 6063322944..e5a199489d 100644 --- a/docs/api/sinks.rst +++ b/docs/api/sinks.rst @@ -125,4 +125,4 @@ These section outlines all the available sink events. :param exception: The exception that ocurred. :type exception: :class:`Exception` :param \*args: The arguments that were passed to the event. - :param \*\*kwargs: The key-word arguments that were passed to the event. \ No newline at end of file + :param \*\*kwargs: The key-word arguments that were passed to the event. From 8304dba7575df2839ab631541156bbb06f40f3f8 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Tue, 9 Sep 2025 10:04:19 +0200 Subject: [PATCH 34/40] add first pass to implement dave --- discord/client.py | 2 +- discord/gateway.py | 28 ---------- discord/voice/client.py | 14 ++++- discord/voice/enums.py | 13 +++++ discord/voice/gateway.py | 112 ++++++++++++++++++++++++++++++++++++++- discord/voice/state.py | 84 +++++++++++++++++++++++++++++ requirements/voice.txt | 1 + 7 files changed, 223 insertions(+), 31 deletions(-) diff --git a/discord/client.py b/discord/client.py index d86b14077a..97c5f6d557 100644 --- a/discord/client.py +++ b/discord/client.py @@ -50,7 +50,7 @@ from .invite import Invite from .iterators import EntitlementIterator, GuildIterator from .mentions import AllowedMentions -from .monetization import SKU, Entitlement +from .monetization import SKU from .object import Object from .soundboard import SoundboardSound from .stage_instance import StageInstance diff --git a/discord/gateway.py b/discord/gateway.py index 2160818bda..b5138c0813 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -41,7 +41,6 @@ from . import utils from .activity import BaseActivity -from .enums import SpeakingState from .errors import ConnectionClosed, InvalidArgument if TYPE_CHECKING: @@ -55,8 +54,6 @@ __all__ = ( "DiscordWebSocket", "KeepAliveHandler", - "VoiceKeepAliveHandler", - "DiscordVoiceWebSocket", "ReconnectWebSocket", ) @@ -228,31 +225,6 @@ def ack(self) -> None: _log.warning(self.behind_msg, self.shard_id, self.latency) -class VoiceKeepAliveHandler(KeepAliveHandler): - if TYPE_CHECKING: - ws: DiscordVoiceWebSocket - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.recent_ack_latencies = deque(maxlen=20) - self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s." - self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds" - self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind" - - def get_payload(self): - return { - "op": self.ws.HEARTBEAT, - "d": {"t": int(time.time() * 1000), "seq_ack": self.ws.seq_ack}, - } - - def ack(self): - ack_time = time.perf_counter() - self._last_ack = ack_time - self._last_recv = ack_time - self.latency = ack_time - self._last_send - self.recent_ack_latencies.append(self.latency) - - class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: return await super().close(code=code, message=message) diff --git a/discord/voice/client.py b/discord/voice/client.py index 41e305e8b1..efd0c71600 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -241,6 +241,14 @@ def average_latency(self) -> float: ws = self.ws return float("inf") if not ws else ws.average_latency + @property + def privacy_code(self) -> str | None: + """Returns the current voice session's privacy code, only available if the call has upgraded to use the + DAVE protocol + """ + session = self._connection.dave_session + return session and session.voice_privacy_code + async def disconnect(self, *, force: bool = False) -> None: """|coro| @@ -288,6 +296,10 @@ def is_paused(self) -> bool: # audio related def _get_voice_packet(self, data: Any) -> bytes: + + session = self._connection.dave_session + packet = session.encrypt_opus(data) if session and session.ready else data + header = bytearray(12) # formulate rtp header @@ -298,7 +310,7 @@ def _get_voice_packet(self, data: Any) -> bytes: struct.pack_into(">I", header, 8, self.ssrc) encrypt_packet = getattr(self, f"_encrypt_{self.mode}") - return encrypt_packet(header, data) + return encrypt_packet(header, packet) # encryption methods diff --git a/discord/voice/enums.py b/discord/voice/enums.py index 2c164e2f7a..5564c690b5 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -42,6 +42,19 @@ class OpCodes(Enum): client_connect = 10 client_disconnect = 11 + # dave protocol stuff + dave_prepare_transition = 21 + dave_execute_transition = 22 + dave_transition_ready = 23 + dave_prepare_epoch = 24 + mls_external_sender_package = 25 + mls_key_package = 26 + mls_proposals = 27 + mls_commit_welcome = 28 + mls_commit_transition = 29 + mls_welcome = 30 + mls_invalid_commit_welcome = 31 + def __eq__(self, other: object) -> bool: if isinstance(other, int): return self.value == other diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index aab003a527..8ce80d71a3 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -36,6 +36,7 @@ import aiohttp +import davey from discord import utils from discord.enums import SpeakingState from discord.errors import ConnectionClosed @@ -117,6 +118,7 @@ def __init__( self.seq_ack: int = -1 self.state: VoiceConnectionState = state self.ssrc_map: dict[str, dict[str, Any]] = {} + self.known_users: dict[int, Any] = {} if hook: self._hook = hook or state.ws_hook # type: ignore @@ -137,9 +139,22 @@ def session_id(self) -> str | None: def session_id(self, value: str | None) -> None: self.state.session_id = value + @property + def dave_session(self) -> davey.DaveSession | None: + return self.state.dave_session + + @property + def self_id(self) -> int: + return self._connection.self_id + async def _hook(self, *args: Any) -> Any: pass + async def send_as_bytes(self, op: int, data: bytes) -> None: + packet = bytes(op) + data + _log.debug("Sending voice websocket binary frame: op: %s data: %s", op, str(data)) + await self.ws.send_bytes(packet) + async def send_as_json(self, data: Any) -> None: _log.debug("Sending voice websocket frame: %s.", data) await self.ws.send_str(utils._to_json(data)) @@ -163,6 +178,7 @@ async def received_message(self, msg: Any, /): op = msg["op"] data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated + state = self.state if op == OpCodes.ready: await self.ready(data) @@ -179,8 +195,10 @@ async def received_message(self, msg: Any, /): "successfully RESUMED.", ) elif op == OpCodes.session_description: - self.state.mode = data["mode"] + state.mode = data["mode"] + state.dave_protocol_version = data["dave_protocol_version"] await self.load_secret_key(data) + await state.reinit_dave_session() elif op == OpCodes.hello: interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler( @@ -188,9 +206,88 @@ async def received_message(self, msg: Any, /): interval=min(interval, 5), ) self._keep_alive.start() + elif self.dave_session: + if op == OpCodes.dave_prepare_transition: + _log.info("Preparing to upgrade to a DAVE connection for channel %s", state.channel_id) + state.dave_pending_transition = data + + transition_id = data["transition_id"] + + if transition_id == 0: + await state.execute_dave_transition(data["transition_id"]) + else: + if data["protocol_version"] == 0: + self.dave_session.set_passthrough_mode(True, 120) + await self.send_dave_transition_ready(transition_id) + elif op == OpCodes.dave_execute_transition: + _log.info("Upgrading to DAVE connection for channel %s", state.channel_id) + await state.execute_dave_transition(data["transition_id"]) + elif op == OpCodes.dave_prepare_epoch: + epoch = data["epoch"] + _log.debug("Preparing for DAVE epoch in channel %s: %s", state.channel_id, epoch) + # if epoch is 1 then a new MLS group is to be created for the proto version + if epoch == 1: + state.dave_protocol_version = data["protocol_version"] + await state.reinit_dave_session() + else: + _log.debug("Unhandled op code: %s with data %s", op, data) await utils.maybe_coroutine(self._hook, self, msg) + async def received_binary_message(self, msg: bytes) -> None: + self.seq_ack = struct.unpack_from(">H", msg, 0)[0] + op = msg[2] + _log.debug("Voice websocket binary frame received: %d bytes, seq: %s, op: %s", len(msg), self.seq_ack, op) + + state = self.state + + if not self.dave_session: + return + + if op == OpCodes.mls_external_sender_package: + self.dave_session.set_external_sender(msg[3:]) + elif op == OpCodes.mls_proposals: + op_type = msg[3] + result = self.dave_session.process_proposals( + davey.ProposalsOperationType.append if op_type == 0 else davey.ProposalsOperationType.revoke, + msg[4:], + ) + + if isinstance(result, davey.CommitWelcome): + await self.send_as_bytes( + OpCodes.mls_key_package.value, + (result.commit + result.welcome) if result.welcome else result.commit, + ) + _log.debug("Processed MLS proposals for current dave session") + elif op == OpCodes.mls_commit_transition: + transt_id = struct.unpack_from(">H", msg, 3)[0] + try: + self.dave_session.process_commit(msg[5:]) + if transt_id != 0: + state.dave_pending_transition = { + "transition_id": transt_id, + "protocol_version": state.dave_protocol_version, + } + await self.send_dave_transition_ready(transt_id) + _log.debug("Processed MLS commit for transition %s", transt_id) + except Exception as exc: + _log.debug("An exception ocurred while processing a MLS commit, this should be safe to ignore: %s", exc) + await state.recover_dave_from_invalid_commit(transt_id) + elif op == OpCodes.mls_welcome: + transt_id = struct.unpack_from(">H", msg, 3)[0] + try: + self.dave_session.process_welcome(msg[5:]) + if transt_id != 0: + state.dave_pending_transition = { + "transition_id": transt_id, + "protocol_version": state.dave_protocol_version, + } + await self.send_dave_transition_ready(transt_id) + _log.debug("Processed MLS welcome for transition %s", transt_id) + except Exception as exc: + _log.debug("An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s", exc) + await state.recover_dave_from_invalid_commit(transt_id) + async def ready(self, data: dict[str, Any]) -> None: state = self.state @@ -232,6 +329,7 @@ async def select_protocol(self, ip: str, port: int, mode: str) -> None: "port": port, "mode": mode, }, + "dave_protocol_version": self.state.dave_protocol_version, }, } await self.send_as_json(payload) @@ -292,6 +390,8 @@ async def poll_event(self) -> None: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) + elif msg.type is aiohttp.WSMsgType.BINARY: + await self.received_binary_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data @@ -355,6 +455,16 @@ async def identify(self) -> None: "user_id": str(state.user.id), "session_id": self.session_id, "token": self.token, + "max_dave_protocol_version": self.state.max_dave_proto_version, + }, + } + await self.send_as_json(payload) + + async def send_dave_transition_ready(self, transition_id: int) -> None: + payload = { + "op": int(OpCodes.dave_transition_ready), + "d": { + "transition_id": transition_id, }, } await self.send_as_json(payload) diff --git a/discord/voice/state.py b/discord/voice/state.py index 81e9d1734f..aa717f0445 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -42,6 +42,8 @@ from discord.object import Object from discord.sinks import RawData, Sink +import davey + from .enums import ConnectionFlowState, OpCodes from .gateway import VoiceWebSocket @@ -60,6 +62,7 @@ SocketReaderCallback = Callable[[bytes], Any] _log = logging.getLogger(__name__) _recv_log = logging.getLogger("discord.voice.receiver") +DAVE_PROTOCOL_VERSION = davey.DAVE_PROTOCOL_VERSION class SocketReader(threading.Thread): @@ -291,10 +294,24 @@ def __init__( ] = [] self._sink_dispatch_task_set: set[asyncio.Task[Any]] = set() + if not self._connection.self_id: + raise RuntimeError("client self ID is not set") + if not self.channel_id: + raise RuntimeError("client channel being connected to is not set") + + self.dave_session: davey.DaveSession | None = None + self.dave_protocol_version: int = 0 + self.dave_pending_transition: dict[str, int] | None = None + self.downgraded_dave = False + @property def sinks(self) -> list[Sink]: return list(self._sinks.values()) + @property + def max_dave_proto_version(self) -> int: + return davey.DAVE_PROTOCOL_VERSION + def start_record_socket(self) -> None: try: self._voice_recv_socket.start() @@ -1011,6 +1028,7 @@ async def _voice_disconnect(self) -> None: await self.client.channel.guild.change_voice_state(channel=None) self._expecting_disconnect = True self._disconnected.clear() + self.ws._identified = False async def _connect_websocket(self, resume: bool) -> VoiceWebSocket: seq_ack = -1 @@ -1187,3 +1205,69 @@ async def _move_to(self, channel: abc.Snowflake) -> None: def _update_voice_channel(self, channel_id: int | None) -> None: self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore + + async def reinit_dave_session(self) -> None: + session = self.dave_session + + if self.dave_protocol_version > 0: + if session: + session.reinit(self.dave_protocol_version, self.user.id, self.channel_id) + else: + session = self.dave_session = davey.DaveSession( + self.dave_protocol_version, + self.user.id, + self.channel_id, + ) + + await self.ws.send_as_bytes( + int(OpCodes.mls_key_package), + session.get_serialized_key_package(), + ) + elif session: + session.reset() + session.set_passthrough_mode(True, 10) + + async def recover_dave_from_invalid_commit(self, transition: int) -> None: + payload = { + "op": int(OpCodes.mls_invalid_commit_welcome), + "d": {"transition_id": transition}, + } + await self.ws.send_as_json(payload) + await self.reinit_dave_session() + + async def execute_dave_transition(self, transition: int) -> None: + _log.debug("Executing DAVE transition with id %s", transition) + + if not self.dave_pending_transition: + _log.warning( + "Attempted to execute a transition without having a pending transition for id %s, " + "this is a Discord bug.", + transition, + ) + return + + pending_transition = self.dave_pending_transition["transition_id"] + pending_proto = self.dave_pending_transition["protocol_version"] + + session = self.dave_session + + if transition == pending_transition: + old_version = self.dave_protocol_version + self.dave_protocol_version = pending_proto + + if old_version != self.dave_protocol_version and self.dave_protocol_version == 0: + _log.warning("DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7") + self.downgraded_dave = True + elif transition > 0 and self.downgraded_dave: + self.downgraded_dave = False + if session: + session.set_passthrough_mode(True, 10) + _log.info("Upgraded voice session to use DAVE") + else: + _log.debug( + "Received an execute transition id %s when expected was %s, ignoring", + transition, + pending_proto, + ) + + self.dave_pending_transition = None diff --git a/requirements/voice.txt b/requirements/voice.txt index 6382712eac..71df4cdb2d 100644 --- a/requirements/voice.txt +++ b/requirements/voice.txt @@ -1 +1,2 @@ PyNaCl>=1.3.0,<1.6 +davey==0.1.0rc3 From 74d659021b1fb7130c4bce820b6563a2d981ca5d Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Thu, 11 Sep 2025 22:56:45 +0200 Subject: [PATCH 35/40] yay voice recv --- discord/__init__.py | 3 +- discord/_voice_aliases.py | 71 ++ discord/opus.py | 172 ++- discord/player.py | 433 ++++--- discord/sinks/__init__.py | 1 - discord/sinks/core.py | 1012 ++++++----------- discord/voice/__init__.py | 10 +- discord/voice/_types.py | 4 + discord/voice/client.py | 266 ++--- discord/voice/enums.py | 5 +- discord/voice/gateway.py | 46 +- discord/voice/packets/__init__.py | 47 + discord/voice/packets/core.py | 89 ++ discord/voice/packets/rtp.py | 300 +++++ discord/voice/receive/__init__.py | 2 + discord/voice/receive/reader.py | 466 ++++++++ discord/voice/receive/router.py | 222 ++++ discord/voice/state.py | 344 +----- discord/voice/utils/__init__.py | 0 discord/voice/utils/buffer.py | 204 ++++ discord/voice/utils/multidataevent.py | 78 ++ .../enums.py => voice/utils/wrapped.py} | 12 +- docs/api/voice.rst | 4 +- 23 files changed, 2374 insertions(+), 1417 deletions(-) create mode 100644 discord/_voice_aliases.py create mode 100644 discord/voice/packets/__init__.py create mode 100644 discord/voice/packets/core.py create mode 100644 discord/voice/packets/rtp.py create mode 100644 discord/voice/receive/__init__.py create mode 100644 discord/voice/receive/reader.py create mode 100644 discord/voice/receive/router.py create mode 100644 discord/voice/utils/__init__.py create mode 100644 discord/voice/utils/buffer.py create mode 100644 discord/voice/utils/multidataevent.py rename discord/{sinks/enums.py => voice/utils/wrapped.py} (85%) diff --git a/discord/__init__.py b/discord/__init__.py index b0094b6ec9..49389bc245 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -72,9 +72,10 @@ from .template import * from .threads import * from .user import * -from .voice import * from .webhook import * from .welcome_screen import * from .widget import * +from ._voice_aliases import * + logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/_voice_aliases.py b/discord/_voice_aliases.py new file mode 100644 index 0000000000..034ae1849a --- /dev/null +++ b/discord/_voice_aliases.py @@ -0,0 +1,71 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .utils import warn_deprecated + +""" +since discord.voice raises an error when importing it without having the +required package (ie davey) installed, we can't import it in __init__ because +that would break the whole library, that is why this file is here. + +the error would still be raised, but at least here we have more freedom on how we are typing it +""" + +__all__ = ("VoiceProtocol", "VoiceClient") + + +if TYPE_CHECKING: + from typing_extensions import deprecated + + from discord.voice import VoiceProtocolC, VoiceClientC + + @deprecated( + "discord.VoiceClient is deprecated in favour " + "of discord.voice.VoiceClient since 2.7 and " + "will be removed in 3.0", + ) + def VoiceClient(client, channel) -> VoiceClientC: + ... + + @deprecated( + "discord.VoiceProtocol is deprecated in favour " + "of discord.voice.VoiceProtocol since 2.7 and " + "will be removed in 3.0", + ) + def VoiceProtocol(client, channel) -> VoiceProtocolC: + ... +else: + @warn_deprecated("discord.VoiceClient", "discord.voice.VoiceClient", "2.7", "3.0") + def VoiceClient(client, channel): + from discord.voice import VoiceClient + return VoiceClient(client, channel) + + @warn_deprecated("discord.VoiceProtocol", "discord.voice.VoiceProtocol", "2.7", "3.0") + def VoiceProtocol(client, channel): + from discord.voice import VoiceProtocol + return VoiceProtocol(client, channel) diff --git a/discord/opus.py b/discord/opus.py index cd096f56d9..f3a3ed8986 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -36,11 +36,23 @@ import sys from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar +from discord.voice.packets.rtp import FakePacket +from discord.voice.utils.wrapped import gap_wrapped, add_wrapped +from discord.voice.utils.buffer import JitterBuffer + +import davey + from .errors import DiscordException from .sinks import RawData if TYPE_CHECKING: + from discord.user import User + from discord.member import Member from discord.voice.client import VoiceClient + from discord.voice.receive.router import PacketRouter + from discord.voice.packets.core import Packet + from discord.voice.packets import VoiceData + from discord.sinks.core import Sink T = TypeVar("T") APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] @@ -102,6 +114,12 @@ class DecoderStruct(ctypes.Structure): # Error codes OK = 0 BAD_ARG = -1 +BUFF_TOO_SMALL = -2 +INTERNAL_ERROR = -3 +INVALID_PACKET = -4 +UNIMPLEMENTED = -5 +INVALID_STATE = -6 +ALLOC_FAIL = -7 # Encoder CTLs application_ctl: ApplicationCtl = { @@ -449,12 +467,15 @@ def set_fec(self, enabled: bool = True) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None: _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore - def encode(self, pcm: bytes, frame_size: int) -> bytes: + def encode(self, pcm: bytes, frame_size: int | None = None) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() + if frame_size is None: + frame_size = self.FRAME_SIZE + ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know @@ -462,7 +483,7 @@ def encode(self, pcm: bytes, frame_size: int) -> bytes: class Decoder(_OpusStruct): - def __init__(self): + def __init__(self) -> None: _OpusStruct.get_opus_version() self._state = self._create_state() @@ -521,18 +542,18 @@ def _get_last_packet_duration(self): _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) return ret.value - def decode(self, data, *, fec=False): + def decode(self, data: bytes | None, *, fec: bool = True): if data is None and fec: raise OpusError( message="Invalid arguments: FEC cannot be used with null data" ) + channel_count = self.CHANNELS + if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME - channel_count = self.CHANNELS else: frames = self.packet_get_nb_frames(data) - channel_count = self.CHANNELS samples_per_frame = self.packet_get_samples_per_frame(data) frame_size = frames * samples_per_frame @@ -542,9 +563,150 @@ def decode(self, data, *, fec=False): # )() pcm = (ctypes.c_int16 * (frame_size * channel_count))() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) + pcm_ptr = ctypes.cast( + pcm, + c_int16_ptr, + ) ret = _lib.opus_decode( self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec ) return array.array("h", pcm[: ret * channel_count]).tobytes() + + +class PacketDecoder: + def __init__(self, router: PacketRouter, ssrc: int) -> None: + self.router: PacketRouter = router + self.ssrc: int = ssrc + + self._decoder: Decoder | None = None if self.sink.is_opus() else Decoder() + self._buffer: JitterBuffer = JitterBuffer() + self._cached_id: int | None = None + + self._last_seq: int = -1 + self._last_ts: int = -1 + + @property + def sink(self) -> Sink: + return self.router.sink + + def _get_user(self, user_id: int) -> User | Member | None: + vc: VoiceClient = self.sink.client # type: ignore + return vc.guild.get_member(user_id) or vc.client.get_user(user_id) + + def _get_cached_member(self) -> User | Member | None: + return self._get_user(self._cached_id) if self._cached_id else None + + def _flag_ready_state(self) -> None: + if self._buffer.peek(): + self.router.waiter.register(self) + else: + self.router.waiter.unregister(self) + + def push_packet(self, packet: Packet) -> None: + self._buffer.push(packet) + self._flag_ready_state() + + def pop_data(self, *, timeout: float = 0) -> VoiceData | None: + packet = self._get_next_packet(timeout) + self._flag_ready_state() + + if packet is None: + return None + return self._process_packet(packet) + + def set_user_id(self, user_id: int) -> None: + self._cached_id = user_id + + def reset(self) -> None: + self._buffer.reset() + self._decoder = None if self.sink.is_opus() else Decoder() + self._last_seq = self._last_ts = -1 + self._flag_ready_state() + + def destroy(self) -> None: + self._buffer.reset() + self._decoder = None + self._flag_ready_state() + + def _get_next_packet(self, timeout: float) -> Packet | None: + packet = self._buffer.pop(timeout=timeout) + + if packet is None: + if self._buffer: + packets = self._buffer.flush() + if any(packets[1:]): + _log.warning( + "%s packets were lost being flushed in decoder-%s", + len(packets) - 1, + self.ssrc, + ) + return packets[0] + return + elif not packet: + packet = self._make_fakepacket() + return packet + + def _make_fakepacket(self) -> FakePacket: + seq = add_wrapped(self._last_seq, 1) + ts = add_wrapped(self._last_ts, Decoder.SAMPLES_PER_FRAME, wrap=2**32) + return FakePacket(self.ssrc, seq, ts) + + def _process_packet(self, packet: Packet) -> VoiceData: + from discord.object import Object + + pcm = None + + if not self.sink.is_opus(): + packet, pcm = self._decode_packet(packet) + + member = self._get_cached_member() + + if member is None: + self._cached_id = self.sink.client._connection._get_id_from_ssrc(self.ssrc) + member = self._get_cached_member() + + # yet still none, use Object + if member is None and self._cached_id: + member = Object(id=self._cached_id) + + data = VoiceData(packet, member, pcm=pcm) + self._last_seq = packet.sequence + self._last_ts = packet.timestamp + return data + + def _decode_packet(self, packet: Packet) -> tuple[Packet, bytes]: + assert self._decoder is not None + assert self.sink.client + + user_id: int | None = self._cached_id + dave: davey.DaveSession | None = self.sink.client._connection.dave_session + in_dave = dave is not None + + # personally, the best variable + other_code = True + + if packet: + other_code = False + pcm = self._decoder.decode(packet.decrypted_data, fec=False) + + if other_code: + next_packet = self._buffer.peek_next() + + if next_packet is not None: + nextdata: bytes = next_packet.decrypted_data # type: ignore + + _log.debug( + "Generating fec packet: fake=%s, fec=%s", + packet.sequence, + next_packet.sequence, + ) + pcm = self._decoder.decode(nextdata, fec=True) + else: + pcm = self._decoder.decode(None, fec=False) + + if user_id is not None and in_dave and dave.can_passthrough(user_id): + pcm = dave.decrypt(user_id, davey.MediaType.audio, pcm) + + return packet, pcm diff --git a/discord/player.py b/discord/player.py index 3bf74a0a8f..9d67c3f0bb 100644 --- a/discord/player.py +++ b/discord/player.py @@ -38,6 +38,7 @@ import time from math import floor from typing import IO, TYPE_CHECKING, Any, Callable, Generic, TypeVar +import warnings from .enums import SpeakingState from .errors import ClientException @@ -47,6 +48,8 @@ from .utils import MISSING if TYPE_CHECKING: + from typing_extensions import Self + from .voice import VoiceClient @@ -97,7 +100,7 @@ def read(self) -> bytes: per frame (20ms worth of audio). Returns - ------- + -------- :class:`bytes` A bytes like object that represents the PCM or Opus data. """ @@ -122,7 +125,7 @@ class PCMAudio(AudioSource): """Represents raw 16-bit 48KHz stereo PCM audio source. Attributes - ---------- + ----------- stream: :term:`py:file object` A file-like object that reads byte data representing raw PCM. """ @@ -133,7 +136,7 @@ def __init__(self, stream: io.BufferedIOBase) -> None: def read(self) -> bytes: ret = self.stream.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b"" + return b'' return ret @@ -146,107 +149,128 @@ class FFmpegAudio(AudioSource): .. versionadded:: 1.3 """ + BLOCKSIZE: int = io.DEFAULT_BUFFER_SIZE + def __init__( self, source: str | io.BufferedIOBase, *, - executable: str = "ffmpeg", + executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any, ): - piping = subprocess_kwargs.get("stdin") == subprocess.PIPE - if piping and isinstance(source, str): - raise TypeError( - "parameter conflict: 'source' parameter cannot be a string when piping" - " to stdin" - ) + piping_stdin = subprocess_kwargs.get('stdin') == subprocess.PIPE + if piping_stdin and isinstance(source, str): + raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") + + stderr: IO[bytes] | None = subprocess_kwargs.pop('stderr', None) + + if stderr == subprocess.PIPE: + warnings.warn('Passing subprocess.PIPE does nothing', DeprecationWarning, stacklevel=3) + stderr = None + + piping_stderr = False + if stderr is not None: + try: + stderr.fileno() + except Exception: + piping_stderr = True args = [executable, *args] - kwargs = {"stdout": subprocess.PIPE} + kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.PIPE if piping_stderr else stderr} kwargs.update(subprocess_kwargs) - self._process: subprocess.Popen = self._spawn_process(args, **kwargs) - self._stdout: IO[bytes] = self._process.stdout # type: ignore + # Ensure attribute is assigned even in the case of errors + self._process: subprocess.Popen = MISSING + self._process = self._spawn_process(args, **kwargs) + self._stdout: IO[bytes] = self._process.stdout # type: ignore # process stdout is explicitly set self._stdin: IO[bytes] | None = None - self._pipe_thread: threading.Thread | None = None + self._stderr: IO[bytes] | None = None + self._pipe_writer_thread: threading.Thread | None = None + self._pipe_reader_thread: threading.Thread | None = None - if piping: - n = f"popen-stdin-writer:{id(self):#x}" + if piping_stdin: + n = f"popen-stdin-writer:pid-{self._process.pid}" self._stdin = self._process.stdin - self._pipe_thread = threading.Thread( - target=self._pipe_writer, args=(source,), daemon=True, name=n - ) - self._pipe_thread.start() + self._pipe_writer_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) + self._pipe_writer_thread.start() + + if piping_stderr: + n = f"popen-stderr-reader:pid-{self._process.pid}" + self._stderr = self._process.stderr + self._pipe_reader_thread = threading.Thread(target=self._pipe_reader, args=(stderr,), daemon=True, name=n) + self._pipe_reader_thread.start() def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: + _log.debug("Spawning ffmpeg process with command: %s", args) + process = None try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs - ) + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) except FileNotFoundError: executable = args.partition(" ")[0] if isinstance(args, str) else args[0] - raise ClientException(f"{executable} was not found.") from None + raise ClientException(executable + " was not found.") from None except subprocess.SubprocessError as exc: - raise ClientException( - f"Popen failed: {exc.__class__.__name__}: {exc}" - ) from exc + raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc else: return process def _kill_process(self) -> None: - proc = self._process + # this function gets called in __del__ so instance attributes might not even exist + proc = getattr(self, "_process", MISSING) if proc is MISSING: return - _log.info("Preparing to terminate ffmpeg process %s.", proc.pid) + _log.debug("Preparing to terminate ffmpeg process %s.", proc.pid) try: proc.kill() except Exception: - _log.exception( - "Ignoring error attempting to kill ffmpeg process %s", proc.pid - ) + _log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid) if proc.poll() is None: - _log.info( - "ffmpeg process %s has not terminated. Waiting to terminate...", - proc.pid, - ) + _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) proc.communicate() - _log.info( - "ffmpeg process %s should have terminated with a return code of %s.", - proc.pid, - proc.returncode, - ) + _log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode) else: - _log.info( - "ffmpeg process %s successfully terminated with return code of %s.", - proc.pid, - proc.returncode, - ) + _log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode) def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: - # arbitrarily large read size - data = source.read(8192) + data = source.read(self.BLOCKSIZE) if not data: - self._stdin.close() + if self._stdin is not None: + self._stdin.close() return try: - self._stdin.write(data) + if self._stdin is not None: + self._stdin.write(data) except Exception: - _log.debug( - "Write error for %s, this is probably not a problem", - self, - exc_info=True, - ) + _log.debug('Write error for %s, this is probably not a problem', self, exc_info=True) # at this point the source data is either exhausted or the process is fubar self._process.terminate() return + def _pipe_reader(self, dest: IO[bytes]) -> None: + while self._process: + if self._stderr is None: + return + try: + data: bytes = self._stderr.read(self.BLOCKSIZE) + except Exception: + _log.debug("Read error for %s, this is probably not a problem", self, exc_info=True) + return + if data is None: + return + try: + dest.write(data) + except Exception: + _log.exception("Write error for %s", self) + self._stderr.close() + return + def cleanup(self) -> None: self._kill_process() - self._process = self._stdout = self._stdin = MISSING + self._process = self._stdout = self._stdin = self._stderr = MISSING class FFmpegPCMAudio(FFmpegAudio): @@ -260,26 +284,31 @@ class FFmpegPCMAudio(FFmpegAudio): variable in order for this to work. Parameters - ---------- + ------------ source: Union[:class:`str`, :class:`io.BufferedIOBase`] The input that ffmpeg will take and convert to PCM bytes. If ``pipe`` is ``True`` then this is a file-like object that is passed to the stdin of ffmpeg. executable: :class:`str` The executable name (and path) to use. Defaults to ``ffmpeg``. + + .. warning:: + + Since this class spawns a subprocess, care should be taken to not + pass in an arbitrary executable name when using this parameter. + pipe: :class:`bool` If ``True``, denotes that ``source`` parameter will be passed to the stdin of ffmpeg. Defaults to ``False``. stderr: Optional[:term:`py:file object`] A file-like object to pass to the Popen constructor. - Could also be an instance of ``subprocess.PIPE``. before_options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg before the ``-i`` flag. options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg after the ``-i`` flag. Raises - ------ + -------- ClientException The subprocess failed to be created. """ @@ -288,24 +317,26 @@ def __init__( self, source: str | io.BufferedIOBase, *, - executable: str = "ffmpeg", + executable: str = 'ffmpeg', pipe: bool = False, - stderr: IO[str] | None = None, + stderr: IO[bytes] | None = None, before_options: str | None = None, options: str | None = None, ) -> None: args = [] - subprocess_kwargs = { - "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, - "stderr": stderr, - } + subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} if isinstance(before_options, str): args.extend(shlex.split(before_options)) args.append("-i") args.append("-" if pipe else source) - args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning")) + + args.extend(("-f", "s16le", + "-ar", "48000", + "-ac", "2", + "-loglevel", "warning", + "-blocksize", str(self.BLOCKSIZE))) if isinstance(options, str): args.extend(shlex.split(options)) @@ -346,7 +377,7 @@ class FFmpegOpusAudio(FFmpegAudio): variable in order for this to work. Parameters - ---------- + ------------ source: Union[:class:`str`, :class:`io.BufferedIOBase`] The input that ffmpeg will take and convert to Opus bytes. If ``pipe`` is ``True`` then this is a file-like object that is @@ -357,9 +388,8 @@ class FFmpegOpusAudio(FFmpegAudio): The codec to use to encode the audio data. Normally this would be just ``libopus``, but is used by :meth:`FFmpegOpusAudio.from_probe` to opportunistically skip pointlessly re-encoding Opus audio data by passing - ``copy`` as the codec value. Any values other than ``copy``, or - ``libopus`` will be considered ``libopus``. ``opus`` will also be considered - ``libopus`` since the ``opus`` encoder is still in development. Defaults to ``libopus``. + ``copy`` as the codec value. Any values other than ``copy``, ``opus``, or + ``libopus`` will be considered ``libopus``. Defaults to ``libopus``. .. warning:: @@ -374,14 +404,13 @@ class FFmpegOpusAudio(FFmpegAudio): to the stdin of ffmpeg. Defaults to ``False``. stderr: Optional[:term:`py:file object`] A file-like object to pass to the Popen constructor. - Could also be an instance of ``subprocess.PIPE``. before_options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg before the ``-i`` flag. options: Optional[:class:`str`] Extra command line arguments to pass to ffmpeg after the ``-i`` flag. Raises - ------ + -------- ClientException The subprocess failed to be created. """ @@ -390,19 +419,16 @@ def __init__( self, source: str | io.BufferedIOBase, *, - bitrate: int = 128, + bitrate: int | None = None, codec: str | None = None, - executable: str = "ffmpeg", - pipe=False, - stderr=None, - before_options=None, - options=None, + executable: str = 'ffmpeg', + pipe: bool = False, + stderr: IO[bytes] | None = None, + before_options: str | None = None, + options: str | None = None, ) -> None: args = [] - subprocess_kwargs = { - "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, - "stderr": stderr, - } + subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} if isinstance(before_options, str): args.extend(shlex.split(before_options)) @@ -410,35 +436,19 @@ def __init__( args.append("-i") args.append("-" if pipe else source) - # use "libopus" when "opus" is specified since the "opus" encoder is incomplete - # link to ffmpeg docs: https://www.ffmpeg.org/ffmpeg-codecs.html#opus - codec = "copy" if codec == "copy" else "libopus" - - args.extend( - ( - "-map_metadata", - "-1", - "-f", - "opus", - "-c:a", - codec, - "-loglevel", - "warning", - ) - ) - - # only pass in bitrate, sample rate, channels arguments when actually encoding to avoid ffmpeg warnings - if codec != "copy": - args.extend( - ( - "-ar", - "48000", - "-ac", - "2", - "-b:a", - f"{bitrate}k", - ) - ) + codec = "copy" if codec in ("opus", "libopus", "copy") else "libopus" + bitrate = bitrate if bitrate is not None else 128 + + args.extend(("-map_metadata", "-1", + "-f", "opus", + "-c:a", codec, + "-ar", "48000", + "-ac", "2", + "-b:a", f"{bitrate}k", + "-loglevel", "warning", + "-fec", "true", + "-packet_loss", "15", + "-blocksize", str(self.BLOCKSIZE))) if isinstance(options, str): args.extend(shlex.split(options)) @@ -450,45 +460,19 @@ def __init__( @classmethod async def from_probe( - cls: type[FT], + cls, source: str, *, method: str | Callable[[str, str], tuple[str | None, int | None]] | None = None, **kwargs: Any, - ) -> FT: - """|coro| + ) -> Self: + r"""|coro| A factory method that creates a :class:`FFmpegOpusAudio` after probing the input source for audio codec and bitrate information. - Parameters - ---------- - source - Identical to the ``source`` parameter for the constructor. - method: Optional[Union[:class:`str`, Callable[:class:`str`, :class:`str`]]] - The probing method used to determine bitrate and codec information. As a string, valid - values are ``native`` to use ffprobe (or avprobe) and ``fallback`` to use ffmpeg - (or avconv). As a callable, it must take two string arguments, ``source`` and - ``executable``. Both parameters are the same values passed to this factory function. - ``executable`` will default to ``ffmpeg`` if not provided as a keyword argument. - kwargs - The remaining parameters to be passed to the :class:`FFmpegOpusAudio` constructor, - excluding ``bitrate`` and ``codec``. - - Returns - ------- - :class:`FFmpegOpusAudio` - An instance of this class. - - Raises - ------ - AttributeError - Invalid probe method, must be ``'native'`` or ``'fallback'``. - TypeError - Invalid value for ``probe`` parameter, must be :class:`str` or a callable. - Examples - -------- + ---------- Use this function to create an :class:`FFmpegOpusAudio` instance instead of the constructor: :: @@ -509,13 +493,37 @@ def custom_probe(source, executable): source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe) voice_client.play(source) + + Parameters + ------------ + source + Identical to the ``source`` parameter for the constructor. + method: Optional[Union[:class:`str`, Callable[:class:`str`, :class:`str`]]] + The probing method used to determine bitrate and codec information. As a string, valid + values are ``native`` to use ffprobe (or avprobe) and ``fallback`` to use ffmpeg + (or avconv). As a callable, it must take two string arguments, ``source`` and + ``executable``. Both parameters are the same values passed to this factory function. + ``executable`` will default to ``ffmpeg`` if not provided as a keyword argument. + \*\*kwargs + The remaining parameters to be passed to the :class:`FFmpegOpusAudio` constructor, + excluding ``bitrate`` and ``codec``. + + Raises + -------- + AttributeError + Invalid probe method, must be ``'native'`` or ``'fallback'``. + TypeError + Invalid value for ``probe`` parameter, must be :class:`str` or a callable. + + Returns + -------- + :class:`FFmpegOpusAudio` + An instance of this class. """ - executable = kwargs.get("executable") + executable = kwargs.get('executable') codec, bitrate = await cls.probe(source, method=method, executable=executable) - # only re-encode if the source isn't already opus, else directly copy the source audio stream - codec = "copy" if codec in ("opus", "libopus") else "libopus" - return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore + return cls(source, bitrate=bitrate, codec=codec, **kwargs) @classmethod async def probe( @@ -530,7 +538,7 @@ async def probe( Probes the input source for bitrate and codec information. Parameters - ---------- + ------------ source Identical to the ``source`` parameter for :class:`FFmpegOpusAudio`. method @@ -538,17 +546,17 @@ async def probe( executable: :class:`str` Identical to the ``executable`` parameter for :class:`FFmpegOpusAudio`. - Returns - ------- - Optional[Tuple[Optional[:class:`str`], Optional[:class:`int`]]] - A 2-tuple with the codec and bitrate of the input source. - Raises - ------ + -------- AttributeError Invalid probe method, must be ``'native'`` or ``'fallback'``. TypeError Invalid value for ``probe`` parameter, must be :class:`str` or a callable. + + Returns + --------- + Optional[Tuple[Optional[:class:`str`], :class:`int`]] + A 2-tuple with the codec and bitrate of the input source. """ method = method or "native" @@ -573,82 +581,60 @@ async def probe( ) codec = bitrate = None - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: - codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore - except Exception: + codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: if not fallback: _log.exception("Probe '%s' using '%s' failed", method, executable) - return # type: ignore + return None, None - _log.exception( - "Probe '%s' using '%s' failed, trying fallback", method, executable - ) + _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) try: - codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore - except Exception: + codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException: _log.exception("Fallback probe using '%s' failed", executable) else: - _log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) + _log.debug("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) else: - _log.info("Probe found codec=%s, bitrate=%s", codec, bitrate) - finally: - return codec, bitrate + _log.debug("Probe found codec=%s, bitrate=%s", codec, bitrate) + + return codec, bitrate @staticmethod - def _probe_codec_native( - source, executable: str = "ffmpeg" - ) -> tuple[str | None, int | None]: - exe = ( - f"{executable[:2]}probe" - if executable in {"ffmpeg", "avconv"} - else executable - ) - - args = [ - exe, - "-v", - "quiet", - "-print_format", - "json", - "-show_streams", - "-select_streams", - "a:0", - source, - ] + def _probe_codec_native(source, executable: str = "ffmpeg") -> tuple[str | None, int | None]: + exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable + args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] output = subprocess.check_output(args, timeout=20) codec = bitrate = None if output: data = json.loads(output) - streamdata = data["streams"][0] + streamdata = data['streams'][0] - codec = streamdata.get("codec_name") - bitrate = int(streamdata.get("bit_rate", 0)) + codec = streamdata.get('codec_name') + bitrate = int(streamdata.get('bit_rate', 0)) bitrate = max(round(bitrate / 1000), 512) return codec, bitrate @staticmethod - def _probe_codec_fallback( - source, executable: str = "ffmpeg" - ) -> tuple[str | None, int | None]: - args = [executable, "-hide_banner", "-i", source] - proc = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) + def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: + args = [executable, '-hide_banner', '-i', source] + proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, _ = proc.communicate(timeout=20) - output = out.decode("utf8") + output = out.decode('utf8') codec = bitrate = None - codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output) + codec_match = re.search(r'Stream #0.*?Audio: (\w+)', output) if codec_match: codec = codec_match.group(1) - br_match = re.search(r"(\d+) [kK]b/s", output) + br_match = re.search(r'(\d+) [kK]b/s', output) if br_match: bitrate = max(int(br_match.group(1)), 512) @@ -668,7 +654,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): set to ``True``. Parameters - ---------- + ------------ original: :class:`AudioSource` The original AudioSource to transform. volume: :class:`float` @@ -676,7 +662,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): See :attr:`volume` for more info. Raises - ------ + ------- TypeError Not an audio source. ClientException @@ -685,10 +671,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): def __init__(self, original: AT, volume: float = 1.0): if not isinstance(original, AudioSource): - raise TypeError(f"expected AudioSource not {original.__class__.__name__}.") + raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') if original.is_opus(): - raise ClientException("AudioSource must not be Opus encoded.") + raise ClientException('AudioSource must not be Opus encoded.') self.original: AT = original self.volume = volume @@ -722,9 +708,14 @@ def read(self) -> bytes: class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 - def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): - threading.Thread.__init__(self) - self.daemon: bool = True + def __init__( + self, + source: AudioSource, + client: VoiceClient, + *, + after: Callable[[Exception | None], Any] | None = None, + ) -> None: + super().__init__(daemon=True, name=f'audio-player:{id(self):#x}') self.source: AudioSource = source self.client: VoiceClient = client self.after: Callable[[Exception | None], Any] | None = after @@ -775,6 +766,7 @@ def _do_run(self) -> None: _log.debug("Reconnected, resuming playback") self._speak(SpeakingState.voice) # reset our internal data + self._played_frames_offset += self.loops self.loops = 0 self._start = time.perf_counter() @@ -805,10 +797,9 @@ def _call_after(self) -> None: self.after(error) except Exception as exc: exc.__context__ = error - _log.exception("Calling the after function failed.", exc_info=exc) + _log.exception('Calling the after function failed.', exc_info=exc) elif error: - msg = f"Exception in voice thread {self.name}" - _log.exception(msg, exc_info=error) + _log.exception('Exception in voice thread %s', self.name, exc_info=error) def stop(self) -> None: self._end.set() @@ -821,7 +812,6 @@ def pause(self, *, update_speaking: bool = True) -> None: self._speak(SpeakingState.none) def resume(self, *, update_speaking: bool = True) -> None: - self._played_frames_offset += self.loops self.loops = 0 self._start = time.perf_counter() self._resumed.set() @@ -834,23 +824,17 @@ def is_playing(self) -> bool: def is_paused(self) -> bool: return not self._end.is_set() and not self._resumed.is_set() - def _set_source(self, source: AudioSource) -> None: + def set_source(self, source: AudioSource) -> None: with self._lock: self.pause(update_speaking=False) self.source = source self.resume(update_speaking=False) - def _speak(self, state: SpeakingState) -> None: + def _speak(self, speaking: SpeakingState) -> None: try: - asyncio.run_coroutine_threadsafe( - self.client.ws.speak(state), self.client.loop - ) - except Exception as e: - _log.info("Speaking call in player failed: %s", e) - - def played_frames(self) -> int: - """Gets the number of 20ms frames played since the start of the audio file.""" - return self._played_frames_offset + self.loops + asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop) + except Exception as exc: + _log.exception("Speaking call in player failed", exc_info=exc) def send_silence(self, count: int = 5) -> None: try: @@ -858,3 +842,6 @@ def send_silence(self, count: int = 5) -> None: self.client.send_audio_packet(OPUS_SILENCE, encode=False) except Exception: pass + + def played_frames(self) -> int: + return self._played_frames_offset + self.loops diff --git a/discord/sinks/__init__.py b/discord/sinks/__init__.py index 15498aff27..6db5209af0 100644 --- a/discord/sinks/__init__.py +++ b/discord/sinks/__init__.py @@ -9,7 +9,6 @@ """ from .core import * -from .enums import * from .errors import * from .m4a import * from .mka import * diff --git a/discord/sinks/core.py b/discord/sinks/core.py index c828e8034f..9c87517c11 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -25,26 +25,24 @@ from __future__ import annotations -import asyncio import logging -import struct import sys -import time -from collections import namedtuple -from collections.abc import Callable, Coroutine, Iterable -from functools import partial -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from collections.abc import Callable, Generator, Sequence +import inspect +import subprocess +import shlex +import threading +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar -from discord import opus, utils -from discord.enums import SpeakingState -from discord.utils import MISSING - -from .enums import SinkFilteringMode +from discord.utils import MISSING, SequenceProxy +from discord.player import FFmpegAudio if TYPE_CHECKING: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Self - from discord import abc + from discord.user import User + from discord.member import Member + from discord.voice.packets import VoiceData from ..voice.client import VoiceClient @@ -54,8 +52,6 @@ __all__ = ( "Sink", "RawData", - "SinkFilter", - "SinkHandler", ) @@ -69,747 +65,465 @@ _log = logging.getLogger(__name__) -def is_rtcp(data: bytes) -> bool: - return 200 <= data[1] <= 204 - - -class SinkFilter(Generic[S]): - """Represents a filter for a :class:`~.Sink`. - - This has to be inherited in order to provide a filter to a sink. - - .. versionadded:: 2.7 +class SinkBase: + """Represents an audio sink in which user's audios are stored. """ - @overload - async def filter_packet( - self, sink: S, user: abc.Snowflake, packet: RawData - ) -> bool: ... - - @overload - def filter_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> bool: ... + __sink_listeners__: list[tuple[str, str]] - def filter_packet( - self, sink: S, user: abc.Snowflake, packet: RawData - ) -> bool | Coroutine[Any, Any, bool]: - """|maybecoro| + _client: VoiceClient | None - This is called automatically everytime a voice packet is received. + def __new__(cls) -> Self: + listeners = {} - Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. - - Parameters - ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet was received from. - packet: :class:`~.RawData` - The raw data packet. - - Returns - ------- - :class:`bool` - Whether the filter was successful. - """ - raise NotImplementedError("subclasses must implement this") + for base in reversed(cls.__mro__): + for elem, value in base.__dict__.items(): + if elem in listeners: + del listeners[elem] - @overload - async def filter_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> bool: ... + is_static = isinstance(value, staticmethod) + if is_static: + value = value.__func__ - @overload - def filter_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> bool: ... + if not hasattr(value, '__sink_listener__'): + continue - def filter_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> bool | Coroutine[Any, Any, bool]: - """|maybecoro| + listeners[elem] = value - This is called automatically everytime a speaking state is updated. + listeners_list = [] + for listener in listeners.values(): + for listener_name in listener.__sink_listener_names__: + listeners_list.append((listener_name, listener.__name__)) - Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. + cls.__sink_listeners__ = listeners_list + return super().__new__(cls) - Parameters - ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet was received from. - before: :class:`~discord.SpeakingState` - The speaking state before the update. - after: :class:`~discord.SpeakingState` - The speaking state after the update. - - Returns - ------- - :class:`bool` - Whether the filter was successful. - """ - raise NotImplementedError("subclasses must implement this") + @property + def root(self) -> Sink: + """Returns the root parent of this sink.""" + return self # type: ignore - @overload - async def filter_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> bool: ... + @property + def parent(self) -> Sink | None: + """Returns the parent of this sink.""" + raise NotImplementedError - @overload - def filter_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> bool: ... + @property + def child(self) -> Sink | None: + """Returns this sink's child.""" + raise NotImplementedError - def filter_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> bool | Coroutine[Any, Any, bool]: - """|maybecoro| + @property + def children(self) -> Sequence[Sink]: + """Returns the full list of children of this sink.""" + raise NotImplementedError - This is called automatically everytime a speaking state is updated. + @property + def client(self) -> VoiceClient | None: + """Returns the voice client this sink is connected to.""" + return self._client - Depending on what bool-like this returns, it will dispatch some events in the parent ``sink``. + def is_opus(self) -> bool: + """Returns whether this sink is opus.""" + return False - Parameters - ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet was received from. - channel: :class:`~discord.abc.Snowflake` - The channel the user has connected to. This is usually resolved into the proper guild channel type, but - defaults to a :class:`~discord.Object` when not found. - - Returns - ------- - :class:`bool` - Whether the filter was successful. - """ - raise NotImplementedError("subclasses must implement this") + def write(self, user: User | Member | None, data: VoiceData) -> None: + """Writes the provided ``data`` into the ``user`` map.""" + raise NotImplementedError def cleanup(self) -> None: - """A function called when the filter is ready for cleanup.""" + """Cleans this sink.""" + raise NotImplementedError + def _register_child(self, child: Sink) -> None: + """Registers a child to this sink.""" + raise NotImplementedError -class SinkHandler(Generic[S]): - """Represents a handler for a :class:`~.Sink`. + def walk_children(self, *, with_self: bool = False) -> Generator[Sink, None, None]: + """Iterates through all the children of this sink, including nested.""" + if with_self: + yield self # type: ignore - This has to be inherited in order to provide a handler to a sink. + for child in self.children: + yield child + yield from child.walk_children() - .. versionadded:: 2.7 - """ + def __del__(self) -> None: + self.cleanup() - @overload - async def handle_packet( - self, sink: S, user: abc.Snowflake, packet: RawData - ) -> Any: ... - @overload - def handle_packet(self, sink: S, user: abc.Snowflake, packet: RawData) -> Any: ... +class Sink(SinkBase): + """Object that stores the recordings of the audio data. - def handle_packet( - self, sink: S, user: abc.Snowflake, packet: RawData - ) -> Any | Coroutine[Any, Any, Any]: - """|maybecoro| + Can be subclassed for extra customizability. - This is called automatically everytime a voice packet which has successfully passed the filters is received. + .. versionadded:: 2.0 + """ - Parameters - ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet is from. - packet: :class:`~.RawData` - The raw data packet. - """ + _parent: Sink | None = None + _child: Sink | None = None + _client = None - @overload - async def handle_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> Any: ... + def __init__(self, *, dest: Sink | None = None) -> None: + if dest is not None: + self._register_child(dest) + else: + self._child = dest - @overload - def handle_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> Any: ... + def _register_child(self, child: Sink) -> None: + if child in self.root.walk_children(): + raise RuntimeError("Sink is already registered") + self._child = child + child._parent = self - def handle_speaking_state( - self, sink: S, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> Any | Coroutine[Any, Any, Any]: - """|maybecoro| + @property + def root(self) -> Sink: + if self.parent is None: + return self + return self.parent - This is called automatically everytime a speaking state update is received which has successfully passed the filters. + @property + def parent(self) -> Sink | None: + return self._parent - Parameters - ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet was received from. - before: :class:`~discord.SpeakingState` - The speaking state before the update. - after: :class:`~discord.SpeakingState` - The speaking state after the update. - """ + @property + def child(self) -> Sink | None: + return self._child - @overload - async def handle_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> Any: ... + @property + def children(self) -> Sequence[Sink]: + return [self._child] if self._child else [] - @overload - def handle_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> Any: ... + @property + def client(self) -> VoiceClient | None: + if self.parent is not None: + return self.parent.client + else: + return self._client - def handle_user_connect( - self, - sink: S, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> Any | Coroutine[Any, Any, Any]: - """|maybecoro| + @classmethod + def listener(cls, name: str = MISSING): + """Registers a sink method as a listener. - This is called automatically everytime a user has connected a voice channel which has successfully passed the filters. + You can stack this decorator and pass the ``name`` parameter to mark the same function + to listen to various events. Parameters ---------- - sink: :class:`~.Sink` - The sink the packet was received from, if the filter check goes through. - user: :class:`~discord.abc.Snowflake` - The user that the packet was received from. - channel: :class:`~discord.abc.Snowflake` - The channel the user has connected to. This is usually resolved into the proper guild channel type, but - defaults to a :class:`~discord.Object` when not found. + name: :class:`str` + The name of the event, must not be prefixed with ``on_``. Defaults to the function name. """ - def cleanup(self) -> None: - """A function called when the handler is ready for cleanup.""" - - -class RawData: - """Handles raw data from Discord so that it can be decrypted and decoded to be used. - - .. versionadded:: 2.0 - """ - - unpacker = struct.Struct(">xxHII") - _ext_header = namedtuple("Extension", "profile length values") - _ext_magic = b"\xbe\xde" - - if TYPE_CHECKING: - sequence: int - timestamp: int - ssrc: int - - def __init__(self, raw_data: bytes, client: VoiceClient): - data: bytearray = bytearray(raw_data) - self.client: VoiceClient = client - - self.version: int = data[0] >> 6 - self.padding: bool = bool(data[0] & 0b00100000) - self.extended: bool = bool(data[0] & 0b00010000) - self.cc: int = data[0] & 0b00001111 - self.marker: bool = bool(data[1] & 0b10000000) - self.payload: int = data[1] & 0b01111111 - - self.sequence, self.timestamp, self.ssrc = self.unpacker.unpack_from(data) - self.csrcs: tuple[int, ...] = () - self.extension = None - self.extension_data: dict[int, bytes] = {} - - self.header = data[:12] - self.data = data[12:] - self.decrypted_data: bytes | None = None - self.decoded_data: bytes = MISSING - - self.nonce: bytes = b"" - self._rtpsize: bool = False + if name is not MISSING and not isinstance(name, str): + raise TypeError(f"expected a str for listener name, got {name.__class__.__name__} instead") - self._decoder: opus.Decoder = opus.Decoder() - self.receive_time: float = time.perf_counter() + def decorator(func): + actual = func - if self.cc: - fmt = ">%sI" % self.cc - offset = struct.calcsize(fmt) + 12 - self.csrcs = struct.unpack(fmt, data[12:offset]) - self.data = data[offset:] + if isinstance(actual, staticmethod): + actual = actual.__func__ - def adjust_rtpsize(self) -> None: - self._rtpsize = True - self.nonce = self.data[-4:] + if inspect.iscoroutinefunction(actual): + raise TypeError("listener functions must not be coroutines") - if not self.extended: - self.data = self.data[:-4] + actual.__sink_listener__ = True + to_assign = name or actual.__name__.removeprefix("on_") - self.header += self.data[:4] - self.data = self.data[4:-4] - - def update_headers(self, data: bytes) -> int: - if not self.extended: - return 0 - - if self._rtpsize: - data = self.header[-4:] + data - - profile, length = struct.unpack_from(">2sH", data) - - if profile == self._ext_magic: - self._parse_bede_header(data, length) - - values = struct.unpack(">%sI" % length, data[4 : 4 + length * 4]) - self.extension = self._ext_header(profile, length, values) - - offset = 4 + length * 4 - if self._rtpsize: - offset -= 4 - return offset - - def _parse_bede_header(self, data: bytes, length: int) -> None: - offset = 4 - n = 0 - - while n < length: - next_byte = data[offset : offset + 1] - - if next_byte == b"\x00": - offset += 1 - continue - - header = struct.unpack(">B", next_byte)[0] - - element_id = header >> 4 - element_len = 1 + (header & 0b0000_1111) + try: + actual.__sink_listener_names__.append(to_assign) + except AttributeError: + actual.__sink_listener_names__ = [to_assign] - self.extension_data[element_id] = data[ - offset + 1 : offset + 1 + element_len - ] - offset += 1 + element_len - n += 1 + return func + return decorator - async def decode(self) -> bytes: - if not self.decrypted_data: - _log.debug("Attempted to decode an empty decrypted data frame") - return b"" - return await asyncio.to_thread( - self._decoder.decode, - self.decrypted_data, - ) +class MultiSink(Sink): + """A sink that can handle multiple sinks concurrently. + .. versionadded:: 2.7 + """ -class Sink: - r"""Represents a sink for voice recording. + def __init__(self, *destinations: Sink) -> None: + for dest in destinations: + self._register_child(dest) + self._children: list[Sink] = list(destinations) - This is used as a way of "storing" the recordings. + def _register_child(self, child: Sink) -> None: + if child in self.root.walk_children(): + raise RuntimeError("Sink is already registered") + child._parent = self - This class is abstracted, and must be subclassed in order to apply functionalities to - it. + @property + def child(self) -> Sink | None: + return self._children[0] if self._children else None - Parameters - ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - """ + @property + def children(self) -> Sequence[Sink]: + return SequenceProxy(self._children) - if TYPE_CHECKING: - __filtering_mode: SinkFilteringMode - _filter_strat: Callable[..., bool] - client: VoiceClient + def add_destination(self, dest: Sink, /) -> None: + """Adds a sink to be dispatched in this sink. - __listeners__: dict[str, list[Callable[..., Any]]] = {} + Parameters + ---------- + dest: :class:`Sink` + The sink to register as this one's child. - def __init_subclass__(cls) -> None: - listeners: dict[str, list[Callable[..., Any]]] = {} + Raises + ------ + RuntimeError + The sink is already registered. + """ + self._register_child(dest) - for base in reversed(cls.__mro__): - for elem, value in base.__dict__.items(): - if elem in listeners: - del listeners[elem] + def remove_destination(self, dest: Sink, /) -> None: + """Removes a sink from this sink dispatch. - if isinstance(value, staticmethod): - value = value.__func__ - elif isinstance(value, classmethod): - value = partial(value.__func__, cls) + Parameters + ---------- + dest: :class:`Sink` + The sink to remove. + """ - if not hasattr(value, "__listener__"): - continue + try: + self._children.remove(dest) + except ValueError: + pass + else: + dest._parent = None - event_name = getattr(value, "__listener_name__", elem).removeprefix( - "on_" - ) - try: - listeners[event_name].append(value) - except KeyError: - listeners[event_name] = [value] +if TYPE_CHECKING: + from typing_extensions import deprecated + + @deprecated( + "RawData has been deprecated and will be removed in 3.0 in favour of VoiceData", + category=DeprecationWarning, + ) + def RawData(**kwargs: Any) -> Any: + """Deprecated since version 2.7, use :class:`VoiceData` instead.""" +else: + class RawData: + def __init__(self, **kwargs: Any) -> None: + raise DeprecationWarning("RawData has been deprecated in favour of VoiceData") - cls.__listeners__ = listeners +class _FFmpegSink(Sink): def __init__( self, *, - filters: list[SinkFilter] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler] = MISSING, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, + executable: str = 'ffmpeg', + stderr: IO[bytes] | None = None, + before_options: str | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, ) -> None: - self._paused: bool = False - self.filtering_mode = filtering_mode - self._filters: list[SinkFilter] = filters or [] - self._handlers: list[SinkHandler] = handlers or [] - self.__dispatch_set: set[asyncio.Task[Any]] = set() - self._listeners: dict[str, list[Callable[[Iterable[object]], bool]]] = ( - self.__listeners__ - ) - - @property - def filtering_mode(self) -> SinkFilteringMode: - return self.__filtering_mode - - @filtering_mode.setter - def filtering_mode(self, value: SinkFilteringMode) -> None: - if value is SinkFilteringMode.all: - self._filter_strat = all - elif value is SinkFilteringMode.any: - self._filter_strat = any - else: - raise TypeError( - f"expected a FilteringMode enum member, got {value.__class__.__name__}" - ) + super().__init__() - self.__filtering_mode = value + self.filename: str = filename or "pipe:1" + self.buffer: IO[bytes] = buffer - def dispatch(self, event: str, *args: Any, **kwargs: Any) -> Any: - _log.debug("Dispatching sink %s event %s", self.__class__.__name__, event) - method = f"on_{event}" + self.on_error = error_hook or self._on_error - listeners = self.__listeners__.get(event, []) - for coro in listeners: - self._schedule_event(coro, method, *args, **kwargs) + args = [executable, "-hide_banner"] + subprocess_kwargs: dict[str, Any] = {"stdin": subprocess.PIPE} + if self.buffer is not MISSING: + subprocess_kwargs["stdout"] = subprocess.PIPE - try: - coro = getattr(self, method) - except AttributeError: - pass - else: - self._schedule_event(coro, method, *args, **kwargs) - - async def _run_event( - self, - coro: Callable[..., Coroutine[Any, Any, Any]], - event_name: str, - *args: Any, - **kwargs: Any, - ) -> None: - try: - await coro(*args, **kwargs) - except asyncio.CancelledError: - pass - except Exception as exc: + piping_stderr = False + if stderr is not None: try: - await self.on_error(event_name, exc, *args, **kwargs) - except asyncio.CancelledError: - pass - - def _call_voice_packet_handlers(self, user: abc.Snowflake, packet: RawData) -> None: - for handler in self._handlers: - task = asyncio.create_task( - utils.maybe_coroutine( - handler.handle_packet, - self, - user, - packet, - ) - ) - self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.discard) - - def _call_user_connect_handlers( - self, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> None: - for handler in self._handlers: - task = asyncio.create_task( - utils.maybe_coroutine( - handler.handle_user_connect, - self, - user, - channel, - ), - ) - self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.discard) - - def _call_speaking_state_handlers( - self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> None: - for handler in self._handlers: - task = asyncio.create_task( - utils.maybe_coroutine( - handler.handle_speaking_state, - self, - user, - before, - after, - ), - ) - self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.discard) - - def _schedule_event( - self, - coro: Callable[..., Coroutine[Any, Any, Any]], - event_name: str, - *args: Any, - **kwargs: Any, - ) -> asyncio.Task: - wrapped = self._run_event(coro, event_name, *args, **kwargs) + stderr.fileno() + except Exception: + piping_stderr = True + subprocess_kwargs["stderr"] = subprocess.PIPE + + if isinstance(before_options, str): + args.extend(shlex.split(before_options)) + + args.extend({ + "-f": "s161e", + "-ar": "48000", + "-ac": "2", + "-i": "pipe:0", + "-loglevel": "warning", + "-blocksize": str(FFmpegAudio.BLOCKSIZE) + }) + + if isinstance(options, str): + args.extend(shlex.split(options)) + + args.append(self.filename) + + self._process: subprocess.Popen = MISSING + self._process = self._spawn_process(args, **subprocess_kwargs) + + self._stdin: IO[bytes] = self._process.stdin # type: ignore + self._stdout: IO[bytes] | None = None + self._stderr: IO[bytes] | None = None + self._stdout_reader_thread: threading.Thread | None = None + self._stderr_reader_thread: threading.Thread | None = None + + if self.buffer: + n = f"popen-stdout-reader:pid-{self._process.pid}" + self._stdout = self._process.stdout + _args = (self._stdout, self.buffer) + self._stdout_reader_thread = threading.Thread(target=self._pipe_reader, args=_args, daemon=True, name=n) + self._stdout_reader_thread.start() + + if piping_stderr: + n = f"popen-stderr-reader:pid-{self._process.pid}" + self._stderr = self._process.stderr + _args = (self._stderr, stderr) + self._stderr_reader_thread = threading.Thread(target=self._pipe_reader, args=_args, daemon=True, name=n) + self._stderr_reader_thread.start() - task = asyncio.create_task(wrapped, name=f"sinks: {event_name}") - self.__dispatch_set.add(task) - task.add_done_callback(self.__dispatch_set.discard) - return task - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} id={id(self):#x}>" + @staticmethod + def _on_error(_self: _FFmpegSink, error: Exception, data: VoiceData | None) -> None: + _self.client.stop_recording() # type: ignore - def stop(self) -> None: - """Stops this sink's recording. - - This is the place where :meth:`.cleanup` should be called. - """ - self.cleanup() + def is_opus(self) -> bool: + return False def cleanup(self) -> None: - """Cleans all the data in this sink. - - This should be called when you won't be performing any more operations in this sink. - """ + self._kill_processes() + self._process = self._stdout = self._stdin = self._stderr = MISSING - for task in list(self.__dispatch_set): - if task.done(): - continue - task.cancel() + def write(self, user: User | Member | None, data: VoiceData) -> None: + if self._process and not self._stdin.closed: + audio = data.opus if self.is_opus() else data.pcm + assert audio is not None - for filter in self._filters: - filter.cleanup() - - for handler in self._handlers: - handler.cleanup() - - def add_filter(self, filter: SinkFilter, /) -> None: - """Adds a filter to this sink. + try: + self._stdin.write(audio) + except Exception as exc: + _log.exception("Error while writing audio data to stdin ffmpeg") + self._kill_processes() + self.on_error(self, exc, data) - Parameters - ---------- - filter: :class:`~.SinkFilter` - The filter to add. + def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: + _log.debug("Spawning ffmpeg process with command %s and kwargs %s", args, subprocess_kwargs) + process = None - Raises - ------ - TypeError - You did not provide a Filter object. - """ + try: + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) + except FileNotFoundError: + executable = args.partition(' ')[0] if isinstance(args, str) else args[0] + raise Exception(f"{executable!r} executable was not found") from None + except subprocess.SubprocessError as exc: + raise Exception(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc + else: + return process - if not isinstance(filter, SinkFilter): - raise TypeError( - f"expected a Filter object, not {filter.__class__.__name__}" - ) - self._filters.append(filter) + def _kill_processes(self) -> None: + proc: subprocess.Popen = getattr(self, "_process", MISSING) - def remove_filter(self, filter: SinkFilter, /) -> None: - """Removes a filter from this sink. + if proc is MISSING: + return - Parameters - ---------- - filter: :class:`~.SinkFilter` - The filter to remove. - """ + _log.debug("Terminating ffmpeg process %s", proc.pid) try: - self._filters.remove(filter) - except ValueError: + self._stdin.close() + except Exception: pass - def add_handler(self, handler: SinkHandler, /) -> None: - """Adds a handler to this sink. - - Parameters - ---------- - handler: :class:`~.SinkHandler` - The handler to add. - - Raises - ------ - TypeError - You did not provide a Handler object. - """ - - if not isinstance(handler, SinkHandler): - raise TypeError( - f"expected a Handler object, not {handler.__class__.__name__}" - ) - self._handlers.append(handler) - - def remove_handler(self, handler: SinkHandler, /) -> None: - """Removes a handler from this sink. - - Parameters - ---------- - handler: :class:`~.SinkHandler` - The handler to remove. - """ + _log.debug("Waiting for ffmpeg process %s", proc.pid) try: - self._handlers.remove(handler) - except ValueError: + proc.wait(5) + except Exception: pass - @staticmethod - def listener( - event: str = MISSING, - ) -> Callable[ - [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] - ]: - """Registers a function to be an event listener for this sink. - - The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised; and - also must be inside a sink class. - - Parameters - ---------- - event: :class:`str` - The event name to listen to. If not provided, defaults to the function name. - - Raises - ------ - TypeError - The coroutine passed is not actually a coroutine, or the listener is not in a sink class. - - Example - ------- - - .. code-block:: python3 - - class MySink(Sink): - @Sink.listener() - async def on_member_speaking_state_update(member, ssrc, state): - pass - """ + try: + proc.kill() + except Exception as exc: + _log.exception( + "Ignoring exception while killing Popen process %s", + proc.pid, + exc_info=exc, + ) - def decorator( - func: Callable[P, Coroutine[Any, Any, R]], - ) -> Callable[P, Coroutine[Any, Any, R]]: - parts = func.__qualname__.split(".") + if proc.poll() is None: + _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) + proc.communicate() + _log.info("ffmpeg process %s should have terminated with a return code of %s", proc.pid, proc.returncode) + else: + _log.info("ffmpeg process %s successfully terminated with return code of %s", proc.pid, proc.returncode) - if not parts or not len(parts) > 1: - raise TypeError("event listeners must be declared in a Sink class") + self._process = MISSING - if parts[-1] != func.__name__: - raise NameError( - "qualified name and function name mismatch, this should not happen" - ) + def _pipe_reader(self, source: IO[bytes], dest: IO[bytes]) -> None: + while self._process: + if source.closed: + return - if not asyncio.iscoroutinefunction(func): - raise TypeError("event listeners must be coroutine functions") + try: + data = source.read(FFmpegAudio.BLOCKSIZE) + except (OSError, ValueError) as exc: + _log.debug("FFmpeg stdin pipe closed with exception %s", exc) + return + except Exception: + _log.debug("An error ocurred in %s, this can be ignored", self, exc_info=True) + return - func.__listener__ = True - if event is not MISSING: - func.__listener_name__ = event - return func + if data is None: + return - return decorator + try: + dest.write(data) + except Exception as exc: + _log.exception("Error while writing to destination pipe %s", self, exc_info=exc) + self._kill_processes() + self.on_error(self, exc, None) + return - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - pass - async def on_unfiltered_voice_packet_receive( - self, user: abc.Snowflake, data: RawData - ) -> None: - pass +class FilterSink(Sink): + r"""A sink that calls filtering callbacks before writing. - async def on_speaking_state_update( - self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> None: - pass + .. versionadded:: 2.7 - async def on_unfiltered_speaking_state_update( - self, user: abc.Snowflake, before: SpeakingState, after: SpeakingState - ) -> None: - pass + Parameters + ---------- + destination: :class:`Sink` + The sink that is being filtered. + filters: Sequence[Callable[[:class:`User` | :class`Member` | :data:`None`, :class:`VoiceData`], :class:`bool`]] + The filters of this sink. + filtering_mode: Literal["all", "any"] + How the filters should work, if ``all`, all filters must be successful in order for + a voice data packet to be written. Using ``any`` will make it so only one filter is + required to be successful in order for a voice data packet to be written. + """ - async def on_user_connect( + def __init__( self, - user: abc.Snowflake, - channel: abc.Snowflake, - ) -> None: - pass - - async def on_unfiltered_user_connect( - self, user: abc.Snowflake, channel: abc.Snowflake - ) -> None: - pass - - async def on_error( - self, event: str, exception: Exception, *args: Any, **kwargs: Any + destination: Sink, + filters: Sequence[Callable[[User | Member | None, VoiceData], bool]], + *, + filtering_mode: Literal["all", "any"] = "all", ) -> None: - _log.exception( - "An error ocurred in sink %s while dispatching the event %s", - self, - event, - exc_info=exception, - ) + if not filters: + raise ValueError("filters must have at least one callback") - def is_recording(self) -> bool: - """Whether this sink is currently available to record, and doing so.""" - state = self.client._connection - return state.is_recording() and id(self) in state._sinks + if not isinstance(destination, SinkBase): + raise TypeError(f"expected a Sink object, got {destination.__class__.__name__}") - def is_paused(self) -> bool: - """Whether this sink is currently paused from recording.""" - return self._paused + self._filter_strat = all if filtering_mode == "all" else any + self.filters: Sequence[Callable[[User | Member | None, VoiceData], bool]] = filters + self.destination: Sink = destination + super().__init__(dest=destination) - def pause(self) -> None: - """Pauses the recording of this sink. + def is_opus(self) -> bool: + return self.destination.is_opus() - No filter or handlers will be called when a sink is paused, and no - event will be dispatched. + def write(self, user: User | Member | None, data: VoiceData) -> None: + if self._filter_strat(f(user, data) for f in self.filters): + self.destination.write(user, data) - Pending events _could still be called_ even when a sink is paused, - so make sure you pause a sink when there are not current packets being - handled. - - You can resume the recording of this sink with :meth:`.resume`. - """ - self._paused = True - - def resume(self) -> None: - """Resumes the recording of this sink. - - You can pause the recording of this sink with :meth:`.pause`. - """ - self._paused = False + def cleanup(self) -> None: + self.filters = [] + self.destination.cleanup() diff --git a/discord/voice/__init__.py b/discord/voice/__init__.py index feedaa9f52..d6d13ab0df 100644 --- a/discord/voice/__init__.py +++ b/discord/voice/__init__.py @@ -8,10 +8,6 @@ :license: MIT, see LICENSE for more details. """ -from ._types import VoiceProtocol -from .client import VoiceClient - -__all__ = ( - "VoiceClient", - "VoiceProtocol", -) +from ._types import * +from .client import * +from .packets import * diff --git a/discord/voice/_types.py b/discord/voice/_types.py index b2e252bd39..6fc393c7b0 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -42,6 +42,10 @@ ClientT = TypeVar("ClientT", bound="Client", covariant=True) +__all__ = ( + "VoiceProtocol", +) + class VoiceProtocol(Generic[ClientT]): """A class that represents the Discord voice protocol. diff --git a/discord/voice/client.py b/discord/voice/client.py index efd0c71600..16cb32fcd3 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -29,19 +29,18 @@ import datetime import logging import struct -from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any, Literal, overload +import warnings from discord import opus from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource -from discord.sinks.core import RawData, Sink, is_rtcp +from discord.sinks.core import Sink from discord.sinks.errors import RecordingException from discord.utils import MISSING from ._types import VoiceProtocol - -# from .recorder import VoiceRecorderClient +from .receive import AudioReader from .state import VoiceConnectionState if TYPE_CHECKING: @@ -50,18 +49,19 @@ from discord import abc from discord.client import Client from discord.guild import Guild, VocalGuildChannel - from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Decoder, Encoder + from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Encoder from discord.raw_models import ( RawVoiceServerUpdateEvent, RawVoiceStateUpdateEvent, ) from discord.state import ConnectionState from discord.types.voice import SupportedModes - from discord.user import ClientUser + from discord.user import ClientUser, User + from discord.member import Member from .gateway import VoiceWebSocket + from .receive.reader import AfterCallback - AfterCallback = Callable[[Exception | None], Any] P = ParamSpec("P") _log = logging.getLogger(__name__) @@ -76,6 +76,10 @@ except ImportError: has_nacl = False +__all__ = ( + "VoiceClient", +) + class VoiceClient(VoiceProtocol): """Represents a Discord voice connection. @@ -129,11 +133,15 @@ def __init__( self._player: AudioPlayer | None = None self._player_future: asyncio.Future[None] | None = None self.encoder: Encoder = MISSING - self.decoder: Decoder = MISSING self._incr_nonce: int = 0 self._connection: VoiceConnectionState = self.create_connection_state() + self._ssrc_to_id: dict[int, int] = {} + self._id_to_ssrc: dict[int, int] = {} + self._event_listeners: dict[str, list] = {} + self._reader: AudioReader = MISSING + warn_nacl: bool = not has_nacl supported_modes: tuple[SupportedModes, ...] = ( "aead_xchacha20_poly1305_rtpsize", @@ -192,14 +200,40 @@ def checked_add(self, attr: str, value: int, limit: int) -> None: setattr(self, attr, val + value) def create_connection_state(self) -> VoiceConnectionState: - return VoiceConnectionState(self) + return VoiceConnectionState(self, hook=self._recv_hook) async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: + old_channel_id = self.channel.id if self.channel else None await self._connection.voice_state_update(data) + if data.channel_id is None: + return + + if self._reader and data.channel_id != old_channel_id: + _log.debug("Destroying voice receive decoders in guild %s", self.guild.id) + self._reader.packet_router.destroy_all_decoders() + async def on_voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: await self._connection.voice_server_update(data) + def _dispatch_sink(self, event: str, /, *args: Any, **kwargs: Any) -> None: + if self._reader: + self._reader.event_router.dispatch(event, *args, **kwargs) + + def _add_ssrc(self, user_id: int, ssrc: int) -> None: + self._ssrc_to_id[ssrc] = user_id + self._id_to_ssrc[user_id] = ssrc + + if self._reader: + self._reader.packet_router.set_user_id(ssrc, user_id) + + def _remove_ssrc(self, *, user_id: int) -> None: + ssrc = self._id_to_ssrc.pop(user_id, None) + + if ssrc: + self._reader.speaking_timer.drop_ssrc(ssrc) + self._ssrc_to_id.pop(ssrc, None) + async def connect( self, *, @@ -349,135 +383,6 @@ def _encrypt_aead_xchacha20_poly1305_rtpsize( + nonce[:4] ) - # decryption methods - - def _decrypt_rtp_xsalsa20_poly1305(self, data: bytes) -> bytes: - packet = RawData(data, self) - nonce = bytearray(24) - nonce[:12] = packet.header - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(bytes(packet.data), bytes(nonce)) - - if packet.extended: - offset = packet.update_headers(result) - result = result[offset:] - - return result - - def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes: - nonce = bytearray(24) - nonce[:8] = data[:8] - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(data[8:], bytes(nonce)) - - return data[:8] + result - - def _decrypt_xsalsa20_poly1305(self, data: bytes) -> bytes: - if is_rtcp(data): - func = self._decrypt_rtcp_xsalsa20_poly1305 - else: - func = self._decrypt_rtp_xsalsa20_poly1305 - return func(data) - - def _decrypt_rtp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: - packet = RawData(data, self) - nonce = packet.data[-24:] - voice_data = packet.data[:-24] - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(bytes(voice_data), bytes(nonce)) - - if packet.extended: - offset = packet.update_headers(result) - result = result[offset:] - - return result - - def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: - nonce = data[-24:] - header = data[:8] - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(data[8:-24], nonce) - - return header + result - - def _decrypt_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: - if is_rtcp(data): - func = self._decrypt_rtcp_xsalsa20_poly1305_suffix - else: - func = self._decrypt_rtp_xsalsa20_poly1305_suffix - return func(data) - - def _decrypt_rtp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: - packet = RawData(data, self) - nonce = bytearray(24) - nonce[:4] = packet.data[-4:] - voice_data = packet.data[:-4] - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(bytes(voice_data), bytes(nonce)) - - if packet.extended: - offset = packet.update_headers(result) - result = result[offset:] - - return result - - def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: - nonce = bytearray(24) - nonce[:4] = data[-4:] - header = data[:8] - - box = nacl.secret.SecretBox(bytes(self.secret_key)) - result = box.decrypt(data[8:-4], bytes(nonce)) - - return header + result - - def _decrypt_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: - if is_rtcp(data): - func = self._decrypt_rtcp_xsalsa20_poly1305_lite - else: - func = self._decrypt_rtp_xsalsa20_poly1305_lite - return func(data) - - def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: - packet = RawData(data, self) - packet.adjust_rtpsize() - - nonce = bytearray(24) - nonce[:4] = packet.nonce - voice_data = packet.data - - # Blob vomit - box = nacl.secret.Aead(bytes(self.secret_key)) - result = box.decrypt(bytes(voice_data), bytes(packet.header), bytes(nonce)) - - if packet.extended: - offset = packet.update_headers(result) - result = result[offset:] - - return result - - def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: - nonce = bytearray(24) - nonce[:4] = data[-4:] - header = data[:8] - - box = nacl.secret.Aead(bytes(self.secret_key)) - result = box.decrypt(data[8:-4], bytes(header), bytes(nonce)) - - return header + result - - def _decrypt_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: - if is_rtcp(data): - func = self._decrypt_rtcp_aead_xchacha20_poly1305_rtpsize - else: - func = self._decrypt_rtp_aead_xchacha20_poly1305_rtpsize - return func(data) - @overload def play( self, @@ -616,6 +521,9 @@ def stop(self) -> None: for cb, _ in self._player_future._callbacks: self._player_future.remove_done_callback(cb) self._player_future.set_result(None) + if self._reader: + self._reader.stop() + self._reader = MISSING self._player = None self._player_future = None @@ -646,7 +554,7 @@ def source(self, value: AudioSource) -> None: if self._player is None: raise ValueError("the client is not playing anything") - self._player._set_source(value) + self._player.set_source(value) def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: """Sends an audio packet composed of the ``data``. @@ -695,7 +603,7 @@ def elapsed(self) -> datetime.timedelta: def start_recording( self, sink: Sink, - callback: Callable[..., Coroutine[Any, Any, Any]] = MISSING, + callback: AfterCallback | None = None, *args: Any, sync_start: bool = MISSING, ) -> None: @@ -709,25 +617,23 @@ def start_recording( ---------- sink: :class:`~.Sink` A Sink in which all audio packets will be processed in. - callback: :ref:`coroutine ` - A function which is called after the bot has stopped recording. + callback: Callable[[:class:`Exception` | None], Any] + A function which is called after the bot has stopped recording. This must take exactly one positonal(-only) + parameter, ``exception``, which is the exception that was raised during the recording of the Sink. .. versionchanged:: 2.7 - This parameter is now optional. + This parameter is now optional, and must take exactly one parameter, ``exception``. \*args: The arguments to pass to the callback coroutine. + + .. deprecated:: 2.7 + Passing custom arguments to the callback is now deprecated and ignored. sync_start: :class:`bool` If ``True``, the recordings of subsequent users will start with silence. This is useful for recording audio just as it was heard. - .. warning:: - - This is a global voice client variable, this means, you can't have individual - sinks with different ``sync_start`` values. If you are willing to have such - functionality, you should consider creating your own :class:`discord.SinkHandler`. - - .. versionchanged:: 2.7 - This now defaults to ``MISSING``. + .. deprecated:: 2.7 + This parameter is now ignored and deprecated. Raises ------ @@ -742,19 +648,18 @@ def start_recording( if not isinstance(sink, Sink): raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}") + if self.is_recording(): + raise ClientException("Already recording audio") + + if len(args) > 0: + warnings.warn("'args' parameter is deprecated since 2.7 and will be removed in 3.0") if sync_start is not MISSING: - self._connection.sync_recording_start = sync_start + warnings.warn("'sync_tart' parameter is deprecated since 2.7 and will be removed in 3.0") - sink.client = self - self._connection.add_sink(sink) - if callback is not MISSING: - self._connection.recording_done_callbacks.append((callback, args)) + self._reader = AudioReader(sink, self, after=callback) + self._reader.start() - def stop_recording( - self, - *, - sink: Sink | None = None, - ) -> None: + def stop_recording(self) -> None: """Stops the recording of the provided ``sink``, or all recording sinks. .. versionadded:: 2.0 @@ -762,24 +667,29 @@ def stop_recording( Raises ------ RecordingException - The provided sink is not currently recording, or if ``None``, you are not recording. - - Paremeters - ---------- - sink: :class:`discord.Sink` - The sink to stop recording. + You are not recording. """ - - if sink is not None: - try: - self._connection.sinks.remove(sink) - except ValueError: - raise RecordingException("the provided sink is not currently recording") - - sink.stop() - return - self._connection.stop_record_socket() + if self._reader: + self._reader.stop() + self._reader = MISSING + else: + raise RecordingException("You are not recording") def is_recording(self) -> bool: """Whether the current client is recording in any sink.""" - return self._connection.is_recording() + return self._reader and self._reader.is_listening() + + def is_speaking(self, member: Member | User) -> bool | None: + """Whether a user is speaking. + + This is an approximate calculation and may have outdated or wrong data. + + If the member speaking status has not been yet saved, it returns ``None``. + + .. versionadded:: 2.7 + """ + ssrc = self._id_to_ssrc.get(member.id) + if ssrc is None: + return None + if self._reader: + return self._reader.speaking_timer.get_speaking(ssrc) diff --git a/discord/voice/enums.py b/discord/voice/enums.py index 5564c690b5..76f78c768d 100644 --- a/discord/voice/enums.py +++ b/discord/voice/enums.py @@ -39,8 +39,9 @@ class OpCodes(Enum): resume = 7 hello = 8 resumed = 9 - client_connect = 10 - client_disconnect = 11 + clients_connect = 11 + client_connect = 12 + client_disconnect = 13 # dave protocol stuff dave_prepare_transition = 21 diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index 8ce80d71a3..c7f63aacb9 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from _typeshed import ConvertibleToInt from .state import VoiceConnectionState @@ -139,10 +140,6 @@ def session_id(self) -> str | None: def session_id(self, value: str | None) -> None: self.state.session_id = value - @property - def dave_session(self) -> davey.DaveSession | None: - return self.state.dave_session - @property def self_id(self) -> int: return self._connection.self_id @@ -150,13 +147,15 @@ def self_id(self) -> int: async def _hook(self, *args: Any) -> Any: pass - async def send_as_bytes(self, op: int, data: bytes) -> None: - packet = bytes(op) + data - _log.debug("Sending voice websocket binary frame: op: %s data: %s", op, str(data)) + async def send_as_bytes(self, op: ConvertibleToInt, data: bytes) -> None: + packet = bytes([int(op)]) + data + _log.debug("Sending voice websocket binary frame: op: %s size: %d", op, len(data)) await self.ws.send_bytes(packet) async def send_as_json(self, data: Any) -> None: _log.debug("Sending voice websocket frame: %s.", data) + if data.get('op', None) == OpCodes.identify: + _log.info("Identifying ourselves: %s", data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json @@ -206,9 +205,9 @@ async def received_message(self, msg: Any, /): interval=min(interval, 5), ) self._keep_alive.start() - elif self.dave_session: + elif state.dave_session: if op == OpCodes.dave_prepare_transition: - _log.info("Preparing to upgrade to a DAVE connection for channel %s", state.channel_id) + _log.info("Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d", state.channel_id, data["transition_id"], data["protocol_version"]) state.dave_pending_transition = data transition_id = data["transition_id"] @@ -216,8 +215,8 @@ async def received_message(self, msg: Any, /): if transition_id == 0: await state.execute_dave_transition(data["transition_id"]) else: - if data["protocol_version"] == 0: - self.dave_session.set_passthrough_mode(True, 120) + if data["protocol_version"] == 0 and state.dave_session: + state.dave_session.set_passthrough_mode(True, 120) await self.send_dave_transition_ready(transition_id) elif op == OpCodes.dave_execute_transition: _log.info("Upgrading to DAVE connection for channel %s", state.channel_id) @@ -241,33 +240,36 @@ async def received_binary_message(self, msg: bytes) -> None: state = self.state - if not self.dave_session: + if not state.dave_session: return if op == OpCodes.mls_external_sender_package: - self.dave_session.set_external_sender(msg[3:]) + state.dave_session.set_external_sender(msg[3:]) elif op == OpCodes.mls_proposals: op_type = msg[3] - result = self.dave_session.process_proposals( + result = state.dave_session.process_proposals( davey.ProposalsOperationType.append if op_type == 0 else davey.ProposalsOperationType.revoke, msg[4:], ) if isinstance(result, davey.CommitWelcome): + data = (result.commit + result.welcome) if result.welcome else result.commit + _log.debug("Sending MLS key package with data: %s", data) await self.send_as_bytes( - OpCodes.mls_key_package.value, - (result.commit + result.welcome) if result.welcome else result.commit, + OpCodes.mls_commit_welcome, + data, ) - _log.debug("Processed MLS proposals for current dave session") + _log.debug("Processed MLS proposals for current dave session: %r", result) elif op == OpCodes.mls_commit_transition: transt_id = struct.unpack_from(">H", msg, 3)[0] try: - self.dave_session.process_commit(msg[5:]) + state.dave_session.process_commit(msg[5:]) if transt_id != 0: state.dave_pending_transition = { "transition_id": transt_id, "protocol_version": state.dave_protocol_version, } + _log.debug("Sending DAVE transition ready from MLS commit transition with data: %s", state.dave_pending_transition) await self.send_dave_transition_ready(transt_id) _log.debug("Processed MLS commit for transition %s", transt_id) except Exception as exc: @@ -276,12 +278,13 @@ async def received_binary_message(self, msg: bytes) -> None: elif op == OpCodes.mls_welcome: transt_id = struct.unpack_from(">H", msg, 3)[0] try: - self.dave_session.process_welcome(msg[5:]) + state.dave_session.process_welcome(msg[5:]) if transt_id != 0: state.dave_pending_transition = { "transition_id": transt_id, "protocol_version": state.dave_protocol_version, } + _log.debug("Sending DAVE transition ready from MLS welcome with data: %s", state.dave_pending_transition) await self.send_dave_transition_ready(transt_id) _log.debug("Processed MLS welcome for transition %s", transt_id) except Exception as exc: @@ -329,7 +332,6 @@ async def select_protocol(self, ip: str, port: int, mode: str) -> None: "port": port, "mode": mode, }, - "dave_protocol_version": self.state.dave_protocol_version, }, } await self.send_as_json(payload) @@ -389,8 +391,10 @@ async def poll_event(self) -> None: msg = await asyncio.wait_for(self.ws.receive(), timeout=30) if msg.type is aiohttp.WSMsgType.TEXT: + _log.debug("Received text payload: %s", msg.data) await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.BINARY: + _log.debug("Received binary payload: size: %d", len(msg.data)) await self.received_binary_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: _log.debug("Received %s", msg) @@ -455,7 +459,7 @@ async def identify(self) -> None: "user_id": str(state.user.id), "session_id": self.session_id, "token": self.token, - "max_dave_protocol_version": self.state.max_dave_proto_version, + "max_dave_protocol_version": state.max_dave_proto_version, }, } await self.send_as_json(payload) diff --git a/discord/voice/packets/__init__.py b/discord/voice/packets/__init__.py new file mode 100644 index 0000000000..d55c2a3312 --- /dev/null +++ b/discord/voice/packets/__init__.py @@ -0,0 +1,47 @@ +""" +discord.voice.packets +~~~~~~~~~~~~~~~~~~~~~ + +Sink packet handlers. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from .core import Packet +from .rtp import RTPPacket, RTCPPacket, FakePacket, ReceiverReportPacket, SenderReportPacket, SilencePacket + +if TYPE_CHECKING: + from discord import User, Member + +__all__ = ( + "Packet", + "RTPPacket", + "RTCPPacket", + "FakePacket", + "ReceiverReportPacket", + "SenderReportPacket", + "SilencePacket", + "VoiceData", +) + + +class VoiceData: + """Represents an audio data from a source. + + .. versionadded:: 2.7 + + Attributes + ---------- + packet: :class:`~discord.sinks.Packet` + The packet this source data contains. + source: :class:`~discord.User` | :class:`~discord.Member` | None + The user that emitted this audio source. + pcm: :class:`bytes` + The PCM bytes of this source. + """ + + def __init__(self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None) -> None: + self.packet: Packet = packet + self.source: User | Member | None = source + self.pcm: bytes = pcm if pcm else b'' diff --git a/discord/voice/packets/core.py b/discord/voice/packets/core.py new file mode 100644 index 0000000000..6a2ca96a89 --- /dev/null +++ b/discord/voice/packets/core.py @@ -0,0 +1,89 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations +from typing import TYPE_CHECKING + +from discord.opus import Decoder, _lib + +if TYPE_CHECKING: + from typing_extensions import Final + +if _lib is None: + DECODER = None +else: + DECODER = Decoder() + +OPUS_SILENCE: Final = b'\xf8\xff\xfe' + + +class Packet: + """Represents an audio stream bytes packet. + + Attributes + ---------- + data: :class:`bytes` + The bytes data of this packet. This has not been decoded. + """ + + if TYPE_CHECKING: + ssrc: int + sequence: int + timestamp: int + type: int + decrypted_data: bytes + + def __init__(self, data: bytes) -> None: + self.data: bytes = data + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}> data={len(self.data)} bytes>" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + return self.sequence == other.sequence and self.timestamp == other.timestamp + + def __gt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + return self.sequence > other.sequence and self.timestamp > other.timestamp + + def __lt__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + if self.ssrc != other.ssrc: + raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + return self.sequence < other.sequence and self.timestamp < other.timestamp + + def is_silence(self) -> bool: + data = getattr(self, 'decrypted_data', None) + return data == OPUS_SILENCE + + def __hash__(self) -> int: + return hash(self.data) diff --git a/discord/voice/packets/rtp.py b/discord/voice/packets/rtp.py new file mode 100644 index 0000000000..8ed822f92e --- /dev/null +++ b/discord/voice/packets/rtp.py @@ -0,0 +1,300 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from collections import namedtuple +import struct +from typing import TYPE_CHECKING, Any, Literal + +from .core import OPUS_SILENCE, Packet + +if TYPE_CHECKING: + from typing_extensions import Final + +MAX_UINT_32 = 0xffffffff +MAX_UINT_16 = 0xffff + +RTP_PACKET_TYPE_VOICE = 120 + + +def decode(data: bytes) -> Packet: + if not data[0] >> 6 == 2: + raise ValueError(f"Invalid packet header 0b{data[0]:0>8b}") + return _rtcp_map.get(data[1], RTPPacket)(data) + + +class FakePacket(Packet): + data = b'' + decrypted_data: bytes = b'' + extension_data: dict = {} + + def __init__( + self, + ssrc: int, + sequence: int, + timestamp: int, + ) -> None: + self.ssrc = ssrc + self.sequence = sequence + self.timestamp = timestamp + + def __bool__(self) -> Literal[False]: + return False + + +class SilencePacket(Packet): + decrypted_data: Final = OPUS_SILENCE + extension_data: Final[dict[int, Any]] = {} + sequence: int = -1 + + def __init__(self, ssrc: int, timestamp: int) -> None: + self.ssrc = ssrc + self.timestamp = timestamp + + def is_silence(self) -> bool: + return True + + +class RTPPacket(Packet): + """Represents an RTP packet. + + .. versionadded:: 2.7 + + Attributes + ---------- + data: :class:`bytes` + The raw data of the packet. + """ + + _hstruct = struct.Struct(">xxHII") + _ext_header = namedtuple("Extension", "profile length values") + _ext_magic = b"\xbe\xde" + + def __init__(self, data: bytes) -> None: + super().__init__(data) + + self.version: int = data[0] >> 6 + self.padding: bool = bool(data[0] & 0b00100000) + self.extended: bool = bool(data[0] & 0b00010000) + self.cc: int = data[0] & 0b00001111 + + self.marker: bool = bool(data[1] & 0b10000000) + self.payload: int = data[1] & 0b01111111 + + sequence, timestamp, ssrc = self._hstruct.unpack_from(data) + self.sequence = sequence + self.timestamp = timestamp + self.ssrc = ssrc + + self.csrcs: tuple[int, ...] = () + self.extension = None + self.extension_data: dict[int, bytes] = {} + + self.header = data[:12] + self.data = data[12:] + self.decrypted_data: bytes | None = None + + self.nonce: bytes = b'' + self._rtpsize: bool = False + + if self.cc: + fmt = '>%sI' % self.cc + offset = struct.calcsize(fmt) + 12 + self.csrcs = struct.unpack(fmt, data[12:offset]) + self.data = data[offset:] + + def adjust_rtpsize(self) -> None: + """Automatically adjusts this packet header and data based on the rtpsize format.""" + + self._rtpsize = True + self.nonce = self.data[-4:] + + if not self.extended: + self.data = self.data[:-4] + return + + self.header += self.data[:4] + self.data = self.data[4:-4] + + def update_extended_header(self, data: bytes) -> int: + """Updates the extended header using ``data`` and returns the pd offset.""" + + if not self.extended: + return 0 + + if self._rtpsize: + data = self.header[-4:] + data + + profile, length = struct.unpack_from(">2sH", data) + + if profile == self._ext_magic: + self._parse_bede_header(data, length) + + values = struct.unpack(">%sI" % length, data[4: 4 + length * 4]) + self.extension = self._ext_header(profile, length, values) + + offset = 4 + length * 4 + if self._rtpsize: + offset -= 4 + + return offset + + def _parse_bede_header(self, data: bytes, length: int) -> None: + offset = 4 + n = 0 + + while n < length: + next_byte = data[offset : offset + 1] + + if next_byte == b'\x00': + offset += 1 + continue + + header = struct.unpack(">B", next_byte)[0] + el_id = header >> 4 + el_len = 1 + (header & 0b0000_1111) + + self.extension_data[el_id] = data[offset + 1 : offset + 1 + el_len] + offset += 1 + el_len + n += 1 + + def __repr__(self) -> str: + return ( + "" + ) + + +class RTCPPacket(Packet): + _header = struct.Struct(">BBH") + _ssrc_fmt = struct.Struct(">I") + type = None + + def __init__(self, data: bytes) -> None: + super().__init__(data) + self.length: int + head, _, self.length = self._header.unpack_from(data) + + self.version: int = head >> 6 + self.padding: bool = bool(head & 0b00100000) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} version={self.version} padding={self.padding} length={self.length}>" + + @classmethod + def from_data(cls, data: bytes) -> Packet: + _, ptype, _ = cls._header.unpack_from(data) + return _rtcp_map[ptype](data) + + +def _parse_low(x: int, bitlen: int = 32) -> float: + return x / 2.0 ** bitlen + + +def _to_low(x: float, bitlen: int = 32) -> int: + return int(x * 2.0 ** bitlen) + + +class SenderReportPacket(RTCPPacket): + _info_fmt = struct.Struct(">5I") + _report_fmt = struct.Struct(">IB3x4I") + _24bit_int_fmt = struct.Struct(">4xI") + _info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count") + _report = namedtuple("RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr") + type = 200 + + if TYPE_CHECKING: + report_count: int + + def __init__(self, data: bytes) -> None: + super().__init__(data) + + self.ssrc = self._ssrc_fmt.unpack_from(data, 4)[0] + self.info = self._read_sender_info(data, 8) + + _report = self._report + reports: list[_report] = [] + for x in range(self.report_count): + offset = 28 + 24 * x + reports.append(self._read_report(data, offset)) + + self.reports: tuple[_report, ...] = tuple(reports) + self.extension = None + if len(data) > 28 + 24 * self.report_count: + self.extension = data[28 + 24 * self.report_count :] + + def _read_sender_info(self, data: bytes, offset: int) -> _info: + nhigh, nlow, rtp_ts, pcount, ocount = self._info_fmt.unpack_from(data, offset) + ntotal = nhigh + _parse_low(nlow) + return self._info(ntotal, rtp_ts, pcount, ocount) + + def _read_report(self, data: bytes, offset: int) -> _report: + ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset) + clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF + return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr) + + +class ReceiverReportPacket(RTCPPacket): + _report_fmt = struct.Struct(">IB3x4I") + _24bit_int_fmt = struct.Struct(">4xI") + _report = namedtuple("RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr") + type = 201 + + reports: tuple[_report, ...] + + if TYPE_CHECKING: + report_count: int + + def __init__(self, data: bytes) -> None: + super().__init__(data) + self.ssrc: int = self._ssrc_fmt.unpack_from(data, 4)[0] + + _report = self._report + reports: list[_report] = [] + for x in range(self.report_count): + offset = 8 + 24 * x + reports.append(self._read_report(data, offset)) + + self.reports = tuple(reports) + + self.extension: bytes | None = None + if len(data) > 8 + 24 * self.report_count: + self.extension = data[8 + 24 * self.report_count :] + + def _read_report(self, data: bytes, offset: int) -> _report: + ssrc, flost, seq, jit, lsr, dlsr = self._report_fmt.unpack_from(data, offset) + clost = self._24bit_int_fmt.unpack_from(data, offset)[0] & 0xFFFFFF + return self._report(ssrc, flost, clost, seq, jit, lsr, dlsr) + + +_rtcp_map = { + 200: SenderReportPacket, + 201: ReceiverReportPacket, +} diff --git a/discord/voice/receive/__init__.py b/discord/voice/receive/__init__.py new file mode 100644 index 0000000000..b8d4662310 --- /dev/null +++ b/discord/voice/receive/__init__.py @@ -0,0 +1,2 @@ +from .reader import AudioReader +from .router import PacketRouter, SinkEventRouter diff --git a/discord/voice/receive/reader.py b/discord/voice/receive/reader.py new file mode 100644 index 0000000000..7861a97291 --- /dev/null +++ b/discord/voice/receive/reader.py @@ -0,0 +1,466 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from collections.abc import Callable +import logging +from operator import itemgetter +import threading +import time +from typing import TYPE_CHECKING, Any, Literal + +import davey +from discord.opus import PacketDecoder + +from .router import PacketRouter, SinkEventRouter +from ..packets.rtp import decode, ReceiverReportPacket + +try: + import nacl.secret + from nacl.exceptions import CryptoError +except ImportError as exc: + raise RuntimeError("can't use voice receiver without PyNaCl installed, please install it with the 'py-cord[voice]' extra.") from exc + + +if TYPE_CHECKING: + from discord.member import Member + from discord.sinks import Sink + from discord.types.voice import SupportedModes + + from ..client import VoiceClient + from ..packets import RTPPacket + + AfterCallback = Callable[[Exception | None], Any] + DecryptRTP = Callable[[RTPPacket], bytes] + DecryptRTCP = Callable[[bytes], bytes] + SpeakingEvent = Literal["member_speaking_start", "member_speaking_stop"] + EncryptionBox = nacl.secret.SecretBox | nacl.secret.Aead + +_log = logging.getLogger(__name__) + +__all__ = ( + "AudioReader", +) + + +def is_rtcp(data: bytes) -> bool: + return 200 <= data[1] <= 204 + + +class AudioReader: + def __init__(self, sink: Sink, client: VoiceClient, *, after: AfterCallback | None = None) -> None: + if after is not None and not callable(after): + raise TypeError(f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead") + + self.sink: Sink = sink + self.client: VoiceClient = client + self.after: AfterCallback | None = after + + self.sink._client = client + + self.active: bool = False + self.error: Exception | None = None + self.packet_router: PacketRouter = PacketRouter(self.sink, self) + self.event_router: SinkEventRouter = SinkEventRouter(self.sink, self) + self.decryptor: PacketDecryptor = PacketDecryptor(client.mode, bytes(client.secret_key), client) + self.speaking_timer: SpeakingTimer = SpeakingTimer(self) + self.keep_alive: UDPKeepAlive = UDPKeepAlive(client) + + def is_listening(self) -> bool: + return self.active + + def update_secret_key(self, secret_key: bytes) -> None: + self.decryptor.update_secret_key(secret_key) + + def start(self) -> None: + if self.active: + _log.debug("Reader is already running", exc_info=True) + return + + self.speaking_timer.start() + self.event_router.start() + self.packet_router.start() + self.client._connection.add_socket_listener(self.callback) + self.keep_alive.start() + self.active = True + + def stop(self) -> None: + if not self.active: + _log.debug("Reader is not active", exc_info=True) + return + + self.client._connection.remove_socket_listener(self.callback) + self.active = False + self.speaking_timer.notify() + + threading.Thread(target=self._stop, name=f"voice-receiver-audio-reader-stop:{id(self):#x}").start() + + def _stop(self) -> None: + try: + self.packet_router.stop() + except Exception as exc: + self.error = exc + _log.exception("An error ocurred while stopping packet router.") + + try: + self.event_router.stop() + except Exception as exc: + self.error = exc + _log.exception("An error ocurred while stopping event router.") + + self.speaking_timer.stop() + self.keep_alive.stop() + + if self.after: + try: + self.after(self.error) + except Exception: + _log.exception("An error ocurred while calling the after callback on audio reader") + + for sink in self.sink.root.walk_children(with_self=True): + try: + sink.cleanup() + except Exception as exc: + _log.exception("Error calling cleanup() for %s", sink, exc_info=exc) + + def set_sink(self, sink: Sink) -> Sink: + old_sink = self.sink + old_sink._client = None + sink._client = self.client + self.packet_router.set_sink(sink) + self.sink = sink + return old_sink + + def _is_ip_discovery_packet(self, data: bytes) -> bool: + return len(data) == 74 and data[1] == 0x02 + + def callback(self, packet_data: bytes) -> None: + + packet = rtp_packet = rtcp_packet = None + + try: + if not is_rtcp(packet_data): + packet = rtp_packet = decode(packet_data) + packet.decrypted_data = self.decryptor.decrypt_rtp(packet) # type: ignore + else: + packet = rtcp_packet = decode(packet_data) + + if not isinstance(packet, ReceiverReportPacket): + _log.info("Received unexpected rtcp packet type=%s, %s", packet.type, type(packet)) + except CryptoError as exc: + _log.error("CryptoError while decoding a voice packet", exc_info=exc) + return + except Exception as exc: + if self._is_ip_discovery_packet(packet_data): + _log.debug("Received an IP Discovery Packet, ignoring...") + return + _log.exception("An exception ocurred while decoding voice packets", exc_info=exc) + finally: + if self.error: + self.stop() + return + if not packet: + return + + if rtcp_packet: + self.packet_router.feed_rtcp(rtcp_packet) # type: ignore + elif rtp_packet: + ssrc = rtp_packet.ssrc + + if ssrc not in self.client._connection.user_ssrc_map: + if rtp_packet.is_silence(): + return + else: + _log.info("Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet) + + self.speaking_timer.notify(ssrc) + + try: + self.packet_router.feed_rtp(rtp_packet) # type: ignore + except Exception as exc: + _log.exception("An error ocurred while processing RTP packet %s", rtp_packet) + self.error = exc + self.stop() + + +class PacketDecryptor: + supported_modes: list[SupportedModes] = [ + "aead_xchacha20_poly1305_rtpsize", + "xsalsa20_poly1305", + "xsalsa20_poly1305_lite", + "xsalsa20_poly1305_suffix", + ] + + def __init__(self, mode: SupportedModes, secret_key: bytes, client: VoiceClient) -> None: + self.mode: SupportedModes = mode + self.client: VoiceClient = client + + try: + self._decryptor_rtp: DecryptRTP = getattr(self, '_decrypt_rtp_' + mode) + self._decryptor_rtcp: DecryptRTCP = getattr(self, '_decrypt_rtcp_' + mode) + except AttributeError as exc: + raise NotImplementedError(mode) from exc + + self.box: EncryptionBox = self._make_box(secret_key) + + def _make_box(self, secret_key: bytes) -> EncryptionBox: + if self.mode.startswith("aead"): + return nacl.secret.Aead(secret_key) + else: + return nacl.secret.SecretBox(secret_key) + + def decrypt_rtp(self, packet: RTPPacket) -> bytes: + state = self.client._connection + dave = state.dave_session + data = self._decryptor_rtp(packet) + + if dave is not None and dave.ready and packet.ssrc in state.user_ssrc_map: + return dave.decrypt(state.user_ssrc_map[packet.ssrc], davey.MediaType.audio, data) + return data + + def decrypt_rtcp(self, packet: bytes) -> bytes: + data = self._decryptor_rtcp(packet) + # TODO: guess how to get the SSRC so we can use dave + return data + + def update_secret_key(self, secret_key: bytes) -> None: + self.box = self._make_box(secret_key) + + def _decrypt_rtp_xsalsa20_poly1305(self, packet: RTPPacket) -> bytes: + nonce = bytearray(24) + nonce[:12] = packet.header + result = self.box.decrypt(bytes(packet.data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:8] = data[:8] + result = self.box.decrypt(data[8:], bytes(nonce)) + + return data[:8] + result + + def _decrypt_rtp_xsalsa20_poly1305_suffix(self, packet: RTPPacket) -> bytes: + nonce = packet.data[-24:] + voice_data = packet.data[:-24] + result = self.box.decrypt(bytes(voice_data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305_suffix(self, data: bytes) -> bytes: + nonce = data[-24:] + header = data[:8] + result = self.box.decrypt(data[8:-24], nonce) + + return header + result + + def _decrypt_rtp_xsalsa20_poly1305_lite(self, packet: RTPPacket) -> bytes: + nonce = bytearray(24) + nonce[:4] = packet.data[-4:] + voice_data = packet.data[:-4] + result = self.box.decrypt(bytes(voice_data), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:4] = data[-4:] + header = data[:8] + result = self.box.decrypt(data[8:-4], bytes(nonce)) + + return header + result + + def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> bytes: + packet.adjust_rtpsize() + + nonce = bytearray(24) + nonce[:4] = packet.nonce + voice_data = packet.data + + # Blob vomit + assert isinstance(self.box, nacl.secret.Aead) + result = self.box.decrypt(bytes(voice_data), bytes(packet.header), bytes(nonce)) + + if packet.extended: + offset = packet.update_extended_header(result) + result = result[offset:] + + return result + + def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: + nonce = bytearray(24) + nonce[:4] = data[-4:] + header = data[:8] + + assert isinstance(self.box, nacl.secret.Aead) + result = self.box.decrypt(data[8:-4], bytes(header), bytes(nonce)) + + return header + result + + +class SpeakingTimer(threading.Thread): + def __init__(self, reader: AudioReader) -> None: + super().__init__( + daemon=True, + name=f"voice-receiver-speaking-timer:{id(self):#x}", + ) + + self.reader: AudioReader = reader + self.client: VoiceClient = reader.client + self.speaking_timeout_delay: float = 0.2 + self.last_speaking_state: dict[int, bool] = {} + self.speaking_cache: dict[int, float] = {} + self.speaking_timer_event: threading.Event = threading.Event() + self._end_thread: threading.Event = threading.Event() + + def _lookup_member(self, ssrc: int) -> Member | None: + id = self.client._connection.user_ssrc_map.get(ssrc) + if not self.client.guild: + return None + return self.client.guild.get_member(id) if id else None + + def maybe_dispatch_speaking_start(self, ssrc: int) -> None: + tlast = self.speaking_cache.get(ssrc) + if tlast is None or tlast + self.speaking_timeout_delay < time.perf_counter(): + self.dispatch("member_speaking_start", ssrc) + + def dispatch(self, event: SpeakingEvent, ssrc: int) -> None: + member = self._lookup_member(ssrc) + if not member: + return None + self.client._dispatch_sink(event, member) + + def notify(self, ssrc: int | None = None) -> None: + if ssrc is not None: + self.last_speaking_state[ssrc] = True + self.maybe_dispatch_speaking_start(ssrc) + self.speaking_cache[ssrc] = time.perf_counter() + + self.speaking_timer_event.set() + self.speaking_timer_event.clear() + + def drop_ssrc(self, ssrc: int) -> None: + self.speaking_cache.pop(ssrc, None) + state = self.last_speaking_state.pop(ssrc, None) + if state: + self.dispatch("member_speaking_stop", ssrc) + self.notify() + + def get_speaking(self, ssrc: int) -> bool | None: + return self.last_speaking_state.get(ssrc) + + def stop(self) -> None: + self._end_thread.set() + self.notify() + + def run(self) -> None: + _i1 = itemgetter(1) + + def get_next_entry(): + cache = sorted(self.speaking_cache.items(), key=_i1) + for ssrc, tlast in cache: + if self.last_speaking_state.get(ssrc): + return ssrc, tlast + return None, None + + self.speaking_timer_event.wait() + while not self._end_thread.is_set(): + if not self.speaking_cache: + self.speaking_timer_event.wait() + + tnow = time.perf_counter() + ssrc, tlast = get_next_entry() + + if ssrc is None or tlast is None: + self.speaking_timer_event.wait() + continue + + self.speaking_timer_event.wait(tlast + self.speaking_timeout_delay - tnow) + + if time.perf_counter() < tlast + self.speaking_timeout_delay: + continue + + self.dispatch("member_speaking_stop", ssrc) + self.last_speaking_state[ssrc] = False + + +class UDPKeepAlive(threading.Thread): + delay: int = 5000 + + def __init__(self, client: VoiceClient) -> None: + super().__init__( + daemon=True, + name=f"voice-receiver-udp-keep-alive:{id(self):#x}", + ) + + self.client: VoiceClient = client + self.last_time: float = 0 + self.counter: int = 0 + self._end_thread: threading.Event = threading.Event() + + def run(self) -> None: + self.client.wait_until_connected() + + while not self._end_thread.is_set(): + vc = self.client + + try: + packet = self.counter.to_bytes(8, "big") + except OverflowError: + self.counter = 0 + continue + + try: + vc._connection.socket.sendto(packet, (vc._connection.endpoint_ip, vc._connection.voice_port)) + except Exception as exc: + _log.debug( + "Error while sending udp keep alive to socket %s at %s:%s", + vc._connection.socket, + vc._connection.endpoint_ip, + vc._connection.voice_port, + exc_info=exc, + ) + vc.wait_until_connected() + if vc.is_connected(): + continue + break + else: + self.counter += 1 + time.sleep(self.delay) + + def stop(self) -> None: + self._end_thread.set() diff --git a/discord/voice/receive/router.py b/discord/voice/receive/router.py new file mode 100644 index 0000000000..94e9f5bb72 --- /dev/null +++ b/discord/voice/receive/router.py @@ -0,0 +1,222 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +from collections import deque +from collections.abc import Callable +import threading +import logging +from typing import TYPE_CHECKING, Any +import queue + +from discord.opus import PacketDecoder + +from ..utils.multidataevent import MultiDataEvent + +if TYPE_CHECKING: + from discord.sinks import Sink + + from .reader import AudioReader + from ..packets import RTPPacket, RTCPPacket + + EventCB = Callable[..., Any] + EventData = tuple[str, tuple[Any, ...], dict[str, Any]] + +_log = logging.getLogger(__name__) + + +class PacketRouter(threading.Thread): + def __init__(self, sink: Sink, reader: AudioReader) -> None: + super().__init__( + daemon=True, + name=f'voice-receiver-packet-router:{id(self):#x}', + ) + + self.sink: Sink = sink + self.decoders: dict[int, PacketDecoder] = {} + self.reader: AudioReader = reader + self.waiter: MultiDataEvent[PacketDecoder] = MultiDataEvent() + + self._lock: threading.RLock = threading.RLock() + self._end_thread: threading.Event = threading.Event() + self._dropped_ssrcs: deque[int] = deque(maxlen=16) + + def feed_rtp(self, packet: RTPPacket) -> None: + if packet.ssrc in self._dropped_ssrcs: + _log.debug("Ignoring packet from dropped ssrc %s", packet.ssrc) + + with self._lock: + decoder = self.get_decoder(packet.ssrc) + if decoder is not None: + decoder.push_packet(packet) + + def feed_rtcp(self, packet: RTCPPacket) -> None: + guild = self.sink.client.guild if self.sink.client else None + event_router = self.reader.event_router + event_router.dispatch('rtcp_packet', packet, guild) + + def get_decoder(self, ssrc: int) -> PacketDecoder | None: + with self._lock: + decoder = self.decoders.get(ssrc) + if decoder is None: + decoder = self.decoders[ssrc] = PacketDecoder(self, ssrc) + return decoder + + def set_sink(self, sink: Sink) -> None: + with self._lock: + self.sink = sink + + def set_user_id(self, ssrc: int, user_id: int) -> None: + with self._lock: + if ssrc in self._dropped_ssrcs: + self._dropped_ssrcs.remove(ssrc) + + decoder = self.decoders.get(ssrc) + if decoder is not None: + decoder.set_user_id(user_id) + + def destroy_decoder(self, ssrc: int) -> None: + with self._lock: + decoder = self.decoders.pop(ssrc, None) + if decoder is not None: + self._dropped_ssrcs.append(ssrc) + decoder.destroy() + + def destroy_all_decoders(self) -> None: + with self._lock: + for ssrc in self.decoders.keys(): + self.destroy_decoder(ssrc) + + def stop(self) -> None: + self._end_thread.set() + self.waiter.notify() + + def run(self) -> None: + try: + self._do_run() + except Exception as exc: + _log.exception("Error in %s loop", self) + self.reader.error = exc + finally: + self.reader.client.stop_recording() + self.waiter.clear() + + def _do_run(self) -> None: + while not self._end_thread.is_set(): + self.waiter.wait() + + with self._lock: + for decoder in self.waiter.items: + data = decoder.pop_data() + if data is not None: + self.sink.write(data.source, data) + + +class SinkEventRouter(threading.Thread): + def __init__(self, sink: Sink, reader: AudioReader) -> None: + super().__init__(daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}") + + self.sink: Sink = sink + self.reader: AudioReader = reader + + self._event_listeners: dict[str, list[EventCB]] = {} + self._buffer: queue.SimpleQueue[EventData] = queue.SimpleQueue() + self._lock = threading.RLock() + self._end_thread = threading.Event() + + self.register_events() + + def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: + _log.debug("Dispatch voice event %s", event) + self._buffer.put_nowait((event, args, kwargs)) + + def set_sink(self, sink: Sink) -> None: + with self._lock: + self.unregister_events() + self.sink = sink + self.register_events() + + def register_events(self) -> None: + with self._lock: + self._register_listeners(self.sink) + for child in self.sink.walk_children(): + self._register_listeners(child) + + def unregister_events(self) -> None: + with self._lock: + self._unregister_listeners(self.sink) + for child in self.sink.walk_children(): + self._unregister_listeners(child) + + def _register_listeners(self, sink: Sink) -> None: + _log.debug("Registering events for %s: %s", sink, sink.__sink_listeners__) + + for name, method_name in sink.__sink_listeners__: + func = getattr(sink, method_name) + _log.debug("Registering event: %r (callback at %r)", name, method_name) + + if name in self._event_listeners: + self._event_listeners[name].append(func) + else: + self._event_listeners[name] = [func] + + def _unregister_listeners(self, sink: Sink) -> None: + for name, method_name in sink.__sink_listeners__: + func = getattr(sink, method_name) + + if name in self._event_listeners: + try: + self._event_listeners[name].remove(func) + except ValueError: + pass + + def _dispatch_to_listeners(self, event: str, *args: Any, **kwargs: Any) -> None: + for listener in self._event_listeners.get(f"on_{event}", []): + try: + listener(*args, **kwargs) + except Exception as exc: + _log.exception("Unhandled exception while dispatching event %s (args: %s; kwargs: %s)", event, args, kwargs, exc_info=exc) + + def stop(self) -> None: + self._end_thread.set() + + def run(self) -> None: + try: + self._do_run() + except Exception as exc: + _log.exception("Error in sink event router", exc_info=exc) + self.reader.error = exc + self.reader.client.stop_listening() + + def _do_run(self) -> None: + while not self._end_thread.is_set(): + try: + event, args, kwargs = self._buffer.get(timeout=0.5) + except queue.Empty: + continue + else: + with self._lock: + with self.reader.packet_router._lock: + self._dispatch_to_listeners(event, *args, **kwargs) diff --git a/discord/voice/state.py b/discord/voice/state.py index aa717f0445..deba6ee9e3 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -42,7 +42,12 @@ from discord.object import Object from discord.sinks import RawData, Sink -import davey +try: + import davey +except ImportError: + import warnings + + warnings.warn_explicit() from .enums import ConnectionFlowState, OpCodes from .gateway import VoiceWebSocket @@ -193,12 +198,15 @@ def _do_run(self) -> None: else: for cb in self._callbacks: try: - task = self.state.loop.create_task( - utils.maybe_coroutine(cb, data) + task = asyncio.ensure_future( + self.state.loop.create_task( + utils.maybe_coroutine(cb, data) + ), + loop=self.state.loop, ) - self.state._sink_dispatch_task_set.add(task) + self.state._dispatch_task_set.add(task) task.add_done_callback( - self.state._sink_dispatch_task_set.discard + self.state._dispatch_task_set.discard ) except Exception: _log.exception( @@ -208,21 +216,6 @@ def _do_run(self) -> None: ) -class SocketVoiceRecvReader(SocketReader): - def __init__( - self, - state: VoiceConnectionState, - *, - start_paused: bool = True, - ) -> None: - super().__init__( - state, - f"voice-recv-socket-reader:{id(self):#x}", - 4096, - start_paused=start_paused, - ) - - class SocketEventReader(SocketReader): def __init__( self, state: VoiceConnectionState, *, start_paused: bool = True @@ -235,11 +228,6 @@ def __init__( ) -class SSRC(TypedDict): - user_id: int - speaking: SpeakingState - - class VoiceConnectionState: def __init__( self, @@ -281,18 +269,10 @@ def __init__( self._connector: asyncio.Task[None] | None = None self._socket_reader = SocketEventReader(self) self._socket_reader.start() - self._voice_recv_socket = SocketVoiceRecvReader(self) - self._voice_recv_socket.register(self.handle_voice_recv_packet) - self.start_record_socket() - self.user_ssrc_map: dict[int, SSRC] = {} - self.user_voice_timestamps: dict[int, tuple[int, float]] = {} - self.sync_recording_start: bool = False - self.first_received_packet_ts: float = MISSING - self._sinks: dict[int, Sink] = {} self.recording_done_callbacks: list[ tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]] ] = [] - self._sink_dispatch_task_set: set[asyncio.Task[Any]] = set() + self._dispatch_task_set: set[asyncio.Task] = set() if not self._connection.self_id: raise RuntimeError("client self ID is not set") @@ -305,284 +285,13 @@ def __init__( self.downgraded_dave = False @property - def sinks(self) -> list[Sink]: - return list(self._sinks.values()) + def user_ssrc_map(self) -> dict[int, int]: + return self.client._id_to_ssrc @property def max_dave_proto_version(self) -> int: return davey.DAVE_PROTOCOL_VERSION - def start_record_socket(self) -> None: - try: - self._voice_recv_socket.start() - except RuntimeError: - self._voice_recv_socket.resume() - - def stop_record_socket(self) -> None: - self._voice_recv_socket.stop() - - for cb, args in self.recording_done_callbacks: - task = self.loop.create_task(cb(*args)) - self._sink_dispatch_task_set.add(task) - task.add_done_callback(self._sink_dispatch_task_set.remove) - - for sink in self.sinks: - sink.stop() - - self.recording_done_callbacks.clear() - self.sinks.clear() - - async def handle_voice_recv_packet(self, packet: bytes) -> None: - _recv_log.debug("Handling voice packet %s", packet) - if packet[1] != 0x78: - # We should ignore any payload types we do not understand - # Ref: RFC 3550 5.1 payload type - # At some point we noted that we should ignore only types 200 - 204 inclusive. - # They were marked as RTCP: provides information about the connection - # this was too broad of a whitelist, it is unclear if this is too narrow of a whitelist - return - - if self.paused_recording(): - _recv_log.debug("Ignoring packet %s because recording is stopped", packet) - return - - data = RawData(packet, self.client) - - if data.decrypted_data == opus.OPUS_SILENCE: - _recv_log.debug( - "Ignoring packet %s because it is an opus silence frame", data - ) - return - - await data.decode() - - def is_first_packet(self) -> bool: - return not self.user_voice_timestamps or not self.sync_recording_start - - def dispatch_packet_sinks(self, data: RawData) -> None: - _log.debug("Dispatching packet %s in all sinks", data) - if data.ssrc not in self.user_ssrc_map: - if self.is_first_packet(): - self.first_received_packet_ts = data.receive_time - silence = 0 - else: - silence = (data.receive_time - self.first_received_packet_ts) * 48000 - else: - stored_timestamp, stored_recv_time = self.user_voice_timestamps[data.ssrc] - dRT = data.receive_time - stored_recv_time * 48000 - dT = data.timestamp - stored_timestamp - diff = abs(100 - dT * 100 / dRT) - - if diff > 60 and dT != 960: - silence = dRT - 960 - else: - silence = dT - 960 - - self.user_voice_timestamps[data.ssrc] = (data.timestamp, data.receive_time) - - data.decoded_data = ( - struct.pack(" None: - user = self.get_user_by_ssrc(data.ssrc) - if not user: - _log.debug( - "Ignoring received packet %s because the SSRC was waited for but was not found", - data, - ) - return - - data.user_id = user.id - - for sink in self.sinks: - if sink.is_paused(): - continue - - sink.dispatch("unfiltered_voice_packet_receive", user, data) - - if sink._filters: - futures = [ - self.loop.create_task( - utils.maybe_coroutine(fil.filter_packet, sink, user, data) - ) - for fil in sink._filters - ] - strat = sink._filter_strat - - done, pending = await asyncio.wait(futures) - - if pending: - for task in pending: - task.set_result(False) - - done = (*done, *pending) - - result = strat([f.result() for f in done]) - else: - result = True - - if result: - sink.dispatch("voice_packet_receive", user, data) - sink._call_voice_packet_handlers(user, data) - - def is_recording(self) -> bool: - return self._voice_recv_socket.is_running() - - def paused_recording(self) -> bool: - return self._voice_recv_socket.is_paused() - - def add_sink(self, sink: Sink) -> None: - self._sinks[id(sink)] = sink - self.start_record_socket() - - def remove_sink(self, sink: Sink) -> None: - try: - self._sinks.pop(id(sink)) - except KeyError: - pass - - def get_user_by_ssrc(self, ssrc: int) -> abc.Snowflake | None: - data = self.user_ssrc_map.get(ssrc) - if data is None: - return None - - user = int(data["user_id"]) - return self.get_user(user) - - def get_user(self, id: int) -> abc.Snowflake: - state = self._connection - return self.guild.get_member(id) or state.get_user(id) or Object(id=id) - - def ws_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: - op = msg["op"] - data = msg.get("d", {}) - - if op == OpCodes.speaking: - ssrc = data["ssrc"] - user = int(data["user_id"]) - raw_speaking = data["speaking"] - speaking = try_enum(SpeakingState, raw_speaking) - old_data = self.user_ssrc_map.get(ssrc) - old_speaking = (old_data or {}).get("speaking", SpeakingState.none) - - self.dispatch_speaking_state(old_speaking, speaking, user) - - if old_data is None: - self.user_ssrc_map[ssrc]["speaking"] = speaking - else: - self.user_ssrc_map[ssrc] = { - "user_id": user, - "speaking": speaking, - } - elif op == OpCodes.client_connect: - user_ids = [int(uid) for uid in data["user_ids"]] - - for uid in user_ids: - user = self.get_user(uid) - self.dispatch_user_connect(user) - - def dispatch_speaking_state( - self, before: SpeakingState, after: SpeakingState, user_id: int - ) -> None: - task = self.loop.create_task( - self._dispatch_speaking_state(before, after, user_id), - ) - self._sink_dispatch_task_set.add(task) - task.add_done_callback(self._sink_dispatch_task_set.discard) - - def dispatch_user_connect(self, user: abc.Snowflake) -> None: - task = self.loop.create_task( - self._dispatch_user_connect(self.channel_id, user), - ) - self._sink_dispatch_task_set.add(task) - task.add_done_callback(self._sink_dispatch_task_set.discard) - - async def _dispatch_user_connect( - self, chid: int | None, user: abc.Snowflake - ) -> None: - channel = self.guild._resolve_channel(chid) or Object(id=chid or 0) - - for sink in self.sinks: - if sink.is_paused(): - continue - - sink.dispatch("unfiltered_user_connect", user, channel) - - if sink._filters: - futures = [ - self.loop.create_task( - utils.maybe_coroutine(fil.filter_user_connect, user, channel) - ) - for fil in sink._filters - ] - strat = sink._filter_strat - - done, pending = await asyncio.wait(futures) - - if pending: - for task in pending: - task.cancel() - - result = strat([f.result() for f in done]) - else: - result = True - - if result: - sink.dispatch("user_connect", user, channel) - sink._call_user_connect_handlers(user, channel) - - async def _dispatch_speaking_state( - self, before: SpeakingState, after: SpeakingState, uid: int - ) -> None: - resolved = self.get_user(uid) - - for sink in self.sinks: - if sink.is_paused(): - continue - - sink.dispatch("unfiltered_speaking_state_update", resolved, before, after) - - if sink._filters: - futures = [ - self.loop.create_task( - utils.maybe_coroutine( - fil.filter_packet, sink, resolved, before, after - ) - ) - for fil in sink._filters - ] - strat = sink._filter_strat - - done, pending = await asyncio.wait(futures) - - if pending: - # there should not be any pending futures - # but if there are, simply discard them - for task in pending: - task.cancel() - - result = strat([f.result() for f in done]) - else: - result = True - - if result: - sink.dispatch("speaking_state_update", resolved, before, after) - sink._call_speaking_state_handlers(resolved, before, after) - @property def state(self) -> ConnectionFlowState: return self._state @@ -881,7 +590,6 @@ async def disconnect( if cleanup: self._socket_reader.stop() - self.stop_record_socket() self.client.stop() self._connected.set() @@ -927,7 +635,6 @@ async def soft_disconnect( finally: self.state = with_state self._socket_reader.pause() - self._voice_recv_socket.pause() if self.socket: self.socket.close() @@ -1028,7 +735,6 @@ async def _voice_disconnect(self) -> None: await self.client.channel.guild.change_voice_state(channel=None) self._expecting_disconnect = True self._disconnected.clear() - self.ws._identified = False async def _connect_websocket(self, resume: bool) -> VoiceWebSocket: seq_ack = -1 @@ -1207,25 +913,23 @@ def _update_voice_channel(self, channel_id: int | None) -> None: self.client.channel = channel_id and self.guild.get_channel(channel_id) # type: ignore async def reinit_dave_session(self) -> None: - session = self.dave_session - if self.dave_protocol_version > 0: - if session: - session.reinit(self.dave_protocol_version, self.user.id, self.channel_id) + if self.dave_session: + self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.channel_id) else: - session = self.dave_session = davey.DaveSession( + self.dave_session = davey.DaveSession( self.dave_protocol_version, self.user.id, self.channel_id, ) await self.ws.send_as_bytes( - int(OpCodes.mls_key_package), - session.get_serialized_key_package(), + OpCodes.mls_key_package, + self.dave_session.get_serialized_key_package(), ) - elif session: - session.reset() - session.set_passthrough_mode(True, 10) + elif self.dave_session: + self.dave_session.reset() + self.dave_session.set_passthrough_mode(True, 10) async def recover_dave_from_invalid_commit(self, transition: int) -> None: payload = { diff --git a/discord/voice/utils/__init__.py b/discord/voice/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/discord/voice/utils/buffer.py b/discord/voice/utils/buffer.py new file mode 100644 index 0000000000..91c83dfe80 --- /dev/null +++ b/discord/voice/utils/buffer.py @@ -0,0 +1,204 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import heapq +import logging +import threading +from typing import Protocol, TypeVar + +from .wrapped import gap_wrapped, add_wrapped + +from ..packets import Packet + + +__all__ = ( + "Buffer", + "JitterBuffer", +) + + +T = TypeVar("T") +PacketT = TypeVar("PacketT", bound=Packet) +_log = logging.getLogger(__name__) + + +class Buffer(Protocol[T]): + def __len__(self) -> int: ... + def push(self, item: T) -> None: ... + def pop(self) -> T | None: ... + def peek(self) -> T | None: ... + def flush(self) -> list[T]: ... + def reset(self) -> None: ... + + +class BaseBuff(Buffer[PacketT]): + def __init__(self) -> None: + self._buffer: list[PacketT] = [] + + def __len__(self) -> int: + return len(self._buffer) + + def push(self, item: PacketT) -> None: + self._buffer.append(item) + + def pop(self) -> PacketT | None: + return self._buffer.pop() + + def peek(self) -> PacketT | None: + return self._buffer[-1] if self._buffer else None + + def flush(self) -> list[PacketT]: + buf = self._buffer.copy() + self._buffer.clear() + return buf + + def reset(self) -> None: + self._buffer.clear() + + +class JitterBuffer(BaseBuff[PacketT]): + _threshold: int = 10000 + + def __init__(self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1) -> None: + if max_size < 1: + raise ValueError(f"max_size must be greater than 1, not {max_size}") + + if not 0 <= pref_size <= max_size: + raise ValueError(f"pref_size must be between 0 and max_size ({max_size})") + + self.max_size: int = max_size + self.pref_size: int = pref_size + self.prefill: int = prefill + self._prefill: int = prefill + self._last_tx_seq: int = -1 + self._has_item: threading.Event = threading.Event() + #self._lock: threading.Lock = threading.Lock() + self._buffer: list[Packet] = [] + + def _push(self, packet: Packet) -> None: + heapq.heappush(self._buffer, packet) + + def _pop(self) -> Packet: + return heapq.heappop(self._buffer) + + def _get_packet_if_ready(self) -> Packet | None: + return self._buffer[0] if len(self._buffer) > self.pref_size else None + + def _pop_if_ready(self) -> Packet | None: + return self._pop() if len(self._buffer) > self.pref_size else None + + def _update_has_item(self) -> None: + prefilled = self._prefill == 0 + packet_ready = len(self._buffer) > self.pref_size + + if not prefilled or not packet_ready: + self._has_item.clear() + return + + next_packet = self._buffer[0] + sequential = add_wrapped(self._last_tx_seq, 1) == next_packet.sequence + positive_seq = self._last_tx_seq >= 0 + + if (sequential and positive_seq) or not positive_seq or len(self._buffer) >= self.max_size: + self._has_item.set() + else: + self._has_item.clear() + + def _cleanup(self) -> None: + while len(self._buffer) > self.max_size: + packet = heapq.heappop(self._buffer) + + def push(self, packet: Packet) -> bool: + seq = packet.sequence + + if gap_wrapped(self._last_tx_seq, seq) > self._threshold and self._last_tx_seq != -1: + _log.debug("Dropping old packet %s", packet) + return False + + self._push(packet) + + if self._prefill > 0: + self._prefill -= 1 + + self._cleanup() + self._update_has_item() + return True + + def pop(self, *, timeout: float | None = 0) -> Packet | None: + ok = self._has_item.wait(timeout) + if not ok: + return None + + if self._prefill > 0: + return None + + packet = self._pop_if_ready() + + if packet is not None: + self._last_tx_seq = packet.sequence + + self._update_has_item() + return packet + + def peek(self, *, all: bool = False) -> Packet | None: + if not self._buffer: + return None + + if all: + return self._buffer[0] + else: + return self._get_packet_if_ready() + + def peek_next(self) -> Packet | None: + packet = self.peek(all=True) + + if packet is None: + return None + + if packet.sequence == add_wrapped(self._last_tx_seq, 1) or self._last_tx_seq < 0: + return packet + + def gap(self) -> int: + if self._buffer and self._last_tx_seq > 0: + return gap_wrapped(self._last_tx_seq, self._buffer[0].sequence) + return 0 + + def flush(self) -> list[Packet]: + packets = sorted(self._buffer) + self._buffer.clear() + + if packets: + self._last_tx_seq = packets[-1].sequence + + self._prefill = self.prefill + self._has_item.clear() + return packets + + def reset(self) -> None: + self._buffer.clear() + self._has_item.clear() + self._prefill = self.prefill + self._last_tx_seq = -1 diff --git a/discord/voice/utils/multidataevent.py b/discord/voice/utils/multidataevent.py new file mode 100644 index 0000000000..ea175f332b --- /dev/null +++ b/discord/voice/utils/multidataevent.py @@ -0,0 +1,78 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +from __future__ import annotations + +import threading +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class MultiDataEvent(Generic[T]): + """ + Something like the inverse of a Condition. A 1-waiting-on-N type of object, + with accompanying data object for convenience. + """ + + def __init__(self): + self._items: list[T] = [] + self._ready: threading.Event = threading.Event() + + @property + def items(self) -> list[T]: + """A shallow copy of the currently ready objects.""" + return self._items.copy() + + def is_ready(self) -> bool: + return self._ready.is_set() + + def _check_ready(self) -> None: + if self._items: + self._ready.set() + else: + self._ready.clear() + + def notify(self) -> None: + self._ready.set() + self._check_ready() + + def wait(self, timeout: float | None = None) -> bool: + self._check_ready() + return self._ready.wait(timeout) + + def register(self, item: T) -> None: + self._items.append(item) + self._ready.set() + + def unregister(self, item: T) -> None: + try: + self._items.remove(item) + except ValueError: + pass + self._check_ready() + + def clear(self) -> None: + self._items.clear() + self._ready.clear() \ No newline at end of file diff --git a/discord/sinks/enums.py b/discord/voice/utils/wrapped.py similarity index 85% rename from discord/sinks/enums.py rename to discord/voice/utils/wrapped.py index f09daf8b56..8d096a0713 100644 --- a/discord/sinks/enums.py +++ b/discord/voice/utils/wrapped.py @@ -23,13 +23,9 @@ DEALINGS IN THE SOFTWARE. """ -from __future__ import annotations +def gap_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: + return (b - (a + 1) + wrap) % wrap -from discord.enums import Enum -__all__ = ("SinkFilteringMode",) - - -class SinkFilteringMode(Enum): - all = 0 - any = 1 +def add_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: + return (a + b) % wrap \ No newline at end of file diff --git a/docs/api/voice.rst b/docs/api/voice.rst index 95d2b8441d..59de02185d 100644 --- a/docs/api/voice.rst +++ b/docs/api/voice.rst @@ -6,13 +6,13 @@ Voice Related Objects ------- -.. attributetable:: VoiceClient +.. attributetable:: discord.voice.VoiceClient .. autoclass:: VoiceClient() :members: :exclude-members: connect, on_voice_state_update, on_voice_server_update -.. attributetable:: VoiceProtocol +.. attributetable:: discord.voice.VoiceProtocol .. autoclass:: VoiceProtocol :members: From 60658bb6934639b207f3e7f611481e1a669d19c1 Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:25:50 +0200 Subject: [PATCH 36/40] more voice recv --- discord/sinks/core.py | 90 ++++++++++- discord/sinks/m4a.py | 257 +++++++----------------------- discord/sinks/mka.py | 235 +++++++-------------------- discord/sinks/mkv.py | 232 +++++++-------------------- discord/sinks/mp3.py | 238 +++++++-------------------- discord/sinks/mp4.py | 257 +++++++----------------------- discord/sinks/ogg.py | 235 +++++++-------------------- discord/sinks/pcm.py | 122 +------------- discord/sinks/wave.py | 158 ++++++------------ discord/voice/packets/__init__.py | 4 + 10 files changed, 514 insertions(+), 1314 deletions(-) diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 9c87517c11..949fefc25f 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -32,11 +32,14 @@ import subprocess import shlex import threading -from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, overload +from discord.file import File from discord.utils import MISSING, SequenceProxy from discord.player import FFmpegAudio +from .errors import FFmpegNotFound + if TYPE_CHECKING: from typing_extensions import ParamSpec, Self @@ -52,6 +55,9 @@ __all__ = ( "Sink", "RawData", + "FFmpegSink", + "FilterSink", + "MultiSink", ) @@ -309,13 +315,67 @@ def __init__(self, **kwargs: Any) -> None: raise DeprecationWarning("RawData has been deprecated in favour of VoiceData") -class _FFmpegSink(Sink): +class FFmpegSink(Sink): + """A :class:`Sink` built to use ffmpeg executables. + + You can find default implementations of this sink in: + + - :class:`M4ASink` + - :class:`MKASink` + + .. versionadded:: 2.7 + + Parameters + ---------- + filename: :class:`str` + The file in which the ffmpeg buffer should be saved to. + Can not be mixed with ``buffer``. + buffer: IO[:class:`bytes`] + The buffer in which the ffmpeg result would be written to. + Can not be mixed with ``filename``. + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in whcih will be written. Defaults to ``None``. + before_options: :class:`str` | :data:`None` + The options to append **before** the default ones. + options: :class:`str` | :data:`None` + The options to append **after** the default ones. You can override the + default ones with this. + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + """ + + @overload + def __init__( + self, + *, + filename: str, + executable: str = ..., + stderr: IO[bytes] = ..., + before_options: str | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + + @overload + def __init__( + self, + *, + buffer: IO[bytes], + executable: str = ..., + stderr: IO[bytes] = ..., + before_options: str | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... + def __init__( self, *, filename: str = MISSING, buffer: IO[bytes] = MISSING, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", stderr: IO[bytes] | None = None, before_options: str | None = None, options: str | None = None, @@ -323,6 +383,9 @@ def __init__( ) -> None: super().__init__() + if filename is not MISSING and buffer is not MISSING: + raise TypeError("can't mix filename and buffer parameters") + self.filename: str = filename or "pipe:1" self.buffer: IO[bytes] = buffer @@ -345,7 +408,7 @@ def __init__( args.extend(shlex.split(before_options)) args.extend({ - "-f": "s161e", + "-f": "s16le", "-ar": "48000", "-ac": "2", "-i": "pipe:0", @@ -382,7 +445,7 @@ def __init__( self._stderr_reader_thread.start() @staticmethod - def _on_error(_self: _FFmpegSink, error: Exception, data: VoiceData | None) -> None: + def _on_error(_self: FFmpegSink, error: Exception, data: VoiceData | None) -> None: _self.client.stop_recording() # type: ignore def is_opus(self) -> bool: @@ -404,6 +467,21 @@ def write(self, user: User | Member | None, data: VoiceData) -> None: self._kill_processes() self.on_error(self, exc, data) + + def to_file(self, filename: str, /, *, description: str | None = None, spoiler: bool = False) -> File | None: + """Returns the :class:`discord.File` of this sink. + + This is only applicable if this sink uses a ``buffer`` instead of a ``filename``. + + .. warning:: + + This should be used only after the sink has stopped recording. + """ + if self.buffer is not MISSING: + fp = File(self.buffer.read(), filename=filename, description=description, spoiler=spoiler) + return fp + return None + def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: _log.debug("Spawning ffmpeg process with command %s and kwargs %s", args, subprocess_kwargs) process = None @@ -412,7 +490,7 @@ def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Pope process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) except FileNotFoundError: executable = args.partition(' ')[0] if isinstance(args, str) else args[0] - raise Exception(f"{executable!r} executable was not found") from None + raise FFmpegNotFound(f"{executable!r} executable was not found") from None except subprocess.SubprocessError as exc: raise Exception(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc else: diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index e817cfc540..2541b058b6 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -24,230 +24,95 @@ from __future__ import annotations -import io -import logging -import os -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord import utils -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, M4ASinkError, MaxProcessesCountReached, NoUserAudio +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from ..voice.packets import VoiceData __all__ = ("M4ASink",) -class M4ASink(Sink): +class M4ASink(FFmpegSink): """A special sink for .m4a files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - M4ASinkError - Any error raised while formatting, wrapped around M4ASinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - temp_path = f"{user_id}-{time.time()}-recording.m4a.tmp" - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ipod", - temp_path, - ] - - if os.path.exists(temp_path): - found = utils.find(lambda d: d[0] == temp_path, self.__process_queue) - if found: - _, old_process = found - old_process.kill() - _log.info( - "Killing old process (%s) to write in %s", old_process, temp_path - ) - - os.remove( - temp_path - ) # process would get stuck asking whether to overwrite, if file already exists. - - try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - self.__process_queue.append((temp_path, process)) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise M4ASinkError(f"Audio formatting for user {user_id} failed") from exc - - process.communicate(data.read()) - - with open(temp_path, "rb") as file: - buffer = io.BytesIO(file.read()) - buffer.seek(0) - - try: - self.__process_queue.remove((temp_path, process)) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.m4a") - return buffer - - def _clean_process(self, path: str, process: subprocess.Popen) -> None: - _log.debug( - "Cleaning process %s for sink %s (with temporary file at %s)", - process, - self, - path, - ) - process.kill() - if os.path.exists(path): - os.remove(path) - - def cleanup(self) -> None: - for path, process in self.__process_queue: - self._clean_process(path, process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f ipod -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 3018bf84e4..04e13e82b6 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -24,208 +24,95 @@ from __future__ import annotations -import io -import logging -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MKASinkError, NoUserAudio +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from discord.voice import VoiceData __all__ = ("MKASink",) -class MKASink(Sink): +class MKASink(FFmpegSink): """A special sink for .mka files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - MKASinkError - Any error raised while formatting, wrapped around MKASinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", - ] - - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - self.__process_queue.append(process) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise MKASinkError(f"Audio formatting for user {user_id} failed") from exc - - out = process.communicate(data.read())[0] - buffer = io.BytesIO(out) - buffer.seek(0) - - try: - self.__process_queue.remove(process) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.mka") - return buffer - - def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug("Cleaning process %s for sink %s", process, self) - process.kill() - - def cleanup(self) -> None: - for process in self.__process_queue: - self._clean_process(process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f matroska -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 5dda65f16c..e458ef29c4 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -24,205 +24,95 @@ from __future__ import annotations -import io -import logging -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MKVSinkError, NoUserAudio +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from discord.voice import VoiceData __all__ = ("MKVSink",) -class MKVSink(Sink): +class MKVSink(FFmpegSink): """A special sink for .mkv files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - MKVSinkError - Any error raised while formatting, wrapped around MKVSinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "matroska", - "pipe:1", - ] - - try: - process = subprocess.Popen( - args, stdin=subprocess.PIPE, stdout=subprocess.PIPE - ) - self.__process_queue.append(process) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise MKVSinkError(f"Audio formatting for user {user_id} failed") from exc - - out = process.communicate(data.read())[0] - buffer = io.BytesIO(out) - buffer.seek(0) - - try: - self.__process_queue.remove(process) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.mkv") - return buffer - - def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug("Cleaning process %s for sink %s", process, self) - process.kill() - - def cleanup(self) -> None: - for process in self.__process_queue: - self._clean_process(process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f matroska -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 3ce023cbd7..9cc930e07e 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -24,211 +24,95 @@ from __future__ import annotations -import io -import logging -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MP3SinkError, NoUserAudio +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from discord.voice import VoiceData __all__ = ("MP3Sink",) -class MP3Sink(Sink): +class MP3Sink(FFmpegSink): """A special sink for .mp3 files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - ret = self.__audio_data.get(user_id) - _log.debug("Found stored user ID %s with buffer %s", user_id, ret) - return ret - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - _log.debug("Created user ID %s buffer", uid) - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - MP3SinkError - Any error raised while formatting, wrapped around MP3SinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp3", - "pipe:1", - ] - - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - self.__process_queue.append(process) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise MP3SinkError(f"Audio formatting for user {user_id} failed") from exc - - out = process.communicate(data.read())[0] - buffer = io.BytesIO(out) - buffer.seek(0) - - try: - self.__process_queue.remove(process) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.mp3") - return buffer - - def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug("Cleaning process %s for sink %s", process, self) - process.kill() - - def cleanup(self) -> None: - for process in self.__process_queue: - self._clean_process(process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f mp3 -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index 29a47ddece..e3ca273d63 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -24,230 +24,95 @@ from __future__ import annotations -import io -import logging -import os -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord import utils -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, MP4SinkError, NoUserAudio +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from discord.voice import VoiceData __all__ = ("MP4Sink",) -class MP4Sink(Sink): +class MP4Sink(FFmpegSink): """A special sink for .mp4 files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[tuple[str, subprocess.Popen]] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - MP4SinkError - Any error raised while formatting, wrapped around MP4SinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - temp_path = f"{user_id}-{time.time()}-recording.mp4.tmp" - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "mp4", - temp_path, - ] - - if os.path.exists(temp_path): - found = utils.find(lambda d: d[0] == temp_path, self.__process_queue) - if found: - _, old_process = found - old_process.kill() - _log.info( - "Killing old process (%s) to write in %s", old_process, temp_path - ) - - os.remove( - temp_path - ) # process would get stuck asking whether to overwrite, if file already exists. - - try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) - self.__process_queue.append((temp_path, process)) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise MP4SinkError(f"Audio formatting for user {user_id} failed") from exc - - process.communicate(data.read()) - - with open(temp_path, "rb") as file: - buffer = io.BytesIO(file.read()) - buffer.seek(0) - - try: - self.__process_queue.remove((temp_path, process)) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.mp4") - return buffer - - def _clean_process(self, path: str, process: subprocess.Popen) -> None: - _log.debug( - "Cleaning process %s for sink %s (with temporary file at %s)", - process, - self, - path, - ) - process.kill() - if os.path.exists(path): - os.remove(path) - - def cleanup(self) -> None: - for path, process in self.__process_queue: - self._clean_process(path, process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f mp4 -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 8075aa3a99..da279055ca 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -24,208 +24,95 @@ from __future__ import annotations -import io -import logging -import subprocess -import time -from collections import deque -from typing import TYPE_CHECKING, Literal, overload - -from discord.file import File +from collections.abc import Callable +from typing import IO, TYPE_CHECKING, Any, overload + from discord.utils import MISSING -from .core import CREATE_NO_WINDOW, RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import FFmpegNotFound, MaxProcessesCountReached, NoUserAudio, OGGSinkError +from .core import FFmpegSink if TYPE_CHECKING: from typing_extensions import Self - from discord import abc - -_log = logging.getLogger(__name__) + from discord.voice import VoiceData __all__ = ("OGGSink",) -class OGGSink(Sink): +class OGGSink(FFmpegSink): """A special sink for .ogg files. .. versionadded:: 2.0 Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. - max_audio_processes_count: :class:`int` - The maximum of audio conversion processes that can be active concurrently. If this limit is exceeded, then - when calling methods like :meth:`.format_user_audio` they will raise :exc:`MaxProcessesCountReached`. + filename: :class:`str` + The file in which the recording will be saved into. + This can't be mixed with ``buffer``. + + .. versionadded:: 2.7 + buffer: IO[:class:`bytes`] + The buffer in which the recording will be saved into. + This can't be mixed with ``filename``. + + .. verionadded:: 2.7 + executable: :class:`str` + The executable in which ``ffmpeg`` is in. + + .. versionadded:: 2.7 + stderr: IO[:class:`bytes`] | :data:`None` + The stderr buffer in which will be written. Defaults to ``None``. + + .. versionadded:: 2.7 + options: :class:`str` | :data:`None` + The options to append to the ffmpeg executable flags. You should not + use this because you may override any already-provided flag. + + .. versionadded:: 2.7 + error_hook: Callable[[:class:`FFmpegSink`, :class:`Exception`, :class:`discord.voice.VoiceData` | :data:`None`], Any] | :data:`None` + The callback to call when an error ocurrs with this sink. + + .. versionadded:: 2.7 """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - max_audio_processes_count: int = 10, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - self.__process_queue: deque[subprocess.Popen] = deque( - maxlen=max_audio_processes_count - ) - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audio data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str, executable: str = ..., - as_file: Literal[True], - ) -> File: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... @overload - def format_user_audio( + def __init__( self, - user_id: int, *, + buffer: IO[bytes], executable: str = ..., - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + stderr: IO[bytes] | None = ..., + options: str | None = ..., + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = ..., + ) -> None: ... - def format_user_audio( + def __init__( self, - user_id: int, *, + filename: str = MISSING, + buffer: IO[bytes] = MISSING, executable: str = "ffmpeg", - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - executable: :class:`str` - The FFmpeg executable path to use for this formatting. It defaults - to ``ffmpeg``. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - FFmpegNotFound - The provided FFmpeg executable was not found. - MaxProcessesCountReached - You tried to go over the maximum processes count threshold. - OGGSinkError - Any error raised while formatting, wrapped around OGGSinkError. - """ - - if len(self.__process_queue) >= 10: - raise MaxProcessesCountReached - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - _log.info("There is no audio data for %s, ignoring.", user_id) - raise NoUserAudio - - args = [ - executable, - "-f", - "s16le", - "-ar", - "48000", - "-loglevel", - "error", - "-ac", - "2", - "-i", - "-", - "-f", - "ogg", - "pipe:1", - ] - - try: - process = subprocess.Popen( - args, - creationflags=CREATE_NO_WINDOW, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - self.__process_queue.append(process) - except FileNotFoundError as exc: - raise FFmpegNotFound from exc - except subprocess.SubprocessError as exc: - raise OGGSinkError(f"Audio formatting for user {user_id} failed") from exc - - out = process.communicate(data.read())[0] - buffer = io.BytesIO(out) - buffer.seek(0) - - try: - self.__process_queue.remove(process) - except ValueError: - pass - - if as_file: - return File(buffer, filename=f"{user_id}-{time.time()}-recording.ogg") - return buffer - - def _clean_process(self, process: subprocess.Popen) -> None: - _log.debug("Cleaning process %s for sink %s", process, self) - process.kill() - - def cleanup(self) -> None: - for process in self.__process_queue: - self._clean_process(process) - self.__process_queue.clear() - - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + stderr: IO[bytes] | None = None, + options: str | None = None, + error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, + ) -> None: + super().__init__( + executable=executable, + before_options="-f ogg -loglevel error", + filename=filename, + buffer=buffer, + stderr=stderr, + options=options, + error_hook=error_hook, + ) # type: ignore diff --git a/discord/sinks/pcm.py b/discord/sinks/pcm.py index ddc156b173..5ae17b2df4 100644 --- a/discord/sinks/pcm.py +++ b/discord/sinks/pcm.py @@ -25,19 +25,13 @@ from __future__ import annotations import io -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING -from discord.file import File -from discord.utils import MISSING - -from .core import RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import NoUserAudio +from .core import Sink if TYPE_CHECKING: - from typing_extensions import Self - from discord import abc + from discord.voice import VoiceData __all__ = ("PCMSink",) @@ -46,112 +40,12 @@ class PCMSink(Sink): """A special sink for .pcm files. .. versionadded:: 2.0 - - Parameters - ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. """ - def __init__( - self, - *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, - ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audiop data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - - @overload - def format_user_audio( - self, - user_id: int, - *, - as_file: Literal[True], - ) -> File: ... - - @overload - def format_user_audio( - self, - user_id: int, - *, - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... - - def format_user_audio( - self, - user_id: int, - *, - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - """ - - try: - data = self.__audio_data.pop(user_id) - except KeyError: - raise NoUserAudio - - data.seek(0) - - if as_file: - return File(data, filename=f"{user_id}-recording.pcm") - return data - - def cleanup(self) -> None: - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() + def __init__(self) -> None: + super().__init__(dest=None) - self.__audio_data.clear() - super().cleanup() + self.buffer: io.BytesIO = io.BytesIO() - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + def write(self, user: abc.User | None, data: VoiceData) -> None: + self.buffer.write(data.pcm) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index ea588a6fce..c2a1ee3299 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -25,20 +25,19 @@ from __future__ import annotations import io +import logging import wave -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING +from discord.opus import Decoder from discord.file import File -from discord.utils import MISSING - -from .core import RawData, Sink, SinkFilter, SinkHandler -from .enums import SinkFilteringMode -from .errors import NoUserAudio +from .core import Sink if TYPE_CHECKING: - from typing_extensions import Self - from discord import abc + from discord.voice import VoiceData + +_log = logging.getLogger(__name__) __all__ = ( "WaveSink", @@ -53,119 +52,66 @@ class WaveSink(Sink): Parameters ---------- - filters: List[:class:`~.SinkFilter`] - The filters to apply to this sink recorder. - filtering_mode: :class:`~.SinkFilteringMode` - How the filters should work. If set to :attr:`~.SinkFilteringMode.all`, all filters must go through - in order for an audio packet to be stored in this sink, else if it is set to :attr:`~.SinkFilteringMode.any`, - only one filter is required to return ``True`` in order for an audio packet to be stored in this sink. - handlers: List[:class:`~.SinkHandler`] - The sink handlers. Handlers are objects that are called after filtering, and that can be used to, for example - store a certain packet data in a file, or local mapping. + destination: :class:`str` | :term:`py:bytes-like object` + The destination in which the data should be saved into. + + If this is a filename, then it is saved into that. Else, treats + it like a buffer. + + .. versionadded:: 2.7 + channels: :class:`int` + The amount of channels. + + .. versionadded:: 2.7 + sample_width: :class:`int` + The sample width to "n" bytes. + + .. versionadded:: 2.7 + sampling_rate: :class:`int` + The frame rate. A non-integral input is rounded to the nearest int. + + .. versionadded:: 2.7 """ def __init__( self, + destination: wave._File, *, - filters: list[SinkFilter[Self]] = MISSING, - filtering_mode: SinkFilteringMode = SinkFilteringMode.all, - handlers: list[SinkHandler[Self]] = MISSING, + channels: int = Decoder.CHANNELS, + sample_width: int = Decoder.SAMPLE_SIZE // Decoder.CHANNELS, + sampling_rate: int = Decoder.SAMPLING_RATE, ) -> None: - self.__audio_data: dict[int, io.BytesIO] = {} - super().__init__( - filters=filters, - filtering_mode=filtering_mode, - handlers=handlers, - ) - - def get_user_audio(self, user_id: int) -> io.BytesIO | None: - """Gets a user's saved audiop data, or ``None``.""" - return self.__audio_data.get(user_id) - - def _create_audio_packet_for(self, uid: int) -> io.BytesIO: - data = self.__audio_data[uid] = io.BytesIO() - return data - - @overload - def format_user_audio( - self, - user_id: int, - *, - as_file: Literal[True], - ) -> File: ... + super().__init__() - @overload - def format_user_audio( - self, - user_id: int, - *, - as_file: Literal[False] = ..., - ) -> io.BytesIO: ... + self._destination: wave._File = destination + self._file: wave.Wave_write = wave.open(destination, "wb") + self._file.setnchannels(channels) + self._file.setsampwidth(sample_width) + self._file.setframerate(sampling_rate) - def format_user_audio( - self, - user_id: int, - *, - as_file: bool = False, - ) -> io.BytesIO | File: - """Formats a user's saved audio data. - - This should be called after the bot has stopped recording. - - If this is called during recording, there could be missing audio - packets. - - After this, the user's audio data will be resetted to 0 bytes and - seeked to 0. - - Parameters - ---------- - user_id: :class:`int` - The user ID of which format the audio data into a file. - as_file: :class:`bool` - Whether to return a :class:`~discord.File` object instead of a :class:`io.BytesIO`. - - Returns - ------- - Union[:class:`io.BytesIO`, :class:`~discord.File`] - The user's audio saved bytes, if ``as_file`` is ``False``, else a :class:`~discord.File` - object with the buffer set as the audio bytes. - - Raises - ------ - NoUserAudio - You tried to format the audio of a user that was not stored in this sink. - """ + def is_opus(self) -> bool: + return False - try: - data = self.__audio_data.pop(user_id) - except KeyError: - raise NoUserAudio + def write(self, user: abc.User | None, data: VoiceData) -> None: + self._file.writeframes(data.pcm) - decoder = self.client.decoder + def to_file(self, filename: str, /, *, description: str | None = None, spoiler: bool = False) -> File | None: + """Returns the :class:`discord.File` of this sink. - with wave.open(data, "wb") as f: - f.setnchannels(decoder.CHANNELS) - f.setsampwidth(decoder.SAMPLE_SIZE // decoder.CHANNELS) - f.setframerate(decoder.SAMPLING_RATE) + .. warning:: - data.seek(0) + This should be used only after the sink has stopped recording. + """ - if as_file: - return File(data, filename=f"{user_id}-recording.pcm") - return data + f = wave.open(self._destination, "rb") + data = f.readframes(f.getnframes()) + return File(io.BytesIO(data), filename, description=description, spoiler=spoiler) def cleanup(self) -> None: - for _, buffer in self.__audio_data.items(): - if not buffer.closed: - buffer.close() - - self.__audio_data.clear() - super().cleanup() - - async def on_voice_packet_receive(self, user: abc.Snowflake, data: RawData) -> None: - buffer = self.get_user_audio(user.id) or self._create_audio_packet_for(user.id) - buffer.write(data.decoded_data) + try: + self._file.close() + except Exception as exc: + _log.warning("An error ocurred while closing the wave writing file on cleanup", exc_info=exc) WavSink = WaveSink diff --git a/discord/voice/packets/__init__.py b/discord/voice/packets/__init__.py index d55c2a3312..454339d0bd 100644 --- a/discord/voice/packets/__init__.py +++ b/discord/voice/packets/__init__.py @@ -45,3 +45,7 @@ def __init__(self, packet: Packet, source: User | Member | None, *, pcm: bytes | self.packet: Packet = packet self.source: User | Member | None = source self.pcm: bytes = pcm if pcm else b'' + + @property + def opus(self) -> bytes | None: + self.packet.decrypted_data From d44f192befd8b6d3e0e3aabeb2f571d449538b4e Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:27:17 +0200 Subject: [PATCH 37/40] docs actions --- .github/workflows/docs-json-export.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-json-export.yml b/.github/workflows/docs-json-export.yml index eb46d1b1b3..618df46fc7 100644 --- a/.github/workflows/docs-json-export.yml +++ b/.github/workflows/docs-json-export.yml @@ -26,7 +26,7 @@ jobs: id: install-deps run: | python -m pip install -U pip - pip install ".[docs]" + pip install ".[docs,voice]" pip install beautifulsoup4 - name: Build Sphinx HTML docs id: build-sphinx From 135550fd59f95858268cb85de95fb315f380d55f Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:53:06 +0200 Subject: [PATCH 38/40] errors yay --- discord/_voice_aliases.py | 6 ++-- discord/opus.py | 4 +-- discord/sinks/core.py | 22 ++++++------ discord/sinks/mp3.py | 1 + discord/voice/client.py | 64 +++++++++++++++++++++++++++++++++++ discord/voice/packets/core.py | 7 ---- 6 files changed, 81 insertions(+), 23 deletions(-) diff --git a/discord/_voice_aliases.py b/discord/_voice_aliases.py index 034ae1849a..663407d70c 100644 --- a/discord/_voice_aliases.py +++ b/discord/_voice_aliases.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING -from .utils import warn_deprecated +from .utils import deprecated """ since discord.voice raises an error when importing it without having the @@ -60,12 +60,12 @@ def VoiceClient(client, channel) -> VoiceClientC: def VoiceProtocol(client, channel) -> VoiceProtocolC: ... else: - @warn_deprecated("discord.VoiceClient", "discord.voice.VoiceClient", "2.7", "3.0") + @deprecated("discord.VoiceClient", "discord.voice.VoiceClient", "2.7", "3.0") def VoiceClient(client, channel): from discord.voice import VoiceClient return VoiceClient(client, channel) - @warn_deprecated("discord.VoiceProtocol", "discord.voice.VoiceProtocol", "2.7", "3.0") + @deprecated("discord.VoiceProtocol", "discord.voice.VoiceProtocol", "2.7", "3.0") def VoiceProtocol(client, channel): from discord.voice import VoiceProtocol return VoiceProtocol(client, channel) diff --git a/discord/opus.py b/discord/opus.py index f3a3ed8986..345449ad51 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -28,7 +28,6 @@ import array import ctypes import ctypes.util -import gc import logging import math import os.path @@ -37,13 +36,12 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar from discord.voice.packets.rtp import FakePacket -from discord.voice.utils.wrapped import gap_wrapped, add_wrapped +from discord.voice.utils.wrapped import add_wrapped from discord.voice.utils.buffer import JitterBuffer import davey from .errors import DiscordException -from .sinks import RawData if TYPE_CHECKING: from discord.user import User diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 949fefc25f..2676200e57 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -71,15 +71,9 @@ _log = logging.getLogger(__name__) -class SinkBase: - """Represents an audio sink in which user's audios are stored. - """ - - __sink_listeners__: list[tuple[str, str]] - _client: VoiceClient | None - - def __new__(cls) -> Self: +class SinkMeta(type): + def __new__(cls, *args, **kwargs): listeners = {} for base in reversed(cls.__mro__): @@ -102,7 +96,15 @@ def __new__(cls) -> Self: listeners_list.append((listener_name, listener.__name__)) cls.__sink_listeners__ = listeners_list - return super().__new__(cls) + return super().__new__(cls, *args, **kwargs) + + +class SinkBase(metaclass=SinkMeta): + """Represents an audio sink in which user's audios are stored. + """ + + __sink_listeners__: list[tuple[str, str]] + _client: VoiceClient | None @property def root(self) -> Sink: @@ -381,7 +383,7 @@ def __init__( options: str | None = None, error_hook: Callable[[Self, Exception, VoiceData | None], Any] | None = None, ) -> None: - super().__init__() + super().__init__(dest=None) if filename is not MISSING and buffer is not MISSING: raise TypeError("can't mix filename and buffer parameters") diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 9cc930e07e..503b95a1f9 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -116,3 +116,4 @@ def __init__( options=options, error_hook=error_hook, ) # type: ignore + \ No newline at end of file diff --git a/discord/voice/client.py b/discord/voice/client.py index 16cb32fcd3..cce7cee3b6 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -33,6 +33,7 @@ import warnings from discord import opus +from discord.enums import SpeakingState, try_enum from discord.errors import ClientException from discord.player import AudioPlayer, AudioSource from discord.sinks.core import Sink @@ -42,6 +43,7 @@ from ._types import VoiceProtocol from .receive import AudioReader from .state import VoiceConnectionState +from .enums import OpCodes if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -202,6 +204,68 @@ def checked_add(self, attr: str, value: int, limit: int) -> None: def create_connection_state(self) -> VoiceConnectionState: return VoiceConnectionState(self, hook=self._recv_hook) + async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: + op = msg["op"] + data = msg.get("d", {}) + + if op == OpCodes.ready: + self._add_ssrc(self.guild.me.id, data["ssrc"]) + elif op == OpCodes.speaking: + uid = int(data["user_id"]) + ssrc = data["ssrc"] + + self._add_ssrc(uid, ssrc) + + member = self.guild.get_member(uid) + state = try_enum(SpeakingState, data["speaking"]) + self.dispatch("member_speaking_state_update", member, ssrc, state) + elif op == OpCodes.clients_connect: + uids = list(map(int, data["user_ids"])) + + for uid in uids: + member = self.guild.get_member(uid) + if not member: + _log.warning("Skipping member referencing ID %d on member_connect", uid) + continue + self.dispatch("member_connect", member) + elif op == OpCodes.client_disconnect: + uid = int(data["user_id"]) + ssrc = self._id_to_ssrc.get(uid) + + if self._reader and ssrc is not None: + _log.debug("Destroying decoder for user %d, ssrc=%d", uid, ssrc) + self._reader.packet_router.destroy_decoder(ssrc) + + self._remove_ssrc(user_id=uid) + member = self.guild.get_member(uid) + self.dispatch("member_disconnect", member, ssrc) + + # maybe handle video and such things? + + async def _run_event(self, coro, event_name: str, *args: Any, **kwargs: Any) -> None: + try: + await coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception: + _log.exception("Error calling %s", event_name) + + def _schedule_event(self, coro, event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + wrapped = self._run_event(coro, event_name, *args, **kwargs) + return self.client.loop.create_task(wrapped, name=f"voice-receiver-event-dispatch: {event_name}") + + def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: + _log.debug("Dispatching voice_client event %s", event) + + event_name = f"on_{event}" + for coro in self._event_listeners.get(event_name, []): + task = self._schedule_event(coro, event_name, *args, **kwargs) + self._connection._dispatch_task_set.add(task) + task.add_done_callback(self._connection._dispatch_task_set.discard) + + self._dispatch_sink(event, *args, **kwargs) + self.client.dispatch(event, *args, **kwargs) + async def on_voice_state_update(self, data: RawVoiceStateUpdateEvent) -> None: old_channel_id = self.channel.id if self.channel else None await self._connection.voice_state_update(data) diff --git a/discord/voice/packets/core.py b/discord/voice/packets/core.py index 6a2ca96a89..a372e860b7 100644 --- a/discord/voice/packets/core.py +++ b/discord/voice/packets/core.py @@ -25,16 +25,9 @@ from __future__ import annotations from typing import TYPE_CHECKING -from discord.opus import Decoder, _lib - if TYPE_CHECKING: from typing_extensions import Final -if _lib is None: - DECODER = None -else: - DECODER = Decoder() - OPUS_SILENCE: Final = b'\xf8\xff\xfe' From a324844ea44f2de4011de1567fa2be3c0342285b Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:09:45 +0200 Subject: [PATCH 39/40] okay this needs fixing --- discord/opus.py | 17 ++++++++++++----- discord/sinks/core.py | 12 ++++++++---- discord/voice/client.py | 4 ++-- discord/voice/packets/rtp.py | 1 + 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/discord/opus.py b/discord/opus.py index 345449ad51..684aa07f1e 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -653,23 +653,30 @@ def _make_fakepacket(self) -> FakePacket: def _process_packet(self, packet: Packet) -> VoiceData: from discord.object import Object + from discord.voice import VoiceData - pcm = None + assert self.sink.client - if not self.sink.is_opus(): - packet, pcm = self._decode_packet(packet) + pcm = None member = self._get_cached_member() if member is None: - self._cached_id = self.sink.client._connection._get_id_from_ssrc(self.ssrc) + self._cached_id = self.sink.client._ssrc_to_id.get(self.ssrc) member = self._get_cached_member() + else: + self._cached_id = member.id # yet still none, use Object if member is None and self._cached_id: member = Object(id=self._cached_id) - data = VoiceData(packet, member, pcm=pcm) + if not self.sink.is_opus(): + _log.debug("Decoding packet %s (type %s)", packet, type(packet)) + packet, pcm = self._decode_packet(packet) + + + data = VoiceData(packet, member, pcm=pcm) # type: ignore self._last_seq = packet.sequence self._last_ts = packet.timestamp return data diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 2676200e57..88bdf59fa6 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -73,10 +73,14 @@ class SinkMeta(type): - def __new__(cls, *args, **kwargs): + __sink_listeners__: list[tuple[str, str]] + + def __new__(cls, name, bases, attr, **kwargs): listeners = {} - for base in reversed(cls.__mro__): + inst = super().__new__(cls, name, bases, attr, **kwargs) + + for base in reversed(inst.__mro__): for elem, value in base.__dict__.items(): if elem in listeners: del listeners[elem] @@ -95,8 +99,8 @@ def __new__(cls, *args, **kwargs): for listener_name in listener.__sink_listener_names__: listeners_list.append((listener_name, listener.__name__)) - cls.__sink_listeners__ = listeners_list - return super().__new__(cls, *args, **kwargs) + inst.__sink_listeners__ = listeners_list + return inst class SinkBase(metaclass=SinkMeta): diff --git a/discord/voice/client.py b/discord/voice/client.py index cce7cee3b6..5bcb04c438 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -585,7 +585,7 @@ def stop(self) -> None: for cb, _ in self._player_future._callbacks: self._player_future.remove_done_callback(cb) self._player_future.set_result(None) - if self._reader: + if self._reader is not MISSING: self._reader.stop() self._reader = MISSING @@ -733,7 +733,7 @@ def stop_recording(self) -> None: RecordingException You are not recording. """ - if self._reader: + if self._reader is not MISSING: self._reader.stop() self._reader = MISSING else: diff --git a/discord/voice/packets/rtp.py b/discord/voice/packets/rtp.py index 8ed822f92e..6a05c16cf6 100644 --- a/discord/voice/packets/rtp.py +++ b/discord/voice/packets/rtp.py @@ -204,6 +204,7 @@ def __init__(self, data: bytes) -> None: self.version: int = head >> 6 self.padding: bool = bool(head & 0b00100000) + setattr(self, "report_count", head & 0b00011111) def __repr__(self) -> str: return f"<{self.__class__.__name__} version={self.version} padding={self.padding} length={self.length}>" From 805d055375219a40e6b8c7ab2c06a577adb41262 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Sep 2025 13:10:34 +0000 Subject: [PATCH 40/40] style(pre-commit): auto fixes from pre-commit.com hooks --- discord/__init__.py | 3 +- discord/_voice_aliases.py | 13 +- discord/gateway.py | 1 - discord/opus.py | 17 +- discord/player.py | 244 ++++++++++++++++++-------- discord/sinks/core.py | 114 ++++++++---- discord/sinks/m4a.py | 2 +- discord/sinks/mka.py | 2 +- discord/sinks/mkv.py | 2 +- discord/sinks/mp3.py | 3 +- discord/sinks/mp4.py | 2 +- discord/sinks/ogg.py | 2 +- discord/sinks/wave.py | 16 +- discord/voice/_types.py | 4 +- discord/voice/client.py | 34 ++-- discord/voice/gateway.py | 66 +++++-- discord/voice/packets/__init__.py | 18 +- discord/voice/packets/core.py | 18 +- discord/voice/packets/rtp.py | 31 ++-- discord/voice/receive/reader.py | 70 +++++--- discord/voice/receive/router.py | 25 ++- discord/voice/state.py | 21 ++- discord/voice/utils/buffer.py | 29 ++- discord/voice/utils/multidataevent.py | 3 +- discord/voice/utils/wrapped.py | 3 +- 25 files changed, 505 insertions(+), 238 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index 49389bc245..92f3e7a215 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -25,6 +25,7 @@ from . import abc, opus, sinks, ui, utils +from ._voice_aliases import * from .activity import * from .appinfo import * from .application_role_connection import * @@ -75,7 +76,5 @@ from .webhook import * from .welcome_screen import * from .widget import * -from ._voice_aliases import * - logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/_voice_aliases.py b/discord/_voice_aliases.py index 663407d70c..9054e4a04a 100644 --- a/discord/_voice_aliases.py +++ b/discord/_voice_aliases.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations from typing import TYPE_CHECKING @@ -42,30 +43,32 @@ if TYPE_CHECKING: from typing_extensions import deprecated - from discord.voice import VoiceProtocolC, VoiceClientC + from discord.voice import VoiceClientC, VoiceProtocolC @deprecated( "discord.VoiceClient is deprecated in favour " "of discord.voice.VoiceClient since 2.7 and " "will be removed in 3.0", ) - def VoiceClient(client, channel) -> VoiceClientC: - ... + def VoiceClient(client, channel) -> VoiceClientC: ... @deprecated( "discord.VoiceProtocol is deprecated in favour " "of discord.voice.VoiceProtocol since 2.7 and " "will be removed in 3.0", ) - def VoiceProtocol(client, channel) -> VoiceProtocolC: - ... + def VoiceProtocol(client, channel) -> VoiceProtocolC: ... + else: + @deprecated("discord.VoiceClient", "discord.voice.VoiceClient", "2.7", "3.0") def VoiceClient(client, channel): from discord.voice import VoiceClient + return VoiceClient(client, channel) @deprecated("discord.VoiceProtocol", "discord.voice.VoiceProtocol", "2.7", "3.0") def VoiceProtocol(client, channel): from discord.voice import VoiceProtocol + return VoiceProtocol(client, channel) diff --git a/discord/gateway.py b/discord/gateway.py index b5138c0813..b9b28f887d 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -33,7 +33,6 @@ import time import traceback import zlib -from collections import deque from collections.abc import Callable from typing import TYPE_CHECKING, Any, NamedTuple diff --git a/discord/opus.py b/discord/opus.py index 684aa07f1e..45a09d0573 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -35,22 +35,22 @@ import sys from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar +import davey + from discord.voice.packets.rtp import FakePacket -from discord.voice.utils.wrapped import add_wrapped from discord.voice.utils.buffer import JitterBuffer - -import davey +from discord.voice.utils.wrapped import add_wrapped from .errors import DiscordException if TYPE_CHECKING: - from discord.user import User from discord.member import Member + from discord.sinks.core import Sink + from discord.user import User from discord.voice.client import VoiceClient - from discord.voice.receive.router import PacketRouter - from discord.voice.packets.core import Packet from discord.voice.packets import VoiceData - from discord.sinks.core import Sink + from discord.voice.packets.core import Packet + from discord.voice.receive.router import PacketRouter T = TypeVar("T") APPLICATION_CTL = Literal["audio", "voip", "lowdelay"] @@ -675,7 +675,6 @@ def _process_packet(self, packet: Packet) -> VoiceData: _log.debug("Decoding packet %s (type %s)", packet, type(packet)) packet, pcm = self._decode_packet(packet) - data = VoiceData(packet, member, pcm=pcm) # type: ignore self._last_seq = packet.sequence self._last_ts = packet.timestamp @@ -685,7 +684,7 @@ def _decode_packet(self, packet: Packet) -> tuple[Packet, bytes]: assert self._decoder is not None assert self.sink.client - user_id: int | None = self._cached_id + user_id: int | None = self._cached_id dave: davey.DaveSession | None = self.sink.client._connection.dave_session in_dave = dave is not None diff --git a/discord/player.py b/discord/player.py index 9d67c3f0bb..8876f8df55 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,9 +36,9 @@ import sys import threading import time +import warnings from math import floor from typing import IO, TYPE_CHECKING, Any, Callable, Generic, TypeVar -import warnings from .enums import SpeakingState from .errors import ClientException @@ -100,7 +100,7 @@ def read(self) -> bytes: per frame (20ms worth of audio). Returns - -------- + ------- :class:`bytes` A bytes like object that represents the PCM or Opus data. """ @@ -125,7 +125,7 @@ class PCMAudio(AudioSource): """Represents raw 16-bit 48KHz stereo PCM audio source. Attributes - ----------- + ---------- stream: :term:`py:file object` A file-like object that reads byte data representing raw PCM. """ @@ -136,7 +136,7 @@ def __init__(self, stream: io.BufferedIOBase) -> None: def read(self) -> bytes: ret = self.stream.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b'' + return b"" return ret @@ -155,18 +155,22 @@ def __init__( self, source: str | io.BufferedIOBase, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", args: Any, **subprocess_kwargs: Any, ): - piping_stdin = subprocess_kwargs.get('stdin') == subprocess.PIPE + piping_stdin = subprocess_kwargs.get("stdin") == subprocess.PIPE if piping_stdin and isinstance(source, str): - raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") + raise TypeError( + "parameter conflict: 'source' parameter cannot be a string when piping to stdin" + ) - stderr: IO[bytes] | None = subprocess_kwargs.pop('stderr', None) + stderr: IO[bytes] | None = subprocess_kwargs.pop("stderr", None) if stderr == subprocess.PIPE: - warnings.warn('Passing subprocess.PIPE does nothing', DeprecationWarning, stacklevel=3) + warnings.warn( + "Passing subprocess.PIPE does nothing", DeprecationWarning, stacklevel=3 + ) stderr = None piping_stderr = False @@ -177,7 +181,10 @@ def __init__( piping_stderr = True args = [executable, *args] - kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.PIPE if piping_stderr else stderr} + kwargs = { + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE if piping_stderr else stderr, + } kwargs.update(subprocess_kwargs) # Ensure attribute is assigned even in the case of errors @@ -192,25 +199,33 @@ def __init__( if piping_stdin: n = f"popen-stdin-writer:pid-{self._process.pid}" self._stdin = self._process.stdin - self._pipe_writer_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) + self._pipe_writer_thread = threading.Thread( + target=self._pipe_writer, args=(source,), daemon=True, name=n + ) self._pipe_writer_thread.start() if piping_stderr: n = f"popen-stderr-reader:pid-{self._process.pid}" self._stderr = self._process.stderr - self._pipe_reader_thread = threading.Thread(target=self._pipe_reader, args=(stderr,), daemon=True, name=n) + self._pipe_reader_thread = threading.Thread( + target=self._pipe_reader, args=(stderr,), daemon=True, name=n + ) self._pipe_reader_thread.start() def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: _log.debug("Spawning ffmpeg process with command: %s", args) process = None try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs + ) except FileNotFoundError: executable = args.partition(" ")[0] if isinstance(args, str) else args[0] raise ClientException(executable + " was not found.") from None except subprocess.SubprocessError as exc: - raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc + raise ClientException( + f"Popen failed: {exc.__class__.__name__}: {exc}" + ) from exc else: return process @@ -225,14 +240,27 @@ def _kill_process(self) -> None: try: proc.kill() except Exception: - _log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid) + _log.exception( + "Ignoring error attempting to kill ffmpeg process %s", proc.pid + ) if proc.poll() is None: - _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) + _log.info( + "ffmpeg process %s has not terminated. Waiting to terminate...", + proc.pid, + ) proc.communicate() - _log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode) + _log.info( + "ffmpeg process %s should have terminated with a return code of %s.", + proc.pid, + proc.returncode, + ) else: - _log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode) + _log.info( + "ffmpeg process %s successfully terminated with return code of %s.", + proc.pid, + proc.returncode, + ) def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: @@ -245,7 +273,11 @@ def _pipe_writer(self, source: io.BufferedIOBase) -> None: if self._stdin is not None: self._stdin.write(data) except Exception: - _log.debug('Write error for %s, this is probably not a problem', self, exc_info=True) + _log.debug( + "Write error for %s, this is probably not a problem", + self, + exc_info=True, + ) # at this point the source data is either exhausted or the process is fubar self._process.terminate() return @@ -257,7 +289,11 @@ def _pipe_reader(self, dest: IO[bytes]) -> None: try: data: bytes = self._stderr.read(self.BLOCKSIZE) except Exception: - _log.debug("Read error for %s, this is probably not a problem", self, exc_info=True) + _log.debug( + "Read error for %s, this is probably not a problem", + self, + exc_info=True, + ) return if data is None: return @@ -284,7 +320,7 @@ class FFmpegPCMAudio(FFmpegAudio): variable in order for this to work. Parameters - ------------ + ---------- source: Union[:class:`str`, :class:`io.BufferedIOBase`] The input that ffmpeg will take and convert to PCM bytes. If ``pipe`` is ``True`` then this is a file-like object that is @@ -308,7 +344,7 @@ class FFmpegPCMAudio(FFmpegAudio): Extra command line arguments to pass to ffmpeg after the ``-i`` flag. Raises - -------- + ------ ClientException The subprocess failed to be created. """ @@ -317,14 +353,17 @@ def __init__( self, source: str | io.BufferedIOBase, *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe: bool = False, stderr: IO[bytes] | None = None, before_options: str | None = None, options: str | None = None, ) -> None: args = [] - subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} + subprocess_kwargs = { + "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, + "stderr": stderr, + } if isinstance(before_options, str): args.extend(shlex.split(before_options)) @@ -332,11 +371,20 @@ def __init__( args.append("-i") args.append("-" if pipe else source) - args.extend(("-f", "s16le", - "-ar", "48000", - "-ac", "2", - "-loglevel", "warning", - "-blocksize", str(self.BLOCKSIZE))) + args.extend( + ( + "-f", + "s16le", + "-ar", + "48000", + "-ac", + "2", + "-loglevel", + "warning", + "-blocksize", + str(self.BLOCKSIZE), + ) + ) if isinstance(options, str): args.extend(shlex.split(options)) @@ -377,7 +425,7 @@ class FFmpegOpusAudio(FFmpegAudio): variable in order for this to work. Parameters - ------------ + ---------- source: Union[:class:`str`, :class:`io.BufferedIOBase`] The input that ffmpeg will take and convert to Opus bytes. If ``pipe`` is ``True`` then this is a file-like object that is @@ -410,7 +458,7 @@ class FFmpegOpusAudio(FFmpegAudio): Extra command line arguments to pass to ffmpeg after the ``-i`` flag. Raises - -------- + ------ ClientException The subprocess failed to be created. """ @@ -421,14 +469,17 @@ def __init__( *, bitrate: int | None = None, codec: str | None = None, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe: bool = False, stderr: IO[bytes] | None = None, before_options: str | None = None, options: str | None = None, ) -> None: args = [] - subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} + subprocess_kwargs = { + "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, + "stderr": stderr, + } if isinstance(before_options, str): args.extend(shlex.split(before_options)) @@ -439,16 +490,30 @@ def __init__( codec = "copy" if codec in ("opus", "libopus", "copy") else "libopus" bitrate = bitrate if bitrate is not None else 128 - args.extend(("-map_metadata", "-1", - "-f", "opus", - "-c:a", codec, - "-ar", "48000", - "-ac", "2", - "-b:a", f"{bitrate}k", - "-loglevel", "warning", - "-fec", "true", - "-packet_loss", "15", - "-blocksize", str(self.BLOCKSIZE))) + args.extend( + ( + "-map_metadata", + "-1", + "-f", + "opus", + "-c:a", + codec, + "-ar", + "48000", + "-ac", + "2", + "-b:a", + f"{bitrate}k", + "-loglevel", + "warning", + "-fec", + "true", + "-packet_loss", + "15", + "-blocksize", + str(self.BLOCKSIZE), + ) + ) if isinstance(options, str): args.extend(shlex.split(options)) @@ -521,7 +586,7 @@ def custom_probe(source, executable): An instance of this class. """ - executable = kwargs.get('executable') + executable = kwargs.get("executable") codec, bitrate = await cls.probe(source, method=method, executable=executable) return cls(source, bitrate=bitrate, codec=codec, **kwargs) @@ -538,7 +603,7 @@ async def probe( Probes the input source for bitrate and codec information. Parameters - ------------ + ---------- source Identical to the ``source`` parameter for :class:`FFmpegOpusAudio`. method @@ -546,17 +611,17 @@ async def probe( executable: :class:`str` Identical to the ``executable`` parameter for :class:`FFmpegOpusAudio`. + Returns + ------- + Optional[Tuple[Optional[:class:`str`], :class:`int`]] + A 2-tuple with the codec and bitrate of the input source. + Raises - -------- + ------ AttributeError Invalid probe method, must be ``'native'`` or ``'fallback'``. TypeError Invalid value for ``probe`` parameter, must be :class:`str` or a callable. - - Returns - --------- - Optional[Tuple[Optional[:class:`str`], :class:`int`]] - A 2-tuple with the codec and bitrate of the input source. """ method = method or "native" @@ -583,7 +648,9 @@ async def probe( codec = bitrate = None loop = asyncio.get_running_loop() try: - codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) + codec, bitrate = await loop.run_in_executor( + None, lambda: probefunc(source, executable) + ) except (KeyboardInterrupt, SystemExit): raise except BaseException: @@ -591,9 +658,13 @@ async def probe( _log.exception("Probe '%s' using '%s' failed", method, executable) return None, None - _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) + _log.exception( + "Probe '%s' using '%s' failed, trying fallback", method, executable + ) try: - codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) + codec, bitrate = await loop.run_in_executor( + None, lambda: fallback(source, executable) + ) except (KeyboardInterrupt, SystemExit): raise except BaseException: @@ -606,35 +677,58 @@ async def probe( return codec, bitrate @staticmethod - def _probe_codec_native(source, executable: str = "ffmpeg") -> tuple[str | None, int | None]: - exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable - args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] + def _probe_codec_native( + source, executable: str = "ffmpeg" + ) -> tuple[str | None, int | None]: + exe = ( + executable[:2] + "probe" + if executable in ("ffmpeg", "avconv") + else executable + ) + args = [ + exe, + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + "-select_streams", + "a:0", + source, + ] output = subprocess.check_output(args, timeout=20) codec = bitrate = None if output: data = json.loads(output) - streamdata = data['streams'][0] + streamdata = data["streams"][0] - codec = streamdata.get('codec_name') - bitrate = int(streamdata.get('bit_rate', 0)) + codec = streamdata.get("codec_name") + bitrate = int(streamdata.get("bit_rate", 0)) bitrate = max(round(bitrate / 1000), 512) return codec, bitrate @staticmethod - def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - args = [executable, '-hide_banner', '-i', source] - proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + def _probe_codec_fallback( + source, executable: str = "ffmpeg" + ) -> Tuple[Optional[str], Optional[int]]: + args = [executable, "-hide_banner", "-i", source] + proc = subprocess.Popen( + args, + creationflags=CREATE_NO_WINDOW, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) out, _ = proc.communicate(timeout=20) - output = out.decode('utf8') + output = out.decode("utf8") codec = bitrate = None - codec_match = re.search(r'Stream #0.*?Audio: (\w+)', output) + codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output) if codec_match: codec = codec_match.group(1) - br_match = re.search(r'(\d+) [kK]b/s', output) + br_match = re.search(r"(\d+) [kK]b/s", output) if br_match: bitrate = max(int(br_match.group(1)), 512) @@ -654,7 +748,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): set to ``True``. Parameters - ------------ + ---------- original: :class:`AudioSource` The original AudioSource to transform. volume: :class:`float` @@ -662,7 +756,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): See :attr:`volume` for more info. Raises - ------- + ------ TypeError Not an audio source. ClientException @@ -671,10 +765,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): def __init__(self, original: AT, volume: float = 1.0): if not isinstance(original, AudioSource): - raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') + raise TypeError(f"expected AudioSource not {original.__class__.__name__}.") if original.is_opus(): - raise ClientException('AudioSource must not be Opus encoded.') + raise ClientException("AudioSource must not be Opus encoded.") self.original: AT = original self.volume = volume @@ -715,7 +809,7 @@ def __init__( *, after: Callable[[Exception | None], Any] | None = None, ) -> None: - super().__init__(daemon=True, name=f'audio-player:{id(self):#x}') + super().__init__(daemon=True, name=f"audio-player:{id(self):#x}") self.source: AudioSource = source self.client: VoiceClient = client self.after: Callable[[Exception | None], Any] | None = after @@ -797,9 +891,9 @@ def _call_after(self) -> None: self.after(error) except Exception as exc: exc.__context__ = error - _log.exception('Calling the after function failed.', exc_info=exc) + _log.exception("Calling the after function failed.", exc_info=exc) elif error: - _log.exception('Exception in voice thread %s', self.name, exc_info=error) + _log.exception("Exception in voice thread %s", self.name, exc_info=error) def stop(self) -> None: self._end.set() @@ -832,7 +926,9 @@ def set_source(self, source: AudioSource) -> None: def _speak(self, speaking: SpeakingState) -> None: try: - asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop) + asyncio.run_coroutine_threadsafe( + self.client.ws.speak(speaking), self.client.client.loop + ) except Exception as exc: _log.exception("Speaking call in player failed", exc_info=exc) diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 88bdf59fa6..ccd73c4d59 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -25,26 +25,26 @@ from __future__ import annotations -import logging -import sys -from collections.abc import Callable, Generator, Sequence import inspect -import subprocess +import logging import shlex +import subprocess +import sys import threading +from collections.abc import Callable, Generator, Sequence from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, overload from discord.file import File -from discord.utils import MISSING, SequenceProxy from discord.player import FFmpegAudio +from discord.utils import MISSING, SequenceProxy from .errors import FFmpegNotFound if TYPE_CHECKING: from typing_extensions import ParamSpec, Self - from discord.user import User from discord.member import Member + from discord.user import User from discord.voice.packets import VoiceData from ..voice.client import VoiceClient @@ -71,7 +71,6 @@ _log = logging.getLogger(__name__) - class SinkMeta(type): __sink_listeners__: list[tuple[str, str]] @@ -89,7 +88,7 @@ def __new__(cls, name, bases, attr, **kwargs): if is_static: value = value.__func__ - if not hasattr(value, '__sink_listener__'): + if not hasattr(value, "__sink_listener__"): continue listeners[elem] = value @@ -104,8 +103,7 @@ def __new__(cls, name, bases, attr, **kwargs): class SinkBase(metaclass=SinkMeta): - """Represents an audio sink in which user's audios are stored. - """ + """Represents an audio sink in which user's audios are stored.""" __sink_listeners__: list[tuple[str, str]] _client: VoiceClient | None @@ -151,7 +149,7 @@ def _register_child(self, child: Sink) -> None: """Registers a child to this sink.""" raise NotImplementedError - def walk_children(self, *, with_self: bool = False) -> Generator[Sink, None, None]: + def walk_children(self, *, with_self: bool = False) -> Generator[Sink]: """Iterates through all the children of this sink, including nested.""" if with_self: yield self # type: ignore @@ -227,7 +225,9 @@ def listener(cls, name: str = MISSING): """ if name is not MISSING and not isinstance(name, str): - raise TypeError(f"expected a str for listener name, got {name.__class__.__name__} instead") + raise TypeError( + f"expected a str for listener name, got {name.__class__.__name__} instead" + ) def decorator(func): actual = func @@ -247,6 +247,7 @@ def decorator(func): actual.__sink_listener_names__ = [to_assign] return func + return decorator @@ -315,10 +316,14 @@ def remove_destination(self, dest: Sink, /) -> None: ) def RawData(**kwargs: Any) -> Any: """Deprecated since version 2.7, use :class:`VoiceData` instead.""" + else: + class RawData: def __init__(self, **kwargs: Any) -> None: - raise DeprecationWarning("RawData has been deprecated in favour of VoiceData") + raise DeprecationWarning( + "RawData has been deprecated in favour of VoiceData" + ) class FFmpegSink(Sink): @@ -413,14 +418,16 @@ def __init__( if isinstance(before_options, str): args.extend(shlex.split(before_options)) - args.extend({ - "-f": "s16le", - "-ar": "48000", - "-ac": "2", - "-i": "pipe:0", - "-loglevel": "warning", - "-blocksize": str(FFmpegAudio.BLOCKSIZE) - }) + args.extend( + { + "-f": "s16le", + "-ar": "48000", + "-ac": "2", + "-i": "pipe:0", + "-loglevel": "warning", + "-blocksize": str(FFmpegAudio.BLOCKSIZE), + } + ) if isinstance(options, str): args.extend(shlex.split(options)) @@ -440,14 +447,18 @@ def __init__( n = f"popen-stdout-reader:pid-{self._process.pid}" self._stdout = self._process.stdout _args = (self._stdout, self.buffer) - self._stdout_reader_thread = threading.Thread(target=self._pipe_reader, args=_args, daemon=True, name=n) + self._stdout_reader_thread = threading.Thread( + target=self._pipe_reader, args=_args, daemon=True, name=n + ) self._stdout_reader_thread.start() if piping_stderr: n = f"popen-stderr-reader:pid-{self._process.pid}" self._stderr = self._process.stderr _args = (self._stderr, stderr) - self._stderr_reader_thread = threading.Thread(target=self._pipe_reader, args=_args, daemon=True, name=n) + self._stderr_reader_thread = threading.Thread( + target=self._pipe_reader, args=_args, daemon=True, name=n + ) self._stderr_reader_thread.start() @staticmethod @@ -473,8 +484,9 @@ def write(self, user: User | Member | None, data: VoiceData) -> None: self._kill_processes() self.on_error(self, exc, data) - - def to_file(self, filename: str, /, *, description: str | None = None, spoiler: bool = False) -> File | None: + def to_file( + self, filename: str, /, *, description: str | None = None, spoiler: bool = False + ) -> File | None: """Returns the :class:`discord.File` of this sink. This is only applicable if this sink uses a ``buffer`` instead of a ``filename``. @@ -484,18 +496,29 @@ def to_file(self, filename: str, /, *, description: str | None = None, spoiler: This should be used only after the sink has stopped recording. """ if self.buffer is not MISSING: - fp = File(self.buffer.read(), filename=filename, description=description, spoiler=spoiler) + fp = File( + self.buffer.read(), + filename=filename, + description=description, + spoiler=spoiler, + ) return fp return None def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: - _log.debug("Spawning ffmpeg process with command %s and kwargs %s", args, subprocess_kwargs) + _log.debug( + "Spawning ffmpeg process with command %s and kwargs %s", + args, + subprocess_kwargs, + ) process = None try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs + ) except FileNotFoundError: - executable = args.partition(' ')[0] if isinstance(args, str) else args[0] + executable = args.partition(" ")[0] if isinstance(args, str) else args[0] raise FFmpegNotFound(f"{executable!r} executable was not found") from None except subprocess.SubprocessError as exc: raise Exception(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc @@ -532,11 +555,22 @@ def _kill_processes(self) -> None: ) if proc.poll() is None: - _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) + _log.info( + "ffmpeg process %s has not terminated. Waiting to terminate...", + proc.pid, + ) proc.communicate() - _log.info("ffmpeg process %s should have terminated with a return code of %s", proc.pid, proc.returncode) + _log.info( + "ffmpeg process %s should have terminated with a return code of %s", + proc.pid, + proc.returncode, + ) else: - _log.info("ffmpeg process %s successfully terminated with return code of %s", proc.pid, proc.returncode) + _log.info( + "ffmpeg process %s successfully terminated with return code of %s", + proc.pid, + proc.returncode, + ) self._process = MISSING @@ -551,7 +585,9 @@ def _pipe_reader(self, source: IO[bytes], dest: IO[bytes]) -> None: _log.debug("FFmpeg stdin pipe closed with exception %s", exc) return except Exception: - _log.debug("An error ocurred in %s, this can be ignored", self, exc_info=True) + _log.debug( + "An error ocurred in %s, this can be ignored", self, exc_info=True + ) return if data is None: @@ -560,7 +596,9 @@ def _pipe_reader(self, source: IO[bytes], dest: IO[bytes]) -> None: try: dest.write(data) except Exception as exc: - _log.exception("Error while writing to destination pipe %s", self, exc_info=exc) + _log.exception( + "Error while writing to destination pipe %s", self, exc_info=exc + ) self._kill_processes() self.on_error(self, exc, None) return @@ -594,10 +632,14 @@ def __init__( raise ValueError("filters must have at least one callback") if not isinstance(destination, SinkBase): - raise TypeError(f"expected a Sink object, got {destination.__class__.__name__}") + raise TypeError( + f"expected a Sink object, got {destination.__class__.__name__}" + ) self._filter_strat = all if filtering_mode == "all" else any - self.filters: Sequence[Callable[[User | Member | None, VoiceData], bool]] = filters + self.filters: Sequence[Callable[[User | Member | None, VoiceData], bool]] = ( + filters + ) self.destination: Sink = destination super().__init__(dest=destination) diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 2541b058b6..e1a1fd5ce1 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -62,7 +62,7 @@ class M4ASink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 04e13e82b6..4c9878ebeb 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -62,7 +62,7 @@ class MKASink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index e458ef29c4..c4b180bff9 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -62,7 +62,7 @@ class MKVSink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 503b95a1f9..6c62e8e8ef 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -62,7 +62,7 @@ class MP3Sink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not @@ -116,4 +116,3 @@ def __init__( options=options, error_hook=error_hook, ) # type: ignore - \ No newline at end of file diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index e3ca273d63..fe050ec696 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -62,7 +62,7 @@ class MP4Sink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index da279055ca..62b0717be7 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -62,7 +62,7 @@ class OGGSink(FFmpegSink): .. versionadded:: 2.7 stderr: IO[:class:`bytes`] | :data:`None` The stderr buffer in which will be written. Defaults to ``None``. - + .. versionadded:: 2.7 options: :class:`str` | :data:`None` The options to append to the ffmpeg executable flags. You should not diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index c2a1ee3299..b200ff3956 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -29,10 +29,11 @@ import wave from typing import TYPE_CHECKING -from discord.opus import Decoder from discord.file import File +from discord.opus import Decoder from .core import Sink + if TYPE_CHECKING: from discord import abc from discord.voice import VoiceData @@ -95,7 +96,9 @@ def is_opus(self) -> bool: def write(self, user: abc.User | None, data: VoiceData) -> None: self._file.writeframes(data.pcm) - def to_file(self, filename: str, /, *, description: str | None = None, spoiler: bool = False) -> File | None: + def to_file( + self, filename: str, /, *, description: str | None = None, spoiler: bool = False + ) -> File | None: """Returns the :class:`discord.File` of this sink. .. warning:: @@ -105,13 +108,18 @@ def to_file(self, filename: str, /, *, description: str | None = None, spoiler: f = wave.open(self._destination, "rb") data = f.readframes(f.getnframes()) - return File(io.BytesIO(data), filename, description=description, spoiler=spoiler) + return File( + io.BytesIO(data), filename, description=description, spoiler=spoiler + ) def cleanup(self) -> None: try: self._file.close() except Exception as exc: - _log.warning("An error ocurred while closing the wave writing file on cleanup", exc_info=exc) + _log.warning( + "An error ocurred while closing the wave writing file on cleanup", + exc_info=exc, + ) WavSink = WaveSink diff --git a/discord/voice/_types.py b/discord/voice/_types.py index 6fc393c7b0..71015900f2 100644 --- a/discord/voice/_types.py +++ b/discord/voice/_types.py @@ -42,9 +42,7 @@ ClientT = TypeVar("ClientT", bound="Client", covariant=True) -__all__ = ( - "VoiceProtocol", -) +__all__ = ("VoiceProtocol",) class VoiceProtocol(Generic[ClientT]): diff --git a/discord/voice/client.py b/discord/voice/client.py index 5bcb04c438..179aa42818 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -29,8 +29,8 @@ import datetime import logging import struct -from typing import TYPE_CHECKING, Any, Literal, overload import warnings +from typing import TYPE_CHECKING, Any, Literal, overload from discord import opus from discord.enums import SpeakingState, try_enum @@ -41,9 +41,9 @@ from discord.utils import MISSING from ._types import VoiceProtocol +from .enums import OpCodes from .receive import AudioReader from .state import VoiceConnectionState -from .enums import OpCodes if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -51,6 +51,7 @@ from discord import abc from discord.client import Client from discord.guild import Guild, VocalGuildChannel + from discord.member import Member from discord.opus import APPLICATION_CTL, BAND_CTL, SIGNAL_CTL, Encoder from discord.raw_models import ( RawVoiceServerUpdateEvent, @@ -59,7 +60,6 @@ from discord.state import ConnectionState from discord.types.voice import SupportedModes from discord.user import ClientUser, User - from discord.member import Member from .gateway import VoiceWebSocket from .receive.reader import AfterCallback @@ -78,9 +78,7 @@ except ImportError: has_nacl = False -__all__ = ( - "VoiceClient", -) +__all__ = ("VoiceClient",) class VoiceClient(VoiceProtocol): @@ -225,7 +223,9 @@ async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: for uid in uids: member = self.guild.get_member(uid) if not member: - _log.warning("Skipping member referencing ID %d on member_connect", uid) + _log.warning( + "Skipping member referencing ID %d on member_connect", uid + ) continue self.dispatch("member_connect", member) elif op == OpCodes.client_disconnect: @@ -242,7 +242,9 @@ async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: # maybe handle video and such things? - async def _run_event(self, coro, event_name: str, *args: Any, **kwargs: Any) -> None: + async def _run_event( + self, coro, event_name: str, *args: Any, **kwargs: Any + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -250,9 +252,13 @@ async def _run_event(self, coro, event_name: str, *args: Any, **kwargs: Any) -> except Exception: _log.exception("Error calling %s", event_name) - def _schedule_event(self, coro, event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + def _schedule_event( + self, coro, event_name: str, *args: Any, **kwargs: Any + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) - return self.client.loop.create_task(wrapped, name=f"voice-receiver-event-dispatch: {event_name}") + return self.client.loop.create_task( + wrapped, name=f"voice-receiver-event-dispatch: {event_name}" + ) def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None: _log.debug("Dispatching voice_client event %s", event) @@ -716,9 +722,13 @@ def start_recording( raise ClientException("Already recording audio") if len(args) > 0: - warnings.warn("'args' parameter is deprecated since 2.7 and will be removed in 3.0") + warnings.warn( + "'args' parameter is deprecated since 2.7 and will be removed in 3.0" + ) if sync_start is not MISSING: - warnings.warn("'sync_tart' parameter is deprecated since 2.7 and will be removed in 3.0") + warnings.warn( + "'sync_tart' parameter is deprecated since 2.7 and will be removed in 3.0" + ) self._reader = AudioReader(sink, self, after=callback) self._reader.start() diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index c7f63aacb9..e622b07cd7 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -35,8 +35,8 @@ from typing import TYPE_CHECKING, Any import aiohttp - import davey + from discord import utils from discord.enums import SpeakingState from discord.errors import ConnectionClosed @@ -46,8 +46,8 @@ from .enums import OpCodes if TYPE_CHECKING: - from typing_extensions import Self from _typeshed import ConvertibleToInt + from typing_extensions import Self from .state import VoiceConnectionState @@ -149,12 +149,14 @@ async def _hook(self, *args: Any) -> Any: async def send_as_bytes(self, op: ConvertibleToInt, data: bytes) -> None: packet = bytes([int(op)]) + data - _log.debug("Sending voice websocket binary frame: op: %s size: %d", op, len(data)) + _log.debug( + "Sending voice websocket binary frame: op: %s size: %d", op, len(data) + ) await self.ws.send_bytes(packet) async def send_as_json(self, data: Any) -> None: _log.debug("Sending voice websocket frame: %s.", data) - if data.get('op', None) == OpCodes.identify: + if data.get("op", None) == OpCodes.identify: _log.info("Identifying ourselves: %s", data) await self.ws.send_str(utils._to_json(data)) @@ -207,7 +209,12 @@ async def received_message(self, msg: Any, /): self._keep_alive.start() elif state.dave_session: if op == OpCodes.dave_prepare_transition: - _log.info("Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d", state.channel_id, data["transition_id"], data["protocol_version"]) + _log.info( + "Preparing to upgrade to a DAVE connection for channel %s for transition %d proto version %d", + state.channel_id, + data["transition_id"], + data["protocol_version"], + ) state.dave_pending_transition = data transition_id = data["transition_id"] @@ -219,11 +226,17 @@ async def received_message(self, msg: Any, /): state.dave_session.set_passthrough_mode(True, 120) await self.send_dave_transition_ready(transition_id) elif op == OpCodes.dave_execute_transition: - _log.info("Upgrading to DAVE connection for channel %s", state.channel_id) + _log.info( + "Upgrading to DAVE connection for channel %s", state.channel_id + ) await state.execute_dave_transition(data["transition_id"]) elif op == OpCodes.dave_prepare_epoch: epoch = data["epoch"] - _log.debug("Preparing for DAVE epoch in channel %s: %s", state.channel_id, epoch) + _log.debug( + "Preparing for DAVE epoch in channel %s: %s", + state.channel_id, + epoch, + ) # if epoch is 1 then a new MLS group is to be created for the proto version if epoch == 1: state.dave_protocol_version = data["protocol_version"] @@ -236,7 +249,12 @@ async def received_message(self, msg: Any, /): async def received_binary_message(self, msg: bytes) -> None: self.seq_ack = struct.unpack_from(">H", msg, 0)[0] op = msg[2] - _log.debug("Voice websocket binary frame received: %d bytes, seq: %s, op: %s", len(msg), self.seq_ack, op) + _log.debug( + "Voice websocket binary frame received: %d bytes, seq: %s, op: %s", + len(msg), + self.seq_ack, + op, + ) state = self.state @@ -248,12 +266,20 @@ async def received_binary_message(self, msg: bytes) -> None: elif op == OpCodes.mls_proposals: op_type = msg[3] result = state.dave_session.process_proposals( - davey.ProposalsOperationType.append if op_type == 0 else davey.ProposalsOperationType.revoke, + ( + davey.ProposalsOperationType.append + if op_type == 0 + else davey.ProposalsOperationType.revoke + ), msg[4:], ) if isinstance(result, davey.CommitWelcome): - data = (result.commit + result.welcome) if result.welcome else result.commit + data = ( + (result.commit + result.welcome) + if result.welcome + else result.commit + ) _log.debug("Sending MLS key package with data: %s", data) await self.send_as_bytes( OpCodes.mls_commit_welcome, @@ -269,11 +295,17 @@ async def received_binary_message(self, msg: bytes) -> None: "transition_id": transt_id, "protocol_version": state.dave_protocol_version, } - _log.debug("Sending DAVE transition ready from MLS commit transition with data: %s", state.dave_pending_transition) + _log.debug( + "Sending DAVE transition ready from MLS commit transition with data: %s", + state.dave_pending_transition, + ) await self.send_dave_transition_ready(transt_id) _log.debug("Processed MLS commit for transition %s", transt_id) except Exception as exc: - _log.debug("An exception ocurred while processing a MLS commit, this should be safe to ignore: %s", exc) + _log.debug( + "An exception ocurred while processing a MLS commit, this should be safe to ignore: %s", + exc, + ) await state.recover_dave_from_invalid_commit(transt_id) elif op == OpCodes.mls_welcome: transt_id = struct.unpack_from(">H", msg, 3)[0] @@ -284,11 +316,17 @@ async def received_binary_message(self, msg: bytes) -> None: "transition_id": transt_id, "protocol_version": state.dave_protocol_version, } - _log.debug("Sending DAVE transition ready from MLS welcome with data: %s", state.dave_pending_transition) + _log.debug( + "Sending DAVE transition ready from MLS welcome with data: %s", + state.dave_pending_transition, + ) await self.send_dave_transition_ready(transt_id) _log.debug("Processed MLS welcome for transition %s", transt_id) except Exception as exc: - _log.debug("An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s", exc) + _log.debug( + "An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s", + exc, + ) await state.recover_dave_from_invalid_commit(transt_id) async def ready(self, data: dict[str, Any]) -> None: diff --git a/discord/voice/packets/__init__.py b/discord/voice/packets/__init__.py index 454339d0bd..243189df61 100644 --- a/discord/voice/packets/__init__.py +++ b/discord/voice/packets/__init__.py @@ -8,11 +8,19 @@ from __future__ import annotations from typing import TYPE_CHECKING + from .core import Packet -from .rtp import RTPPacket, RTCPPacket, FakePacket, ReceiverReportPacket, SenderReportPacket, SilencePacket +from .rtp import ( + FakePacket, + ReceiverReportPacket, + RTCPPacket, + RTPPacket, + SenderReportPacket, + SilencePacket, +) if TYPE_CHECKING: - from discord import User, Member + from discord import Member, User __all__ = ( "Packet", @@ -41,10 +49,12 @@ class VoiceData: The PCM bytes of this source. """ - def __init__(self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None) -> None: + def __init__( + self, packet: Packet, source: User | Member | None, *, pcm: bytes | None = None + ) -> None: self.packet: Packet = packet self.source: User | Member | None = source - self.pcm: bytes = pcm if pcm else b'' + self.pcm: bytes = pcm if pcm else b"" @property def opus(self) -> bytes | None: diff --git a/discord/voice/packets/core.py b/discord/voice/packets/core.py index a372e860b7..8f7a3960ec 100644 --- a/discord/voice/packets/core.py +++ b/discord/voice/packets/core.py @@ -22,13 +22,15 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing_extensions import Final -OPUS_SILENCE: Final = b'\xf8\xff\xfe' +OPUS_SILENCE: Final = b"\xf8\xff\xfe" class Packet: @@ -57,25 +59,31 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented if self.ssrc != other.ssrc: - raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) return self.sequence == other.sequence and self.timestamp == other.timestamp def __gt__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented if self.ssrc != other.ssrc: - raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) return self.sequence > other.sequence and self.timestamp > other.timestamp def __lt__(self, other: object) -> bool: if not isinstance(other, self.__class__): return NotImplemented if self.ssrc != other.ssrc: - raise TypeError(f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})") + raise TypeError( + f"cannot compare two packets from different ssrc ({self.ssrc=}, {other.ssrc=})" + ) return self.sequence < other.sequence and self.timestamp < other.timestamp def is_silence(self) -> bool: - data = getattr(self, 'decrypted_data', None) + data = getattr(self, "decrypted_data", None) return data == OPUS_SILENCE def __hash__(self) -> int: diff --git a/discord/voice/packets/rtp.py b/discord/voice/packets/rtp.py index 6a05c16cf6..842a69fe2a 100644 --- a/discord/voice/packets/rtp.py +++ b/discord/voice/packets/rtp.py @@ -22,10 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations -from collections import namedtuple import struct +from collections import namedtuple from typing import TYPE_CHECKING, Any, Literal from .core import OPUS_SILENCE, Packet @@ -33,8 +34,8 @@ if TYPE_CHECKING: from typing_extensions import Final -MAX_UINT_32 = 0xffffffff -MAX_UINT_16 = 0xffff +MAX_UINT_32 = 0xFFFFFFFF +MAX_UINT_16 = 0xFFFF RTP_PACKET_TYPE_VOICE = 120 @@ -46,8 +47,8 @@ def decode(data: bytes) -> Packet: class FakePacket(Packet): - data = b'' - decrypted_data: bytes = b'' + data = b"" + decrypted_data: bytes = b"" extension_data: dict = {} def __init__( @@ -116,11 +117,11 @@ def __init__(self, data: bytes) -> None: self.data = data[12:] self.decrypted_data: bytes | None = None - self.nonce: bytes = b'' + self.nonce: bytes = b"" self._rtpsize: bool = False if self.cc: - fmt = '>%sI' % self.cc + fmt = ">%sI" % self.cc offset = struct.calcsize(fmt) + 12 self.csrcs = struct.unpack(fmt, data[12:offset]) self.data = data[offset:] @@ -152,7 +153,7 @@ def update_extended_header(self, data: bytes) -> int: if profile == self._ext_magic: self._parse_bede_header(data, length) - values = struct.unpack(">%sI" % length, data[4: 4 + length * 4]) + values = struct.unpack(">%sI" % length, data[4 : 4 + length * 4]) self.extension = self._ext_header(profile, length, values) offset = 4 + length * 4 @@ -168,7 +169,7 @@ def _parse_bede_header(self, data: bytes, length: int) -> None: while n < length: next_byte = data[offset : offset + 1] - if next_byte == b'\x00': + if next_byte == b"\x00": offset += 1 continue @@ -216,11 +217,11 @@ def from_data(cls, data: bytes) -> Packet: def _parse_low(x: int, bitlen: int = 32) -> float: - return x / 2.0 ** bitlen + return x / 2.0**bitlen def _to_low(x: float, bitlen: int = 32) -> int: - return int(x * 2.0 ** bitlen) + return int(x * 2.0**bitlen) class SenderReportPacket(RTCPPacket): @@ -228,7 +229,9 @@ class SenderReportPacket(RTCPPacket): _report_fmt = struct.Struct(">IB3x4I") _24bit_int_fmt = struct.Struct(">4xI") _info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count") - _report = namedtuple("RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr") + _report = namedtuple( + "RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr" + ) type = 200 if TYPE_CHECKING: @@ -265,7 +268,9 @@ def _read_report(self, data: bytes, offset: int) -> _report: class ReceiverReportPacket(RTCPPacket): _report_fmt = struct.Struct(">IB3x4I") _24bit_int_fmt = struct.Struct(">4xI") - _report = namedtuple("RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr") + _report = namedtuple( + "RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr" + ) type = 201 reports: tuple[_report, ...] diff --git a/discord/voice/receive/reader.py b/discord/voice/receive/reader.py index 7861a97291..0329b5371a 100644 --- a/discord/voice/receive/reader.py +++ b/discord/voice/receive/reader.py @@ -22,26 +22,30 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations -from collections.abc import Callable import logging -from operator import itemgetter import threading import time +from collections.abc import Callable +from operator import itemgetter from typing import TYPE_CHECKING, Any, Literal import davey + from discord.opus import PacketDecoder +from ..packets.rtp import ReceiverReportPacket, decode from .router import PacketRouter, SinkEventRouter -from ..packets.rtp import decode, ReceiverReportPacket try: import nacl.secret from nacl.exceptions import CryptoError except ImportError as exc: - raise RuntimeError("can't use voice receiver without PyNaCl installed, please install it with the 'py-cord[voice]' extra.") from exc + raise RuntimeError( + "can't use voice receiver without PyNaCl installed, please install it with the 'py-cord[voice]' extra." + ) from exc if TYPE_CHECKING: @@ -60,9 +64,7 @@ _log = logging.getLogger(__name__) -__all__ = ( - "AudioReader", -) +__all__ = ("AudioReader",) def is_rtcp(data: bytes) -> bool: @@ -70,9 +72,13 @@ def is_rtcp(data: bytes) -> bool: class AudioReader: - def __init__(self, sink: Sink, client: VoiceClient, *, after: AfterCallback | None = None) -> None: + def __init__( + self, sink: Sink, client: VoiceClient, *, after: AfterCallback | None = None + ) -> None: if after is not None and not callable(after): - raise TypeError(f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead") + raise TypeError( + f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead" + ) self.sink: Sink = sink self.client: VoiceClient = client @@ -84,7 +90,9 @@ def __init__(self, sink: Sink, client: VoiceClient, *, after: AfterCallback | No self.error: Exception | None = None self.packet_router: PacketRouter = PacketRouter(self.sink, self) self.event_router: SinkEventRouter = SinkEventRouter(self.sink, self) - self.decryptor: PacketDecryptor = PacketDecryptor(client.mode, bytes(client.secret_key), client) + self.decryptor: PacketDecryptor = PacketDecryptor( + client.mode, bytes(client.secret_key), client + ) self.speaking_timer: SpeakingTimer = SpeakingTimer(self) self.keep_alive: UDPKeepAlive = UDPKeepAlive(client) @@ -115,7 +123,9 @@ def stop(self) -> None: self.active = False self.speaking_timer.notify() - threading.Thread(target=self._stop, name=f"voice-receiver-audio-reader-stop:{id(self):#x}").start() + threading.Thread( + target=self._stop, name=f"voice-receiver-audio-reader-stop:{id(self):#x}" + ).start() def _stop(self) -> None: try: @@ -137,7 +147,9 @@ def _stop(self) -> None: try: self.after(self.error) except Exception: - _log.exception("An error ocurred while calling the after callback on audio reader") + _log.exception( + "An error ocurred while calling the after callback on audio reader" + ) for sink in self.sink.root.walk_children(with_self=True): try: @@ -168,7 +180,11 @@ def callback(self, packet_data: bytes) -> None: packet = rtcp_packet = decode(packet_data) if not isinstance(packet, ReceiverReportPacket): - _log.info("Received unexpected rtcp packet type=%s, %s", packet.type, type(packet)) + _log.info( + "Received unexpected rtcp packet type=%s, %s", + packet.type, + type(packet), + ) except CryptoError as exc: _log.error("CryptoError while decoding a voice packet", exc_info=exc) return @@ -176,7 +192,9 @@ def callback(self, packet_data: bytes) -> None: if self._is_ip_discovery_packet(packet_data): _log.debug("Received an IP Discovery Packet, ignoring...") return - _log.exception("An exception ocurred while decoding voice packets", exc_info=exc) + _log.exception( + "An exception ocurred while decoding voice packets", exc_info=exc + ) finally: if self.error: self.stop() @@ -193,14 +211,18 @@ def callback(self, packet_data: bytes) -> None: if rtp_packet.is_silence(): return else: - _log.info("Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet) + _log.info( + "Received a packet for unknown SSRC %s: %s", ssrc, rtp_packet + ) self.speaking_timer.notify(ssrc) try: self.packet_router.feed_rtp(rtp_packet) # type: ignore except Exception as exc: - _log.exception("An error ocurred while processing RTP packet %s", rtp_packet) + _log.exception( + "An error ocurred while processing RTP packet %s", rtp_packet + ) self.error = exc self.stop() @@ -213,13 +235,15 @@ class PacketDecryptor: "xsalsa20_poly1305_suffix", ] - def __init__(self, mode: SupportedModes, secret_key: bytes, client: VoiceClient) -> None: + def __init__( + self, mode: SupportedModes, secret_key: bytes, client: VoiceClient + ) -> None: self.mode: SupportedModes = mode self.client: VoiceClient = client try: - self._decryptor_rtp: DecryptRTP = getattr(self, '_decrypt_rtp_' + mode) - self._decryptor_rtcp: DecryptRTCP = getattr(self, '_decrypt_rtcp_' + mode) + self._decryptor_rtp: DecryptRTP = getattr(self, "_decrypt_rtp_" + mode) + self._decryptor_rtcp: DecryptRTCP = getattr(self, "_decrypt_rtcp_" + mode) except AttributeError as exc: raise NotImplementedError(mode) from exc @@ -237,7 +261,9 @@ def decrypt_rtp(self, packet: RTPPacket) -> bytes: data = self._decryptor_rtp(packet) if dave is not None and dave.ready and packet.ssrc in state.user_ssrc_map: - return dave.decrypt(state.user_ssrc_map[packet.ssrc], davey.MediaType.audio, data) + return dave.decrypt( + state.user_ssrc_map[packet.ssrc], davey.MediaType.audio, data + ) return data def decrypt_rtcp(self, packet: bytes) -> bytes: @@ -445,7 +471,9 @@ def run(self) -> None: continue try: - vc._connection.socket.sendto(packet, (vc._connection.endpoint_ip, vc._connection.voice_port)) + vc._connection.socket.sendto( + packet, (vc._connection.endpoint_ip, vc._connection.voice_port) + ) except Exception as exc: _log.debug( "Error while sending udp keep alive to socket %s at %s:%s", diff --git a/discord/voice/receive/router.py b/discord/voice/receive/router.py index 94e9f5bb72..0d2445540b 100644 --- a/discord/voice/receive/router.py +++ b/discord/voice/receive/router.py @@ -22,14 +22,15 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations +import logging +import queue +import threading from collections import deque from collections.abc import Callable -import threading -import logging from typing import TYPE_CHECKING, Any -import queue from discord.opus import PacketDecoder @@ -38,8 +39,8 @@ if TYPE_CHECKING: from discord.sinks import Sink + from ..packets import RTCPPacket, RTPPacket from .reader import AudioReader - from ..packets import RTPPacket, RTCPPacket EventCB = Callable[..., Any] EventData = tuple[str, tuple[Any, ...], dict[str, Any]] @@ -51,7 +52,7 @@ class PacketRouter(threading.Thread): def __init__(self, sink: Sink, reader: AudioReader) -> None: super().__init__( daemon=True, - name=f'voice-receiver-packet-router:{id(self):#x}', + name=f"voice-receiver-packet-router:{id(self):#x}", ) self.sink: Sink = sink @@ -75,7 +76,7 @@ def feed_rtp(self, packet: RTPPacket) -> None: def feed_rtcp(self, packet: RTCPPacket) -> None: guild = self.sink.client.guild if self.sink.client else None event_router = self.reader.event_router - event_router.dispatch('rtcp_packet', packet, guild) + event_router.dispatch("rtcp_packet", packet, guild) def get_decoder(self, ssrc: int) -> PacketDecoder | None: with self._lock: @@ -136,7 +137,9 @@ def _do_run(self) -> None: class SinkEventRouter(threading.Thread): def __init__(self, sink: Sink, reader: AudioReader) -> None: - super().__init__(daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}") + super().__init__( + daemon=True, name=f"voice-receiver-sink-event-router:{id(self):#x}" + ) self.sink: Sink = sink self.reader: AudioReader = reader @@ -197,7 +200,13 @@ def _dispatch_to_listeners(self, event: str, *args: Any, **kwargs: Any) -> None: try: listener(*args, **kwargs) except Exception as exc: - _log.exception("Unhandled exception while dispatching event %s (args: %s; kwargs: %s)", event, args, kwargs, exc_info=exc) + _log.exception( + "Unhandled exception while dispatching event %s (args: %s; kwargs: %s)", + event, + args, + kwargs, + exc_info=exc, + ) def stop(self) -> None: self._end_thread.set() diff --git a/discord/voice/state.py b/discord/voice/state.py index deba6ee9e3..9b17e8e5e8 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -29,11 +29,9 @@ import logging import select import socket -import struct import threading -import time from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, Any from discord import opus, utils from discord.backoff import ExponentialBackoff @@ -205,9 +203,7 @@ def _do_run(self) -> None: loop=self.state.loop, ) self.state._dispatch_task_set.add(task) - task.add_done_callback( - self.state._dispatch_task_set.discard - ) + task.add_done_callback(self.state._dispatch_task_set.discard) except Exception: _log.exception( "Error while calling %s in %s", @@ -915,7 +911,9 @@ def _update_voice_channel(self, channel_id: int | None) -> None: async def reinit_dave_session(self) -> None: if self.dave_protocol_version > 0: if self.dave_session: - self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.channel_id) + self.dave_session.reinit( + self.dave_protocol_version, self.user.id, self.channel_id + ) else: self.dave_session = davey.DaveSession( self.dave_protocol_version, @@ -959,8 +957,13 @@ async def execute_dave_transition(self, transition: int) -> None: old_version = self.dave_protocol_version self.dave_protocol_version = pending_proto - if old_version != self.dave_protocol_version and self.dave_protocol_version == 0: - _log.warning("DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7") + if ( + old_version != self.dave_protocol_version + and self.dave_protocol_version == 0 + ): + _log.warning( + "DAVE was downgraded, voice client non-e2ee session has been deprecated since 2.7" + ) self.downgraded_dave = True elif transition > 0 and self.downgraded_dave: self.downgraded_dave = False diff --git a/discord/voice/utils/buffer.py b/discord/voice/utils/buffer.py index 91c83dfe80..f12fb0ae6e 100644 --- a/discord/voice/utils/buffer.py +++ b/discord/voice/utils/buffer.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import heapq @@ -29,10 +30,8 @@ import threading from typing import Protocol, TypeVar -from .wrapped import gap_wrapped, add_wrapped - from ..packets import Packet - +from .wrapped import add_wrapped, gap_wrapped __all__ = ( "Buffer", @@ -82,7 +81,9 @@ def reset(self) -> None: class JitterBuffer(BaseBuff[PacketT]): _threshold: int = 10000 - def __init__(self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1) -> None: + def __init__( + self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1 + ) -> None: if max_size < 1: raise ValueError(f"max_size must be greater than 1, not {max_size}") @@ -95,7 +96,7 @@ def __init__(self, max_size: int = 10, *, pref_size: int = 1, prefill: int = 1) self._prefill: int = prefill self._last_tx_seq: int = -1 self._has_item: threading.Event = threading.Event() - #self._lock: threading.Lock = threading.Lock() + # self._lock: threading.Lock = threading.Lock() self._buffer: list[Packet] = [] def _push(self, packet: Packet) -> None: @@ -122,19 +123,26 @@ def _update_has_item(self) -> None: sequential = add_wrapped(self._last_tx_seq, 1) == next_packet.sequence positive_seq = self._last_tx_seq >= 0 - if (sequential and positive_seq) or not positive_seq or len(self._buffer) >= self.max_size: + if ( + (sequential and positive_seq) + or not positive_seq + or len(self._buffer) >= self.max_size + ): self._has_item.set() else: self._has_item.clear() def _cleanup(self) -> None: while len(self._buffer) > self.max_size: - packet = heapq.heappop(self._buffer) + heapq.heappop(self._buffer) def push(self, packet: Packet) -> bool: seq = packet.sequence - if gap_wrapped(self._last_tx_seq, seq) > self._threshold and self._last_tx_seq != -1: + if ( + gap_wrapped(self._last_tx_seq, seq) > self._threshold + and self._last_tx_seq != -1 + ): _log.debug("Dropping old packet %s", packet) return False @@ -178,7 +186,10 @@ def peek_next(self) -> Packet | None: if packet is None: return None - if packet.sequence == add_wrapped(self._last_tx_seq, 1) or self._last_tx_seq < 0: + if ( + packet.sequence == add_wrapped(self._last_tx_seq, 1) + or self._last_tx_seq < 0 + ): return packet def gap(self) -> int: diff --git a/discord/voice/utils/multidataevent.py b/discord/voice/utils/multidataevent.py index ea175f332b..e0079b5a54 100644 --- a/discord/voice/utils/multidataevent.py +++ b/discord/voice/utils/multidataevent.py @@ -22,6 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import threading @@ -75,4 +76,4 @@ def unregister(self, item: T) -> None: def clear(self) -> None: self._items.clear() - self._ready.clear() \ No newline at end of file + self._ready.clear() diff --git a/discord/voice/utils/wrapped.py b/discord/voice/utils/wrapped.py index 8d096a0713..fb2aafc630 100644 --- a/discord/voice/utils/wrapped.py +++ b/discord/voice/utils/wrapped.py @@ -23,9 +23,10 @@ DEALINGS IN THE SOFTWARE. """ + def gap_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: return (b - (a + 1) + wrap) % wrap def add_wrapped(a: int, b: int, *, wrap: int = 65536) -> int: - return (a + b) % wrap \ No newline at end of file + return (a + b) % wrap