3636
3737import aiohttp
3838
39+ import davey
3940from discord import utils
4041from discord .enums import SpeakingState
4142from discord .errors import ConnectionClosed
@@ -117,6 +118,7 @@ def __init__(
117118 self .seq_ack : int = - 1
118119 self .state : VoiceConnectionState = state
119120 self .ssrc_map : dict [str , dict [str , Any ]] = {}
121+ self .known_users : dict [int , Any ] = {}
120122
121123 if hook :
122124 self ._hook = hook or state .ws_hook # type: ignore
@@ -137,9 +139,22 @@ def session_id(self) -> str | None:
137139 def session_id (self , value : str | None ) -> None :
138140 self .state .session_id = value
139141
142+ @property
143+ def dave_session (self ) -> davey .DaveSession | None :
144+ return self .state .dave_session
145+
146+ @property
147+ def self_id (self ) -> int :
148+ return self ._connection .self_id
149+
140150 async def _hook (self , * args : Any ) -> Any :
141151 pass
142152
153+ async def send_as_bytes (self , op : int , data : bytes ) -> None :
154+ packet = bytes (op ) + data
155+ _log .debug ("Sending voice websocket binary frame: op: %s data: %s" , op , str (data ))
156+ await self .ws .send_bytes (packet )
157+
143158 async def send_as_json (self , data : Any ) -> None :
144159 _log .debug ("Sending voice websocket frame: %s." , data )
145160 await self .ws .send_str (utils ._to_json (data ))
@@ -163,6 +178,7 @@ async def received_message(self, msg: Any, /):
163178 op = msg ["op" ]
164179 data = msg .get ("d" , {}) # this key should ALWAYS be given, but guard anyways
165180 self .seq_ack = msg .get ("seq" , self .seq_ack ) # keep the seq_ack updated
181+ state = self .state
166182
167183 if op == OpCodes .ready :
168184 await self .ready (data )
@@ -179,18 +195,99 @@ async def received_message(self, msg: Any, /):
179195 "successfully RESUMED." ,
180196 )
181197 elif op == OpCodes .session_description :
182- self .state .mode = data ["mode" ]
198+ state .mode = data ["mode" ]
199+ state .dave_protocol_version = data ["dave_protocol_version" ]
183200 await self .load_secret_key (data )
201+ await state .reinit_dave_session ()
184202 elif op == OpCodes .hello :
185203 interval = data ["heartbeat_interval" ] / 1000.0
186204 self ._keep_alive = KeepAliveHandler (
187205 ws = self ,
188206 interval = min (interval , 5 ),
189207 )
190208 self ._keep_alive .start ()
209+ elif self .dave_session :
210+ if op == OpCodes .dave_prepare_transition :
211+ _log .info ("Preparing to upgrade to a DAVE connection for channel %s" , state .channel_id )
212+ state .dave_pending_transition = data
213+
214+ transition_id = data ["transition_id" ]
215+
216+ if transition_id == 0 :
217+ await state .execute_dave_transition (data ["transition_id" ])
218+ else :
219+ if data ["protocol_version" ] == 0 :
220+ self .dave_session .set_passthrough_mode (True , 120 )
221+ await self .send_dave_transition_ready (transition_id )
222+ elif op == OpCodes .dave_execute_transition :
223+ _log .info ("Upgrading to DAVE connection for channel %s" , state .channel_id )
224+ await state .execute_dave_transition (data ["transition_id" ])
225+ elif op == OpCodes .dave_prepare_epoch :
226+ epoch = data ["epoch" ]
227+ _log .debug ("Preparing for DAVE epoch in channel %s: %s" , state .channel_id , epoch )
228+ # if epoch is 1 then a new MLS group is to be created for the proto version
229+ if epoch == 1 :
230+ state .dave_protocol_version = data ["protocol_version" ]
231+ await state .reinit_dave_session ()
232+ else :
233+ _log .debug ("Unhandled op code: %s with data %s" , op , data )
191234
192235 await utils .maybe_coroutine (self ._hook , self , msg )
193236
237+ async def received_binary_message (self , msg : bytes ) -> None :
238+ self .seq_ack = struct .unpack_from (">H" , msg , 0 )[0 ]
239+ op = msg [2 ]
240+ _log .debug ("Voice websocket binary frame received: %d bytes, seq: %s, op: %s" , len (msg ), self .seq_ack , op )
241+
242+ state = self .state
243+
244+ if not self .dave_session :
245+ return
246+
247+ if op == OpCodes .mls_external_sender_package :
248+ self .dave_session .set_external_sender (msg [3 :])
249+ elif op == OpCodes .mls_proposals :
250+ op_type = msg [3 ]
251+ result = self .dave_session .process_proposals (
252+ davey .ProposalsOperationType .append if op_type == 0 else davey .ProposalsOperationType .revoke ,
253+ msg [4 :],
254+ )
255+
256+ if isinstance (result , davey .CommitWelcome ):
257+ await self .send_as_bytes (
258+ OpCodes .mls_key_package .value ,
259+ (result .commit + result .welcome ) if result .welcome else result .commit ,
260+ )
261+ _log .debug ("Processed MLS proposals for current dave session" )
262+ elif op == OpCodes .mls_commit_transition :
263+ transt_id = struct .unpack_from (">H" , msg , 3 )[0 ]
264+ try :
265+ self .dave_session .process_commit (msg [5 :])
266+ if transt_id != 0 :
267+ state .dave_pending_transition = {
268+ "transition_id" : transt_id ,
269+ "protocol_version" : state .dave_protocol_version ,
270+ }
271+ await self .send_dave_transition_ready (transt_id )
272+ _log .debug ("Processed MLS commit for transition %s" , transt_id )
273+ except Exception as exc :
274+ _log .debug ("An exception ocurred while processing a MLS commit, this should be safe to ignore: %s" , exc )
275+ await state .recover_dave_from_invalid_commit (transt_id )
276+ elif op == OpCodes .mls_welcome :
277+ transt_id = struct .unpack_from (">H" , msg , 3 )[0 ]
278+ try :
279+ self .dave_session .process_welcome (msg [5 :])
280+ if transt_id != 0 :
281+ state .dave_pending_transition = {
282+ "transition_id" : transt_id ,
283+ "protocol_version" : state .dave_protocol_version ,
284+ }
285+ await self .send_dave_transition_ready (transt_id )
286+ _log .debug ("Processed MLS welcome for transition %s" , transt_id )
287+ except Exception as exc :
288+ _log .debug ("An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s" , exc )
289+ await state .recover_dave_from_invalid_commit (transt_id )
290+
194291 async def ready (self , data : dict [str , Any ]) -> None :
195292 state = self .state
196293
@@ -232,6 +329,7 @@ async def select_protocol(self, ip: str, port: int, mode: str) -> None:
232329 "port" : port ,
233330 "mode" : mode ,
234331 },
332+ "dave_protocol_version" : self .state .dave_protocol_version ,
235333 },
236334 }
237335 await self .send_as_json (payload )
@@ -292,6 +390,8 @@ async def poll_event(self) -> None:
292390
293391 if msg .type is aiohttp .WSMsgType .TEXT :
294392 await self .received_message (utils ._from_json (msg .data ))
393+ elif msg .type is aiohttp .WSMsgType .BINARY :
394+ await self .received_binary_message (msg .data )
295395 elif msg .type is aiohttp .WSMsgType .ERROR :
296396 _log .debug ("Received %s" , msg )
297397 raise ConnectionClosed (self .ws , shard_id = None ) from msg .data
@@ -355,6 +455,16 @@ async def identify(self) -> None:
355455 "user_id" : str (state .user .id ),
356456 "session_id" : self .session_id ,
357457 "token" : self .token ,
458+ "max_dave_protocol_version" : self .state .max_dave_proto_version ,
459+ },
460+ }
461+ await self .send_as_json (payload )
462+
463+ async def send_dave_transition_ready (self , transition_id : int ) -> None :
464+ payload = {
465+ "op" : int (OpCodes .dave_transition_ready ),
466+ "d" : {
467+ "transition_id" : transition_id ,
358468 },
359469 }
360470 await self .send_as_json (payload )
0 commit comments