Skip to content

Commit 8304dba

Browse files
committed
add first pass to implement dave
1 parent a0f2a99 commit 8304dba

File tree

7 files changed

+223
-31
lines changed

7 files changed

+223
-31
lines changed

discord/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .invite import Invite
5151
from .iterators import EntitlementIterator, GuildIterator
5252
from .mentions import AllowedMentions
53-
from .monetization import SKU, Entitlement
53+
from .monetization import SKU
5454
from .object import Object
5555
from .soundboard import SoundboardSound
5656
from .stage_instance import StageInstance

discord/gateway.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242
from . import utils
4343
from .activity import BaseActivity
44-
from .enums import SpeakingState
4544
from .errors import ConnectionClosed, InvalidArgument
4645

4746
if TYPE_CHECKING:
@@ -55,8 +54,6 @@
5554
__all__ = (
5655
"DiscordWebSocket",
5756
"KeepAliveHandler",
58-
"VoiceKeepAliveHandler",
59-
"DiscordVoiceWebSocket",
6057
"ReconnectWebSocket",
6158
)
6259

@@ -228,31 +225,6 @@ def ack(self) -> None:
228225
_log.warning(self.behind_msg, self.shard_id, self.latency)
229226

230227

231-
class VoiceKeepAliveHandler(KeepAliveHandler):
232-
if TYPE_CHECKING:
233-
ws: DiscordVoiceWebSocket
234-
235-
def __init__(self, *args, **kwargs):
236-
super().__init__(*args, **kwargs)
237-
self.recent_ack_latencies = deque(maxlen=20)
238-
self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s."
239-
self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds"
240-
self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind"
241-
242-
def get_payload(self):
243-
return {
244-
"op": self.ws.HEARTBEAT,
245-
"d": {"t": int(time.time() * 1000), "seq_ack": self.ws.seq_ack},
246-
}
247-
248-
def ack(self):
249-
ack_time = time.perf_counter()
250-
self._last_ack = ack_time
251-
self._last_recv = ack_time
252-
self.latency = ack_time - self._last_send
253-
self.recent_ack_latencies.append(self.latency)
254-
255-
256228
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
257229
async def close(self, *, code: int = 4000, message: bytes = b"") -> bool:
258230
return await super().close(code=code, message=message)

discord/voice/client.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,14 @@ def average_latency(self) -> float:
241241
ws = self.ws
242242
return float("inf") if not ws else ws.average_latency
243243

244+
@property
245+
def privacy_code(self) -> str | None:
246+
"""Returns the current voice session's privacy code, only available if the call has upgraded to use the
247+
DAVE protocol
248+
"""
249+
session = self._connection.dave_session
250+
return session and session.voice_privacy_code
251+
244252
async def disconnect(self, *, force: bool = False) -> None:
245253
"""|coro|
246254
@@ -288,6 +296,10 @@ def is_paused(self) -> bool:
288296
# audio related
289297

290298
def _get_voice_packet(self, data: Any) -> bytes:
299+
300+
session = self._connection.dave_session
301+
packet = session.encrypt_opus(data) if session and session.ready else data
302+
291303
header = bytearray(12)
292304

293305
# formulate rtp header
@@ -298,7 +310,7 @@ def _get_voice_packet(self, data: Any) -> bytes:
298310
struct.pack_into(">I", header, 8, self.ssrc)
299311

300312
encrypt_packet = getattr(self, f"_encrypt_{self.mode}")
301-
return encrypt_packet(header, data)
313+
return encrypt_packet(header, packet)
302314

303315
# encryption methods
304316

discord/voice/enums.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ class OpCodes(Enum):
4242
client_connect = 10
4343
client_disconnect = 11
4444

45+
# dave protocol stuff
46+
dave_prepare_transition = 21
47+
dave_execute_transition = 22
48+
dave_transition_ready = 23
49+
dave_prepare_epoch = 24
50+
mls_external_sender_package = 25
51+
mls_key_package = 26
52+
mls_proposals = 27
53+
mls_commit_welcome = 28
54+
mls_commit_transition = 29
55+
mls_welcome = 30
56+
mls_invalid_commit_welcome = 31
57+
4558
def __eq__(self, other: object) -> bool:
4659
if isinstance(other, int):
4760
return self.value == other

discord/voice/gateway.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import aiohttp
3838

39+
import davey
3940
from discord import utils
4041
from discord.enums import SpeakingState
4142
from discord.errors import ConnectionClosed
@@ -117,6 +118,7 @@ def __init__(
117118
self.seq_ack: int = -1
118119
self.state: VoiceConnectionState = state
119120
self.ssrc_map: dict[str, dict[str, Any]] = {}
121+
self.known_users: dict[int, Any] = {}
120122

121123
if hook:
122124
self._hook = hook or state.ws_hook # type: ignore
@@ -137,9 +139,22 @@ def session_id(self) -> str | None:
137139
def session_id(self, value: str | None) -> None:
138140
self.state.session_id = value
139141

142+
@property
143+
def dave_session(self) -> davey.DaveSession | None:
144+
return self.state.dave_session
145+
146+
@property
147+
def self_id(self) -> int:
148+
return self._connection.self_id
149+
140150
async def _hook(self, *args: Any) -> Any:
141151
pass
142152

153+
async def send_as_bytes(self, op: int, data: bytes) -> None:
154+
packet = bytes(op) + data
155+
_log.debug("Sending voice websocket binary frame: op: %s data: %s", op, str(data))
156+
await self.ws.send_bytes(packet)
157+
143158
async def send_as_json(self, data: Any) -> None:
144159
_log.debug("Sending voice websocket frame: %s.", data)
145160
await self.ws.send_str(utils._to_json(data))
@@ -163,6 +178,7 @@ async def received_message(self, msg: Any, /):
163178
op = msg["op"]
164179
data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways
165180
self.seq_ack = msg.get("seq", self.seq_ack) # keep the seq_ack updated
181+
state = self.state
166182

167183
if op == OpCodes.ready:
168184
await self.ready(data)
@@ -179,18 +195,99 @@ async def received_message(self, msg: Any, /):
179195
"successfully RESUMED.",
180196
)
181197
elif op == OpCodes.session_description:
182-
self.state.mode = data["mode"]
198+
state.mode = data["mode"]
199+
state.dave_protocol_version = data["dave_protocol_version"]
183200
await self.load_secret_key(data)
201+
await state.reinit_dave_session()
184202
elif op == OpCodes.hello:
185203
interval = data["heartbeat_interval"] / 1000.0
186204
self._keep_alive = KeepAliveHandler(
187205
ws=self,
188206
interval=min(interval, 5),
189207
)
190208
self._keep_alive.start()
209+
elif self.dave_session:
210+
if op == OpCodes.dave_prepare_transition:
211+
_log.info("Preparing to upgrade to a DAVE connection for channel %s", state.channel_id)
212+
state.dave_pending_transition = data
213+
214+
transition_id = data["transition_id"]
215+
216+
if transition_id == 0:
217+
await state.execute_dave_transition(data["transition_id"])
218+
else:
219+
if data["protocol_version"] == 0:
220+
self.dave_session.set_passthrough_mode(True, 120)
221+
await self.send_dave_transition_ready(transition_id)
222+
elif op == OpCodes.dave_execute_transition:
223+
_log.info("Upgrading to DAVE connection for channel %s", state.channel_id)
224+
await state.execute_dave_transition(data["transition_id"])
225+
elif op == OpCodes.dave_prepare_epoch:
226+
epoch = data["epoch"]
227+
_log.debug("Preparing for DAVE epoch in channel %s: %s", state.channel_id, epoch)
228+
# if epoch is 1 then a new MLS group is to be created for the proto version
229+
if epoch == 1:
230+
state.dave_protocol_version = data["protocol_version"]
231+
await state.reinit_dave_session()
232+
else:
233+
_log.debug("Unhandled op code: %s with data %s", op, data)
191234

192235
await utils.maybe_coroutine(self._hook, self, msg)
193236

237+
async def received_binary_message(self, msg: bytes) -> None:
238+
self.seq_ack = struct.unpack_from(">H", msg, 0)[0]
239+
op = msg[2]
240+
_log.debug("Voice websocket binary frame received: %d bytes, seq: %s, op: %s", len(msg), self.seq_ack, op)
241+
242+
state = self.state
243+
244+
if not self.dave_session:
245+
return
246+
247+
if op == OpCodes.mls_external_sender_package:
248+
self.dave_session.set_external_sender(msg[3:])
249+
elif op == OpCodes.mls_proposals:
250+
op_type = msg[3]
251+
result = self.dave_session.process_proposals(
252+
davey.ProposalsOperationType.append if op_type == 0 else davey.ProposalsOperationType.revoke,
253+
msg[4:],
254+
)
255+
256+
if isinstance(result, davey.CommitWelcome):
257+
await self.send_as_bytes(
258+
OpCodes.mls_key_package.value,
259+
(result.commit + result.welcome) if result.welcome else result.commit,
260+
)
261+
_log.debug("Processed MLS proposals for current dave session")
262+
elif op == OpCodes.mls_commit_transition:
263+
transt_id = struct.unpack_from(">H", msg, 3)[0]
264+
try:
265+
self.dave_session.process_commit(msg[5:])
266+
if transt_id != 0:
267+
state.dave_pending_transition = {
268+
"transition_id": transt_id,
269+
"protocol_version": state.dave_protocol_version,
270+
}
271+
await self.send_dave_transition_ready(transt_id)
272+
_log.debug("Processed MLS commit for transition %s", transt_id)
273+
except Exception as exc:
274+
_log.debug("An exception ocurred while processing a MLS commit, this should be safe to ignore: %s", exc)
275+
await state.recover_dave_from_invalid_commit(transt_id)
276+
elif op == OpCodes.mls_welcome:
277+
transt_id = struct.unpack_from(">H", msg, 3)[0]
278+
try:
279+
self.dave_session.process_welcome(msg[5:])
280+
if transt_id != 0:
281+
state.dave_pending_transition = {
282+
"transition_id": transt_id,
283+
"protocol_version": state.dave_protocol_version,
284+
}
285+
await self.send_dave_transition_ready(transt_id)
286+
_log.debug("Processed MLS welcome for transition %s", transt_id)
287+
except Exception as exc:
288+
_log.debug("An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s", exc)
289+
await state.recover_dave_from_invalid_commit(transt_id)
290+
194291
async def ready(self, data: dict[str, Any]) -> None:
195292
state = self.state
196293

@@ -232,6 +329,7 @@ async def select_protocol(self, ip: str, port: int, mode: str) -> None:
232329
"port": port,
233330
"mode": mode,
234331
},
332+
"dave_protocol_version": self.state.dave_protocol_version,
235333
},
236334
}
237335
await self.send_as_json(payload)
@@ -292,6 +390,8 @@ async def poll_event(self) -> None:
292390

293391
if msg.type is aiohttp.WSMsgType.TEXT:
294392
await self.received_message(utils._from_json(msg.data))
393+
elif msg.type is aiohttp.WSMsgType.BINARY:
394+
await self.received_binary_message(msg.data)
295395
elif msg.type is aiohttp.WSMsgType.ERROR:
296396
_log.debug("Received %s", msg)
297397
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
@@ -355,6 +455,16 @@ async def identify(self) -> None:
355455
"user_id": str(state.user.id),
356456
"session_id": self.session_id,
357457
"token": self.token,
458+
"max_dave_protocol_version": self.state.max_dave_proto_version,
459+
},
460+
}
461+
await self.send_as_json(payload)
462+
463+
async def send_dave_transition_ready(self, transition_id: int) -> None:
464+
payload = {
465+
"op": int(OpCodes.dave_transition_ready),
466+
"d": {
467+
"transition_id": transition_id,
358468
},
359469
}
360470
await self.send_as_json(payload)

0 commit comments

Comments
 (0)