Skip to content

Commit 0cc3ce1

Browse files
committed
feat(redecryptor): Use the room to redecrypt events
This allows us to properly calculate the push actions.
1 parent a452ef1 commit 0cc3ce1

File tree

1 file changed

+85
-25
lines changed

1 file changed

+85
-25
lines changed

crates/matrix-sdk/src/event_cache/redecryptor.rs

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ use matrix_sdk_base::{
8686
#[cfg(doc)]
8787
use matrix_sdk_common::deserialized_responses::EncryptionInfo;
8888
use matrix_sdk_common::executor::spawn;
89-
use ruma::{OwnedEventId, OwnedRoomId, RoomId, events::AnySyncTimelineEvent, serde::Raw};
89+
use ruma::{
90+
OwnedEventId, OwnedRoomId, RoomId,
91+
events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
92+
push::Action,
93+
serde::Raw,
94+
};
9095
use tokio::{
9196
sync::{
9297
broadcast,
@@ -99,8 +104,12 @@ use tokio_stream::wrappers::{
99104
};
100105
use tracing::{info, instrument, trace, warn};
101106

102-
use crate::event_cache::{
103-
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
107+
use crate::{
108+
Room,
109+
event_cache::{
110+
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
111+
},
112+
room::PushContext,
104113
};
105114

106115
type SessionId<'a> = &'a str;
@@ -221,24 +230,29 @@ impl EventCache {
221230
async fn on_resolved_utds(
222231
&self,
223232
room_id: &RoomId,
224-
events: Vec<(OwnedEventId, DecryptedRoomEvent)>,
233+
events: Vec<(OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>)>,
225234
) -> Result<(), EventCacheError> {
226235
// Get the cache for this particular room and lock the state for the duration of
227236
// the decryption.
228237
let (room_cache, _drop_handles) = self.for_room(room_id).await?;
229238
let mut state = room_cache.inner.state.write().await;
230239

231-
let event_ids: BTreeSet<_> = events.iter().cloned().map(|(event_id, _)| event_id).collect();
240+
let event_ids: BTreeSet<_> =
241+
events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
232242

233243
trace!(?event_ids, "Replacing successfully re-decrypted events");
234244

235-
for (event_id, decrypted) in events {
245+
for (event_id, decrypted, actions) in events {
236246
// The event isn't in the cache, nothing to replace. Realistically this can't
237247
// happen since we retrieved the list of events from the cache itself and
238248
// `find_event()` will look into the store as well.
239249
if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
240250
target_event.kind = TimelineEventKind::Decrypted(decrypted);
241251

252+
if let Some(actions) = actions {
253+
target_event.set_push_actions(actions);
254+
}
255+
242256
// TODO: `replace_event_at()` propagates changes to the store for every event,
243257
// we should probably have a bulk version of this?
244258
state.replace_event_at(location, target_event).await?
@@ -269,21 +283,50 @@ impl EventCache {
269283
async fn decrypt_event(
270284
&self,
271285
room_id: &RoomId,
286+
room: Option<&Room>,
287+
push_context: Option<&PushContext>,
272288
event: &Raw<EncryptedEvent>,
273-
) -> Option<DecryptedRoomEvent> {
274-
let client = self.inner.client().ok()?;
275-
// TODO: Do we need to use the `Room` object to decrypt these events so we can
276-
// calculate if the event should count as a notification, i.e. get the push
277-
// actions. I thing we do, what happens if the room can't be found? We fallback
278-
// to this?
279-
let machine = client.olm_machine().await;
280-
let machine = machine.as_ref()?;
281-
282-
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
283-
Ok(decrypted) => Some(decrypted),
284-
Err(e) => {
285-
warn!("Failed to redecrypt an event despite receiving a room key for it {e:?}");
286-
None
289+
) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
290+
if let Some(room) = room {
291+
match room
292+
.decrypt_event(
293+
event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
294+
push_context,
295+
)
296+
.await
297+
{
298+
Ok(maybe_decrypted) => {
299+
let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
300+
301+
if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
302+
Some((decrypted, actions))
303+
} else {
304+
warn!(
305+
"Failed to redecrypt an event despite receiving a room key or request to redecrypt"
306+
);
307+
None
308+
}
309+
}
310+
Err(e) => {
311+
warn!(
312+
"Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
313+
);
314+
None
315+
}
316+
}
317+
} else {
318+
let client = self.inner.client().ok()?;
319+
let machine = client.olm_machine().await;
320+
let machine = machine.as_ref()?;
321+
322+
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
323+
Ok(decrypted) => Some((decrypted, None)),
324+
Err(e) => {
325+
warn!(
326+
"Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
327+
);
328+
None
329+
}
287330
}
288331
}
289332
}
@@ -301,14 +344,26 @@ impl EventCache {
301344
// Get all the relevant UTDs.
302345
let events = self.get_utds(room_id, session_id).await?;
303346

347+
let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
348+
let push_context =
349+
if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
350+
304351
// Let's attempt to decrypt them them.
305352
let mut decrypted_events = Vec::with_capacity(events.len());
306353

307354
for (event_id, event) in events {
308355
// If we managed to decrypt the event, and we should have to since we received
309356
// the room key for this specific event, then replace the event.
310-
if let Some(decrypted) = self.decrypt_event(room_id, event.cast_ref_unchecked()).await {
311-
decrypted_events.push((event_id, decrypted));
357+
if let Some((decrypted, actions)) = self
358+
.decrypt_event(
359+
room_id,
360+
room.as_ref(),
361+
push_context.as_ref(),
362+
event.cast_ref_unchecked(),
363+
)
364+
.await
365+
{
366+
decrypted_events.push((event_id, decrypted, actions));
312367
}
313368
}
314369

@@ -324,8 +379,13 @@ impl EventCache {
324379
room_id: &RoomId,
325380
session_id: SessionId<'_>,
326381
) -> Result<(), EventCacheError> {
327-
let client = self.inner.client().ok().unwrap();
328-
let room = client.get_room(room_id).unwrap();
382+
let Ok(client) = self.inner.client() else {
383+
return Ok(());
384+
};
385+
386+
let Some(room) = client.get_room(room_id) else {
387+
return Ok(());
388+
};
329389

330390
// Get all the relevant events.
331391
let events = self.get_decrypted_events(room_id, session_id).await?;
@@ -342,7 +402,7 @@ impl EventCache {
342402
&& event.encryption_info != new_encryption_info
343403
{
344404
event.encryption_info = new_encryption_info;
345-
updated_events.push((event_id, event));
405+
updated_events.push((event_id, event, None));
346406
}
347407
}
348408

0 commit comments

Comments
 (0)