1515//! The REDECRYPTOR is a layer that handles redecryption of events in case we
1616//! couldn't decrypt them imediatelly
1717
18- use std:: sync:: Weak ;
18+ use std:: { collections :: BTreeSet , pin :: Pin , sync:: Weak } ;
1919
2020use as_variant:: as_variant;
2121use futures_core:: Stream ;
@@ -25,30 +25,36 @@ use matrix_sdk_base::{
2525 deserialized_responses:: { DecryptedRoomEvent , TimelineEvent , TimelineEventKind } ,
2626} ;
2727use matrix_sdk_common:: executor:: spawn;
28- use ruma:: { OwnedEventId , RoomId , events:: AnySyncTimelineEvent , serde:: Raw } ;
29- use tokio:: task:: JoinHandle ;
30- use tokio_stream:: wrappers:: errors:: BroadcastStreamRecvError ;
31- use tracing:: { info, instrument, trace, warn} ;
32-
33- use crate :: {
34- Client ,
35- event_cache:: {
36- EventCache , EventCacheError , EventCacheInner , EventsOrigin , RoomEventCacheUpdate ,
37- } ,
28+ use ruma:: { OwnedEventId , OwnedRoomId , RoomId , events:: AnySyncTimelineEvent , serde:: Raw } ;
29+ use tokio:: { sync:: mpsc:: UnboundedSender , task:: JoinHandle } ;
30+ use tokio_stream:: wrappers:: { UnboundedReceiverStream , errors:: BroadcastStreamRecvError } ;
31+ use tracing:: { instrument, trace, warn} ;
32+
33+ use crate :: event_cache:: {
34+ EventCache , EventCacheError , EventCacheInner , EventsOrigin , RoomEventCacheUpdate ,
3835} ;
3936
37+ /// The information sent across the channel to the long-running task requesting
38+ /// that the supplied set of sessions be retried.
39+ pub struct DecryptionRetryRequest {
40+ room_id : OwnedRoomId ,
41+ session_ids : BTreeSet < String > ,
42+ }
43+
44+ type SessionId < ' a > = & ' a str ;
45+
4046impl EventCache {
4147 async fn get_utds (
4248 & self ,
43- room_key_info : & RoomKeyInfo ,
49+ room_id : & RoomId ,
50+ session_id : SessionId < ' _ > ,
4451 ) -> Result < Vec < ( OwnedEventId , Raw < AnySyncTimelineEvent > ) > , EventCacheError > {
4552 let filter_non_utds = |event : TimelineEvent | {
4653 let event_id = event. event_id ( ) ;
54+
4755 // We only care about events fort his particular room key, identified by the
4856 // session ID.
49- let session_id = event. kind . session_id ( ) ;
50-
51- if session_id == Some ( & room_key_info. session_id ) {
57+ if event. kind . session_id ( ) == Some ( session_id) {
5258 // Only pick out events that are UTDs, get just the Raw event as this is what
5359 // the OlmMachine needs.
5460 let event = as_variant ! ( event. kind, TimelineEventKind :: UnableToDecrypt { event, .. } => event) ;
@@ -65,7 +71,7 @@ impl EventCache {
6571 //
6672 // TODO: We can't load **all** events all the time.
6773 let store = self . inner . store . lock ( ) . await ?;
68- let events = store. get_room_events ( & room_key_info . room_id ) . await ?;
74+ let events = store. get_room_events ( & room_id) . await ?;
6975
7076 Ok ( events. into_iter ( ) . filter_map ( filter_non_utds) . collect ( ) )
7177 }
@@ -128,29 +134,32 @@ impl EventCache {
128134 /// Attempt to redecrypt events after a room key with the given session ID
129135 /// has been received.
130136 #[ instrument( skip_all, fields( room_key_info) ) ]
131- async fn retry_decryption ( & self , room_key_info : RoomKeyInfo ) -> Result < ( ) , EventCacheError > {
137+ async fn retry_decryption (
138+ & self ,
139+ room_id : & RoomId ,
140+ session_id : SessionId < ' _ > ,
141+ ) -> Result < ( ) , EventCacheError > {
132142 trace ! ( "Retrying to decrypt" ) ;
133143
134- let events = self . get_utds ( & room_key_info ) . await ?;
144+ let events = self . get_utds ( room_id , session_id ) . await ?;
135145 let mut decrypted_events = Vec :: with_capacity ( events. len ( ) ) ;
136146
137147 for ( event_id, event) in events {
138148 // If we managed to decrypt the event, and we should have to since we received
139149 // the room key for this specific event, then replace the event.
140- if let Some ( decrypted) =
141- self . decrypt_event ( & room_key_info. room_id , event. cast_ref_unchecked ( ) ) . await
142- {
150+ if let Some ( decrypted) = self . decrypt_event ( room_id, event. cast_ref_unchecked ( ) ) . await {
143151 decrypted_events. push ( ( event_id, decrypted) ) ;
144152 }
145153 }
146154
147- self . on_resolved_utds ( & room_key_info . room_id , decrypted_events) . await ?;
155+ self . on_resolved_utds ( room_id, decrypted_events) . await ?;
148156
149157 Ok ( ( ) )
150158 }
151159}
152160
153161pub ( crate ) struct Redecryptor {
162+ request_decryption_sender : UnboundedSender < DecryptionRetryRequest > ,
154163 task : JoinHandle < ( ) > ,
155164}
156165
@@ -161,49 +170,95 @@ impl Drop for Redecryptor {
161170}
162171
163172impl Redecryptor {
164- pub fn new ( client : Client , cache : Weak < EventCacheInner > ) -> Self {
173+ pub ( super ) fn new ( cache : Weak < EventCacheInner > ) -> Self {
174+ let ( request_decryption_sender, receiver) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
165175 let task = spawn ( async {
166- let stream = {
167- let machine = client. olm_machine ( ) . await ;
168- machine. as_ref ( ) . unwrap ( ) . store ( ) . room_keys_received_stream ( )
169- } ;
176+ let request_redecryption_stream = UnboundedReceiverStream :: new ( receiver) ;
170177
171- drop ( client) ;
178+ Self :: listen_for_room_keys_task ( cache, request_redecryption_stream) . await ;
179+ } ) ;
172180
173- Self :: listen_for_room_keys_task ( cache, stream) . await ;
181+ Self { task, request_decryption_sender }
182+ }
183+
184+ #[ allow( dead_code) ]
185+ pub ( super ) fn request_decryption ( & self , request : DecryptionRetryRequest ) {
186+ let _ = self . request_decryption_sender . send ( request) . inspect_err ( |_| {
187+ warn ! ( "Requesting a decryption while the redecryption task has been shut down" )
174188 } ) ;
189+ }
175190
176- Self { task }
191+ async fn subscribe_to_room_key_stream (
192+ cache : & Weak < EventCacheInner > ,
193+ ) -> Option < impl Stream < Item = Result < Vec < RoomKeyInfo > , BroadcastStreamRecvError > > > {
194+ let event_cache = cache. upgrade ( ) ?;
195+ let client = event_cache. client ( ) . ok ( ) ?;
196+ let machine = client. olm_machine ( ) . await ;
197+
198+ machine. as_ref ( ) . map ( |m| m. store ( ) . room_keys_received_stream ( ) )
199+ }
200+
201+ async fn listen_loop (
202+ cache : & Weak < EventCacheInner > ,
203+ decryption_request_stream : & mut Pin < & mut impl Stream < Item = DecryptionRetryRequest > > ,
204+ ) -> bool {
205+ let Some ( room_key_stream) = Self :: subscribe_to_room_key_stream ( cache) . await else {
206+ return false ;
207+ } ;
208+
209+ pin_mut ! ( room_key_stream) ;
210+
211+ // TODO: Listen to notifications that the Olm machine got recreated, this means
212+ // that our room key stream is effectively dead, we need to exit this
213+ // function with a `true` return value.
214+ loop {
215+ tokio:: select! {
216+ Some ( request) = decryption_request_stream. next( ) => {
217+ let Some ( event_cache) = cache. upgrade( ) else {
218+ break false ;
219+ } ;
220+
221+ let cache = EventCache { inner: event_cache } ;
222+
223+ for session_id in request. session_ids {
224+ let _ = cache
225+ . retry_decryption( & request. room_id, & session_id)
226+ . await
227+ . inspect_err( |e| warn!( "Error redecrypting {e:?}" ) ) ;
228+ }
229+ }
230+ Some ( room_keys) = room_key_stream. next( ) => {
231+ if let Ok ( room_keys) = room_keys {
232+ let Some ( event_cache) = cache. upgrade( ) else {
233+ break false ;
234+ } ;
235+
236+ let cache = EventCache { inner: event_cache } ;
237+
238+ for key in room_keys {
239+ let _ = cache
240+ . retry_decryption( & key. room_id, & key. session_id)
241+ . await
242+ . inspect_err( |e| warn!( "Error redecrypting {e:?}" ) ) ;
243+ }
244+ } else {
245+ todo!( "Decrypt all events?" ) ;
246+ }
247+ }
248+ else => break false ,
249+ }
250+ }
177251 }
178252
179253 async fn listen_for_room_keys_task (
180254 cache : Weak < EventCacheInner > ,
181- received_stream : impl Stream < Item = Result < Vec < RoomKeyInfo > , BroadcastStreamRecvError > > ,
255+ decryption_request_stream : UnboundedReceiverStream < DecryptionRetryRequest > ,
182256 ) {
183- pin_mut ! ( received_stream ) ;
257+ pin_mut ! ( decryption_request_stream ) ;
184258
185259 // TODO: We need to relisten to this stream if it dies due to the cross-process
186260 // lock reloading the Olm machine.
187- while let Some ( update) = received_stream. next ( ) . await {
188- if let Ok ( room_keys) = update {
189- let Some ( event_cache) = cache. upgrade ( ) else {
190- break ;
191- } ;
192-
193- let cache = EventCache { inner : event_cache } ;
194-
195- for key in room_keys {
196- let _ = cache
197- . retry_decryption ( key)
198- . await
199- . inspect_err ( |e| warn ! ( "Error redecrypting {e:?}" ) ) ;
200- }
201- } else {
202- todo ! ( "Redecrypt all visible events?" )
203- }
204- }
205-
206- info ! ( "Shutting down the event cache redecryptor" ) ;
261+ while Self :: listen_loop ( & cache, & mut decryption_request_stream) . await { }
207262 }
208263}
209264
0 commit comments