Skip to content

Commit 54f3b5f

Browse files
committed
feat(cache): Let the redecryptor listen to room key withheld updates
1 parent bab9d9b commit 54f3b5f

File tree

1 file changed

+136
-34
lines changed

1 file changed

+136
-34
lines changed

crates/matrix-sdk/src/event_cache/redecryptor.rs

Lines changed: 136 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
//! The Redecryptor (Rd) is a layer and long-running background task which
1616
//! handles redecryption of events in case we couldn't decrypt them imediatelly.
1717
//!
18-
//! Rd listens to the OlmMachine for received room keys. If a new room key has
19-
//! been received it attempts to find any UTDs in the [`EventCache`]. If Rd
20-
//! decrypts any UTDs from the event cache it will replace the events in the
21-
//! cache and send out new [`RoomEventCacheUpdates`] to any of its listeners.
18+
//! Rd listens to the OlmMachine for received room keys and new
19+
//! m.room_key.withheld events.
20+
//!
21+
//! If a new room key has been received it attempts to find any UTDs in the
22+
//! [`EventCache`]. If Rd decrypts any UTDs from the event cache it will replace
23+
//! the events in the cache and send out new [`RoomEventCacheUpdates`] to any of
24+
//! its listeners.
25+
//!
26+
//! If a new withheld info has been received it attempts to find any relevant
27+
//! events and updates the [`EncryptionInfo`] of an event.
2228
//!
2329
//! There's an additional gotcha, the [`OlmMachine`] might get recreated by
2430
//! calls to [`BaseClient::regenerate_olm()`]. When this happens we will receive
@@ -70,12 +76,15 @@ use futures_util::{StreamExt, pin_mut};
7076
#[cfg(doc)]
7177
use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
7278
use matrix_sdk_base::{
73-
crypto::{store::types::RoomKeyInfo, types::events::room::encrypted::EncryptedEvent},
79+
crypto::{
80+
store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
81+
types::events::room::encrypted::EncryptedEvent,
82+
},
7483
deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
7584
locks::Mutex,
7685
};
7786
use matrix_sdk_common::executor::spawn;
78-
use ruma::{OwnedEventId, OwnedRoomId, RoomId, events::AnySyncTimelineEvent, serde::Raw};
87+
use ruma::{OwnedEventId, OwnedRoomId, RoomId, serde::Raw};
7988
use tokio::{
8089
sync::{
8190
broadcast,
@@ -151,36 +160,21 @@ impl EventCache {
151160
/// * `room_id` - The ID of the room where the events were sent to.
152161
/// * `session_id` - The unique ID of the room key that was used to encrypt
153162
/// the event.
154-
async fn get_utds(
163+
async fn get_utds<B>(
155164
&self,
156165
room_id: &RoomId,
157-
session_id: SessionId<'_>,
158-
) -> Result<Vec<(OwnedEventId, Raw<AnySyncTimelineEvent>)>, EventCacheError> {
159-
let filter_non_utds = |event: TimelineEvent| {
160-
let event_id = event.event_id();
161-
162-
// We only care about events fort his particular room key, identified by the
163-
// session ID.
164-
if event.kind.session_id() == Some(session_id) {
165-
// Only pick out events that are UTDs, get just the Raw event as this is what
166-
// the OlmMachine needs.
167-
let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
168-
// Zip the event ID and event together so we don't have to pick out the event ID
169-
// again. We need the event ID to replace the event in the cache.
170-
event_id.zip(event)
171-
} else {
172-
None
173-
}
174-
};
175-
166+
_session_id: SessionId<'_>,
167+
filter: impl Fn(TimelineEvent) -> Option<B>,
168+
) -> Result<Vec<B>, EventCacheError> {
176169
// Load the relevant events from the event cache store and attempt to redecrypt
177170
// things.
178171
//
179172
// TODO: We can't load **all** events all the time.
173+
// TODO: Use the session ID to filter things.
180174
let store = self.inner.store.lock().await?;
181175
let events = store.get_room_events(&room_id).await?;
182176

183-
Ok(events.into_iter().filter_map(filter_non_utds).collect())
177+
Ok(events.into_iter().filter_map(filter).collect())
184178
}
185179

186180
/// Handle a chunk of events that we were previously unable to decrypt but
@@ -272,8 +266,25 @@ impl EventCache {
272266
) -> Result<(), EventCacheError> {
273267
trace!("Retrying to decrypt");
274268

269+
let filter_non_utds = |event: TimelineEvent| {
270+
let event_id = event.event_id();
271+
272+
// We only care about events fort his particular room key, identified by the
273+
// session ID.
274+
if event.kind.session_id() == Some(session_id) {
275+
// Only pick out events that are UTDs, get just the Raw event as this is what
276+
// the OlmMachine needs.
277+
let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
278+
// Zip the event ID and event together so we don't have to pick out the event ID
279+
// again. We need the event ID to replace the event in the cache.
280+
event_id.zip(event)
281+
} else {
282+
None
283+
}
284+
};
285+
275286
// Get all the relevant UTDs.
276-
let events = self.get_utds(room_id, session_id).await?;
287+
let events = self.get_utds(room_id, session_id, filter_non_utds).await?;
277288

278289
// Let's attempt to decrypt them them.
279290
let mut decrypted_events = Vec::with_capacity(events.len());
@@ -293,6 +304,54 @@ impl EventCache {
293304
Ok(())
294305
}
295306

307+
async fn update_encryption_info(
308+
&self,
309+
room_id: &RoomId,
310+
session_id: SessionId<'_>,
311+
) -> Result<(), EventCacheError> {
312+
let filter_non_utds = |event: TimelineEvent| {
313+
let event_id = event.event_id();
314+
315+
// We only care about events fort his particular room key, identified by the
316+
// session ID.
317+
if event.kind.session_id() == Some(session_id) {
318+
let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
319+
// Zip the event ID and event together so we don't have to pick out the event ID
320+
// again. We need the event ID to replace the event in the cache.
321+
event_id.zip(event)
322+
} else {
323+
None
324+
}
325+
};
326+
327+
let client = self.inner.client().ok().unwrap();
328+
let room = client.get_room(room_id).unwrap();
329+
330+
// Get all the relevant events.
331+
let events = self.get_utds(room_id, session_id, filter_non_utds).await?;
332+
333+
// Let's attempt to update their encryption info.
334+
let mut updated_events = Vec::with_capacity(events.len());
335+
336+
for (event_id, mut event) in events {
337+
let new_encryption_info =
338+
room.get_encryption_info(session_id, &event.encryption_info.sender).await;
339+
340+
// Only create a replacement if the encryption info actually changed.
341+
if let Some(new_encryption_info) = new_encryption_info {
342+
if event.encryption_info != new_encryption_info {
343+
event.encryption_info = new_encryption_info;
344+
345+
updated_events.push((event_id, event));
346+
}
347+
}
348+
}
349+
350+
self.on_resolved_utds(room_id, updated_events).await?;
351+
352+
Ok(())
353+
}
354+
296355
/// Explicitly request the redecryption of a set of events.
297356
///
298357
/// TODO: Explain when and why this might be useful.
@@ -347,12 +406,17 @@ impl Redecryptor {
347406
/// the sending part of the stream has been dropped.
348407
async fn subscribe_to_room_key_stream(
349408
cache: &Weak<EventCacheInner>,
350-
) -> Option<impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>> {
409+
) -> Option<(
410+
impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
411+
impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
412+
)> {
351413
let event_cache = cache.upgrade()?;
352414
let client = event_cache.client().ok()?;
353415
let machine = client.olm_machine().await;
354416

355-
machine.as_ref().map(|m| m.store().room_keys_received_stream())
417+
machine.as_ref().map(|m| {
418+
(m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
419+
})
356420
}
357421

358422
fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
@@ -363,11 +427,14 @@ impl Redecryptor {
363427
cache: &Weak<EventCacheInner>,
364428
decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
365429
) -> bool {
366-
let Some(room_key_stream) = Self::subscribe_to_room_key_stream(cache).await else {
430+
let Some((room_key_stream, withheld_stream)) =
431+
Self::subscribe_to_room_key_stream(cache).await
432+
else {
367433
return false;
368434
};
369435

370436
pin_mut!(room_key_stream);
437+
pin_mut!(withheld_stream);
371438

372439
loop {
373440
tokio::select! {
@@ -385,7 +452,14 @@ impl Redecryptor {
385452
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
386453
}
387454

388-
// TODO: Deal with encryption info updating as well.
455+
for session_id in request.refresh_info_session_ids {
456+
let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
457+
warn!(
458+
room_id = %request.room_id,
459+
session_id = session_id,
460+
"Unable to update the encryption info {e:?}",
461+
));
462+
}
389463
}
390464
// The room key stream from the OlmMachine. Needs to be recreated every time we
391465
// receive a `None` from the stream.
@@ -399,14 +473,21 @@ impl Redecryptor {
399473
break false;
400474
};
401475

402-
for key in room_keys {
476+
for key in &room_keys {
403477
let _ = cache
404478
.retry_decryption(&key.room_id, &key.session_id)
405479
.await
406480
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
407481
}
408482

409-
// TODO: Deal with encryption info updating as well.
483+
for key in room_keys {
484+
let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
485+
warn!(
486+
room_id = %key.room_id,
487+
session_id = key.session_id,
488+
"Unable to update the encryption info {e:?}",
489+
));
490+
}
410491
},
411492
Some(Err(_)) => {
412493
// We missed some room keys, we need to report this in case a listener
@@ -428,6 +509,27 @@ impl Redecryptor {
428509
}
429510
}
430511
}
512+
withheld_info = withheld_stream.next() => {
513+
match withheld_info {
514+
Some(infos) => {
515+
let Some(cache) = Self::upgrade_event_cache(cache) else {
516+
break false;
517+
};
518+
519+
for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
520+
let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
521+
warn!(
522+
room_id = %room_id,
523+
session_id = session_id,
524+
"Unable to update the encryption info {e:?}",
525+
));
526+
}
527+
}
528+
// The stream got closed, same as for the room key stream, we'll try to
529+
// recreate the streams.
530+
None => break true
531+
}
532+
}
431533
else => break false,
432534
}
433535
}

0 commit comments

Comments
 (0)