Skip to content

Commit e4e3d33

Browse files
committed
feat(redecryptor): Use the room to redecrypt events
This allows us to properly calculate the push actions.
1 parent 2554c8b commit e4e3d33

File tree

1 file changed

+85
-25
lines changed

1 file changed

+85
-25
lines changed

crates/matrix-sdk/src/event_cache/redecryptor.rs

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ use matrix_sdk_base::{
8686
#[cfg(doc)]
8787
use matrix_sdk_common::deserialized_responses::EncryptionInfo;
8888
use matrix_sdk_common::executor::{JoinHandle, spawn};
89-
use ruma::{OwnedEventId, OwnedRoomId, RoomId, events::AnySyncTimelineEvent, serde::Raw};
89+
use ruma::{
90+
OwnedEventId, OwnedRoomId, RoomId,
91+
events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
92+
push::Action,
93+
serde::Raw,
94+
};
9095
use tokio::sync::{
9196
broadcast,
9297
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
@@ -96,8 +101,12 @@ use tokio_stream::wrappers::{
96101
};
97102
use tracing::{info, instrument, trace, warn};
98103

99-
use crate::event_cache::{
100-
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
104+
use crate::{
105+
Room,
106+
event_cache::{
107+
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate,
108+
},
109+
room::PushContext,
101110
};
102111

103112
type SessionId<'a> = &'a str;
@@ -218,24 +227,29 @@ impl EventCache {
218227
async fn on_resolved_utds(
219228
&self,
220229
room_id: &RoomId,
221-
events: Vec<(OwnedEventId, DecryptedRoomEvent)>,
230+
events: Vec<(OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>)>,
222231
) -> Result<(), EventCacheError> {
223232
// Get the cache for this particular room and lock the state for the duration of
224233
// the decryption.
225234
let (room_cache, _drop_handles) = self.for_room(room_id).await?;
226235
let mut state = room_cache.inner.state.write().await;
227236

228-
let event_ids: BTreeSet<_> = events.iter().cloned().map(|(event_id, _)| event_id).collect();
237+
let event_ids: BTreeSet<_> =
238+
events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
229239

230240
trace!(?event_ids, "Replacing successfully re-decrypted events");
231241

232-
for (event_id, decrypted) in events {
242+
for (event_id, decrypted, actions) in events {
233243
// The event isn't in the cache, nothing to replace. Realistically this can't
234244
// happen since we retrieved the list of events from the cache itself and
235245
// `find_event()` will look into the store as well.
236246
if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
237247
target_event.kind = TimelineEventKind::Decrypted(decrypted);
238248

249+
if let Some(actions) = actions {
250+
target_event.set_push_actions(actions);
251+
}
252+
239253
// TODO: `replace_event_at()` propagates changes to the store for every event,
240254
// we should probably have a bulk version of this?
241255
state.replace_event_at(location, target_event).await?
@@ -266,21 +280,50 @@ impl EventCache {
266280
async fn decrypt_event(
267281
&self,
268282
room_id: &RoomId,
283+
room: Option<&Room>,
284+
push_context: Option<&PushContext>,
269285
event: &Raw<EncryptedEvent>,
270-
) -> Option<DecryptedRoomEvent> {
271-
let client = self.inner.client().ok()?;
272-
// TODO: Do we need to use the `Room` object to decrypt these events so we can
273-
// calculate if the event should count as a notification, i.e. get the push
274-
// actions. I thing we do, what happens if the room can't be found? We fallback
275-
// to this?
276-
let machine = client.olm_machine().await;
277-
let machine = machine.as_ref()?;
278-
279-
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
280-
Ok(decrypted) => Some(decrypted),
281-
Err(e) => {
282-
warn!("Failed to redecrypt an event despite receiving a room key for it {e:?}");
283-
None
286+
) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
287+
if let Some(room) = room {
288+
match room
289+
.decrypt_event(
290+
event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
291+
push_context,
292+
)
293+
.await
294+
{
295+
Ok(maybe_decrypted) => {
296+
let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
297+
298+
if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
299+
Some((decrypted, actions))
300+
} else {
301+
warn!(
302+
"Failed to redecrypt an event despite receiving a room key or request to redecrypt"
303+
);
304+
None
305+
}
306+
}
307+
Err(e) => {
308+
warn!(
309+
"Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
310+
);
311+
None
312+
}
313+
}
314+
} else {
315+
let client = self.inner.client().ok()?;
316+
let machine = client.olm_machine().await;
317+
let machine = machine.as_ref()?;
318+
319+
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
320+
Ok(decrypted) => Some((decrypted, None)),
321+
Err(e) => {
322+
warn!(
323+
"Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
324+
);
325+
None
326+
}
284327
}
285328
}
286329
}
@@ -298,14 +341,26 @@ impl EventCache {
298341
// Get all the relevant UTDs.
299342
let events = self.get_utds(room_id, session_id).await?;
300343

344+
let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
345+
let push_context =
346+
if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
347+
301348
// Let's attempt to decrypt them them.
302349
let mut decrypted_events = Vec::with_capacity(events.len());
303350

304351
for (event_id, event) in events {
305352
// If we managed to decrypt the event, and we should have to since we received
306353
// the room key for this specific event, then replace the event.
307-
if let Some(decrypted) = self.decrypt_event(room_id, event.cast_ref_unchecked()).await {
308-
decrypted_events.push((event_id, decrypted));
354+
if let Some((decrypted, actions)) = self
355+
.decrypt_event(
356+
room_id,
357+
room.as_ref(),
358+
push_context.as_ref(),
359+
event.cast_ref_unchecked(),
360+
)
361+
.await
362+
{
363+
decrypted_events.push((event_id, decrypted, actions));
309364
}
310365
}
311366

@@ -321,8 +376,13 @@ impl EventCache {
321376
room_id: &RoomId,
322377
session_id: SessionId<'_>,
323378
) -> Result<(), EventCacheError> {
324-
let client = self.inner.client().ok().unwrap();
325-
let room = client.get_room(room_id).unwrap();
379+
let Ok(client) = self.inner.client() else {
380+
return Ok(());
381+
};
382+
383+
let Some(room) = client.get_room(room_id) else {
384+
return Ok(());
385+
};
326386

327387
// Get all the relevant events.
328388
let events = self.get_decrypted_events(room_id, session_id).await?;
@@ -339,7 +399,7 @@ impl EventCache {
339399
&& event.encryption_info != new_encryption_info
340400
{
341401
event.encryption_info = new_encryption_info;
342-
updated_events.push((event_id, event));
402+
updated_events.push((event_id, event, None));
343403
}
344404
}
345405

0 commit comments

Comments
 (0)