Skip to content

Commit a62fea4

Browse files
committed
Rejigger things so we can relisten to the room key stream
1 parent e2ab634 commit a62fea4

File tree

1 file changed

+107
-52
lines changed

1 file changed

+107
-52
lines changed

crates/matrix-sdk/src/event_cache/redecryptor.rs

Lines changed: 107 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
//! The REDECRYPTOR is a layer that handles redecryption of events in case we
1616
//! couldn't decrypt them imediatelly
1717
18-
use std::sync::Weak;
18+
use std::{collections::BTreeSet, pin::Pin, sync::Weak};
1919

2020
use as_variant::as_variant;
2121
use futures_core::Stream;
@@ -25,30 +25,36 @@ use matrix_sdk_base::{
2525
deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
2626
};
2727
use matrix_sdk_common::executor::spawn;
28-
use ruma::{OwnedEventId, RoomId, events::AnySyncTimelineEvent, serde::Raw};
29-
use tokio::task::JoinHandle;
30-
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
31-
use tracing::{info, instrument, trace, warn};
32-
33-
use crate::{
34-
Client,
35-
event_cache::{
36-
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
37-
},
28+
use ruma::{OwnedEventId, OwnedRoomId, RoomId, events::AnySyncTimelineEvent, serde::Raw};
29+
use tokio::{sync::mpsc::UnboundedSender, task::JoinHandle};
30+
use tokio_stream::wrappers::{UnboundedReceiverStream, errors::BroadcastStreamRecvError};
31+
use tracing::{instrument, trace, warn};
32+
33+
use crate::event_cache::{
34+
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
3835
};
3936

37+
/// The information sent across the channel to the long-running task requesting
38+
/// that the supplied set of sessions be retried.
39+
pub struct DecryptionRetryRequest {
40+
room_id: OwnedRoomId,
41+
session_ids: BTreeSet<String>,
42+
}
43+
44+
type SessionId<'a> = &'a str;
45+
4046
impl EventCache {
4147
async fn get_utds(
4248
&self,
43-
room_key_info: &RoomKeyInfo,
49+
room_id: &RoomId,
50+
session_id: SessionId<'_>,
4451
) -> Result<Vec<(OwnedEventId, Raw<AnySyncTimelineEvent>)>, EventCacheError> {
4552
let filter_non_utds = |event: TimelineEvent| {
4653
let event_id = event.event_id();
54+
4755
// We only care about events fort his particular room key, identified by the
4856
// session ID.
49-
let session_id = event.kind.session_id();
50-
51-
if session_id == Some(&room_key_info.session_id) {
57+
if event.kind.session_id() == Some(session_id) {
5258
// Only pick out events that are UTDs, get just the Raw event as this is what
5359
// the OlmMachine needs.
5460
let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
@@ -65,7 +71,7 @@ impl EventCache {
6571
//
6672
// TODO: We can't load **all** events all the time.
6773
let store = self.inner.store.lock().await?;
68-
let events = store.get_room_events(&room_key_info.room_id).await?;
74+
let events = store.get_room_events(&room_id).await?;
6975

7076
Ok(events.into_iter().filter_map(filter_non_utds).collect())
7177
}
@@ -128,29 +134,32 @@ impl EventCache {
128134
/// Attempt to redecrypt events after a room key with the given session ID
129135
/// has been received.
130136
#[instrument(skip_all, fields(room_key_info))]
131-
async fn retry_decryption(&self, room_key_info: RoomKeyInfo) -> Result<(), EventCacheError> {
137+
async fn retry_decryption(
138+
&self,
139+
room_id: &RoomId,
140+
session_id: SessionId<'_>,
141+
) -> Result<(), EventCacheError> {
132142
trace!("Retrying to decrypt");
133143

134-
let events = self.get_utds(&room_key_info).await?;
144+
let events = self.get_utds(room_id, session_id).await?;
135145
let mut decrypted_events = Vec::with_capacity(events.len());
136146

137147
for (event_id, event) in events {
138148
// If we managed to decrypt the event, and we should have to since we received
139149
// the room key for this specific event, then replace the event.
140-
if let Some(decrypted) =
141-
self.decrypt_event(&room_key_info.room_id, event.cast_ref_unchecked()).await
142-
{
150+
if let Some(decrypted) = self.decrypt_event(room_id, event.cast_ref_unchecked()).await {
143151
decrypted_events.push((event_id, decrypted));
144152
}
145153
}
146154

147-
self.on_resolved_utds(&room_key_info.room_id, decrypted_events).await?;
155+
self.on_resolved_utds(room_id, decrypted_events).await?;
148156

149157
Ok(())
150158
}
151159
}
152160

153161
pub(crate) struct Redecryptor {
162+
request_decryption_sender: UnboundedSender<DecryptionRetryRequest>,
154163
task: JoinHandle<()>,
155164
}
156165

@@ -161,49 +170,95 @@ impl Drop for Redecryptor {
161170
}
162171

163172
impl Redecryptor {
164-
pub fn new(client: Client, cache: Weak<EventCacheInner>) -> Self {
173+
pub(super) fn new(cache: Weak<EventCacheInner>) -> Self {
174+
let (request_decryption_sender, receiver) = tokio::sync::mpsc::unbounded_channel();
165175
let task = spawn(async {
166-
let stream = {
167-
let machine = client.olm_machine().await;
168-
machine.as_ref().unwrap().store().room_keys_received_stream()
169-
};
176+
let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
170177

171-
drop(client);
178+
Self::listen_for_room_keys_task(cache, request_redecryption_stream).await;
179+
});
172180

173-
Self::listen_for_room_keys_task(cache, stream).await;
181+
Self { task, request_decryption_sender }
182+
}
183+
184+
#[allow(dead_code)]
185+
pub(super) fn request_decryption(&self, request: DecryptionRetryRequest) {
186+
let _ = self.request_decryption_sender.send(request).inspect_err(|_| {
187+
warn!("Requesting a decryption while the redecryption task has been shut down")
174188
});
189+
}
175190

176-
Self { task }
191+
async fn subscribe_to_room_key_stream(
192+
cache: &Weak<EventCacheInner>,
193+
) -> Option<impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>> {
194+
let event_cache = cache.upgrade()?;
195+
let client = event_cache.client().ok()?;
196+
let machine = client.olm_machine().await;
197+
198+
machine.as_ref().map(|m| m.store().room_keys_received_stream())
199+
}
200+
201+
async fn listen_loop(
202+
cache: &Weak<EventCacheInner>,
203+
decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
204+
) -> bool {
205+
let Some(room_key_stream) = Self::subscribe_to_room_key_stream(cache).await else {
206+
return false;
207+
};
208+
209+
pin_mut!(room_key_stream);
210+
211+
// TODO: Listen to notifications that the Olm machine got recreated, this means
212+
// that our room key stream is effectively dead, we need to exit this
213+
// function with a `true` return value.
214+
loop {
215+
tokio::select! {
216+
Some(request) = decryption_request_stream.next() => {
217+
let Some(event_cache) = cache.upgrade() else {
218+
break false;
219+
};
220+
221+
let cache = EventCache { inner: event_cache };
222+
223+
for session_id in request.session_ids {
224+
let _ = cache
225+
.retry_decryption(&request.room_id, &session_id)
226+
.await
227+
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
228+
}
229+
}
230+
Some(room_keys) = room_key_stream.next() => {
231+
if let Ok(room_keys) = room_keys {
232+
let Some(event_cache) = cache.upgrade() else {
233+
break false;
234+
};
235+
236+
let cache = EventCache { inner: event_cache };
237+
238+
for key in room_keys {
239+
let _ = cache
240+
.retry_decryption(&key.room_id, &key.session_id)
241+
.await
242+
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
243+
}
244+
} else {
245+
todo!("Decrypt all events?");
246+
}
247+
}
248+
else => break false,
249+
}
250+
}
177251
}
178252

179253
async fn listen_for_room_keys_task(
180254
cache: Weak<EventCacheInner>,
181-
received_stream: impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
255+
decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
182256
) {
183-
pin_mut!(received_stream);
257+
pin_mut!(decryption_request_stream);
184258

185259
// TODO: We need to relisten to this stream if it dies due to the cross-process
186260
// lock reloading the Olm machine.
187-
while let Some(update) = received_stream.next().await {
188-
if let Ok(room_keys) = update {
189-
let Some(event_cache) = cache.upgrade() else {
190-
break;
191-
};
192-
193-
let cache = EventCache { inner: event_cache };
194-
195-
for key in room_keys {
196-
let _ = cache
197-
.retry_decryption(key)
198-
.await
199-
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
200-
}
201-
} else {
202-
todo!("Redecrypt all visible events?")
203-
}
204-
}
205-
206-
info!("Shutting down the event cache redecryptor");
261+
while Self::listen_loop(&cache, &mut decryption_request_stream).await {}
207262
}
208263
}
209264

0 commit comments

Comments
 (0)