diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index cb89937c..350a1fe1 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -1,6 +1,6 @@ use std::{ cmp::Ordering, - collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, mem::take, ops::Deref, time::Duration, @@ -11,9 +11,10 @@ use futures::{ FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::{OptionFuture, join, join3, join4, join5, try_join}, pin_mut, + stream::once, }; use ruma::{ - DeviceId, JsOption, MxcUri, OwnedEventId, OwnedMxcUri, OwnedRoomId, RoomId, UInt, UserId, + DeviceId, JsOption, MxcUri, OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomId, UInt, UserId, api::client::sync::sync_events::{ DeviceLists, UnreadNotificationsCount, v5::{Request, Response, request::ExtensionRoomConfig, response}, @@ -32,10 +33,10 @@ use tokio::time::{Instant, timeout_at}; use tuwunel_core::{ Err, Result, apply, at, debug_error, error, extract_variant, is_equal_to, matrix::{Event, StateKey, TypeStateKey, pdu::PduCount}, - ref_at, trace, + pair_of, ref_at, trace, utils::{ BoolExt, FutureBoolExt, IterStream, ReadyExt, TryFutureExtExt, - future::ReadyEqExt, + future::{OptionStream, ReadyEqExt}, math::{ruma_from_usize, usize_from_ruma}, result::FlatOk, stream::{BroadbandExt, TryBroadbandExt, TryReadyExt, WidebandExt}, @@ -189,15 +190,9 @@ pub(crate) async fn sync_events_v5_route( ) .map_ok(|rooms| response.rooms = rooms); - let extensions = handle_extensions( - services, - sync_info, - next_batch, - &known_rooms, - &todo_rooms, - all_joined_rooms.clone(), - ) - .map_ok(|extensions| response.extensions = extensions); + let extensions = + handle_extensions(services, sync_info, next_batch, &known_rooms, &todo_rooms) + .map_ok(|extensions| response.extensions = extensions); try_join(rooms, extensions).boxed().await?; @@ -794,17 +789,13 @@ async fn calculate_heroes( known_rooms = known_rooms.len(), ) )] -async fn handle_extensions<'a, Rooms>( +async fn handle_extensions( services: &Services, sync_info: SyncInfo<'_>, next_batch: u64, known_rooms: &KnownRooms, todo_rooms: &TodoRooms, - all_joined_rooms: Rooms, -) -> Result -where - Rooms: Iterator + Clone + Send + 'a, -{ +) -> Result { let &(_, _, _, request) = &sync_info; let account_data: OptionFuture<_> = request @@ -844,9 +835,7 @@ where .e2ee .enabled .unwrap_or(false) - .then(|| { - collect_e2ee(services, sync_info, next_batch, todo_rooms, all_joined_rooms.clone()) - }) + .then(|| collect_e2ee(services, sync_info, next_batch)) .into(); let (account_data, receipts, typing, to_device, e2ee) = @@ -1055,237 +1044,201 @@ async fn collect_to_device( Ok(to_device) } -// TODO ---------------------------------------------------------------------- - -#[tracing::instrument( - level = "trace", - skip_all, - fields( - globalsince, - next_batch, - all_joined_rooms = all_joined_rooms.clone().count(), - ) -)] -async fn collect_e2ee<'a, Rooms>( +#[tracing::instrument(level = "trace", skip_all, fields(globalsince, next_batch,))] +async fn collect_e2ee( services: &Services, - (sender_user, sender_device, globalsince, _): SyncInfo<'_>, + syncinfo: SyncInfo<'_>, next_batch: u64, - _todo_rooms: &TodoRooms, - all_joined_rooms: Rooms, -) -> Result -where - Rooms: Iterator + Clone + Send + 'a, -{ - // Users that have left any encrypted rooms the sender was in - let mut left_encrypted_users = HashSet::new(); - let mut device_list_changes = HashSet::new(); - let mut device_list_left = HashSet::new(); - // Look for device list updates of this account - device_list_changes.extend( - services - .users - .keys_changed(sender_user, globalsince, Some(next_batch)) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - - for room_id in all_joined_rooms { - let Ok(current_shortstatehash) = services - .state - .get_room_shortstatehash(room_id) - .await - else { - error!("Room {room_id} has no state"); - continue; - }; - - let since_shortstatehash = services - .timeline - .next_shortstatehash(room_id, PduCount::Normal(globalsince)) - .await - .ok(); - - let encrypted_room = services - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") - .await - .is_ok(); - - if let Some(since_shortstatehash) = since_shortstatehash { - // Skip if there are only timeline changes - if since_shortstatehash == current_shortstatehash { - continue; - } - - let synced_shortstatehash = services - .timeline - .prev_shortstatehash(room_id, PduCount::Normal(globalsince).saturating_add(1)) - .await - .ok(); - - let synced_sender_member: OptionFuture<_> = synced_shortstatehash - .map(|shortstatehash| { - services - .state_accessor - .state_get_content( - shortstatehash, - &StateEventType::RoomMember, - sender_user.as_str(), - ) - .map_ok(|content: RoomMemberEventContent| content) - }) - .into(); - - let joined_since_last_sync = synced_sender_member - .await - .and_then(Result::ok) - .as_ref() - .is_none_or(|member| member.membership != MembershipState::Join); - - let synced_encryption: OptionFuture<_> = synced_shortstatehash - .map(|shortstatehash| { - services.state_accessor.state_get( - shortstatehash, - &StateEventType::RoomEncryption, - "", - ) - }) - .into(); - - let synced_encryption = synced_encryption.await.and_then(Result::ok); - - let new_encrypted_room = encrypted_room && synced_encryption.is_none(); - - if encrypted_room { - let current_state_ids: HashMap<_, OwnedEventId> = services - .state_accessor - .state_full_ids(current_shortstatehash) - .collect() - .await; - - let since_state_ids: HashMap<_, _> = services - .state_accessor - .state_full_ids(since_shortstatehash) - .collect() - .await; - - for (key, id) in current_state_ids { - if since_state_ids.get(&key) == Some(&id) { - continue; - } - - let Ok(pdu) = services.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); - continue; - }; - - if pdu.kind != TimelineEventType::RoomMember { - continue; - } - - let Some(Ok(user_id)) = pdu.state_key.as_deref().map(UserId::parse) else { - continue; - }; - - if user_id == sender_user { - continue; - } - - let content: RoomMemberEventContent = pdu.get_content()?; - match content.membership { - | MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room( - services, - sender_user, - user_id, - Some(room_id), - ) - .await - { - device_list_changes.insert(user_id.to_owned()); - } - }, - | MembershipState::Leave => { - // Write down users that have left encrypted rooms we - // are in - left_encrypted_users.insert(user_id.to_owned()); - }, - | _ => {}, - } - } - - if joined_since_last_sync || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_changes.extend( - services - .state_cache - .room_members(room_id) - // Don't send key updates from the sender to the sender - .ready_filter(|user_id| sender_user != *user_id) - // Only send keys if the sender doesn't share an encrypted room with the target - // already - .filter_map(|user_id| { - share_encrypted_room(services, sender_user, user_id, Some(room_id)) - .map(|res| res.or_some(user_id.to_owned())) - }) - .collect::>() - .await, - ); - } - } - } - // Look for device list updates in this room - device_list_changes.extend( - services - .users - .room_keys_changed(room_id, globalsince, Some(next_batch)) - .map(|(user_id, _)| user_id) - .map(ToOwned::to_owned) - .collect::>() - .await, - ); - } - - for user_id in left_encrypted_users { - let dont_share_encrypted_room = - !share_encrypted_room(services, sender_user, &user_id, None).await; - - // If the user doesn't share an encrypted room with the target anymore, we need - // to tell them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } - - let last_otk_update = services +) -> Result { + let &(sender_user, sender_device, globalsince, _) = &syncinfo; + let keys_changed = services .users - .last_one_time_keys_update(sender_user) + .keys_changed(sender_user, globalsince, Some(next_batch)) + .map(ToOwned::to_owned) + .collect::>() + .map(|changed| (changed, HashSet::new())); + + let (changed, left) = (HashSet::new(), HashSet::new()); + let (changed, left) = services + .state_cache + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .broad_filter_map(async |room_id| { + collect_e2ee_room(services, syncinfo, next_batch, &room_id) + .await + .ok() + }) + .chain(once(keys_changed)) + .ready_fold((changed, left), |(mut changed, mut left), room| { + changed.extend(room.0); + left.extend(room.1); + (changed, left) + }) .await; - let device_otk_count: OptionFuture<_> = last_otk_update - .gt(&globalsince) - .then(|| { - services - .users - .count_one_time_keys(sender_user, sender_device) + let left = left + .into_iter() + .stream() + .filter_map(async |user_id| { + share_encrypted_room(services, sender_user, &user_id, None) + .await + .is_false() + .then_some(user_id) }) - .into(); + .collect(); + + let device_one_time_keys_count = services + .users + .last_one_time_keys_update(sender_user) + .then(|since| -> OptionFuture<_> { + since + .gt(&globalsince) + .then(|| { + services + .users + .count_one_time_keys(sender_user, sender_device) + }) + .into() + }) + .map(Option::unwrap_or_default); + + let (left, device_one_time_keys_count) = join(left, device_one_time_keys_count) + .boxed() + .await; Ok(response::E2EE { - device_one_time_keys_count: device_otk_count.await.unwrap_or_default(), - + device_one_time_keys_count, device_unused_fallback_key_types: None, - device_lists: DeviceLists { - changed: device_list_changes.into_iter().collect(), - left: device_list_left.into_iter().collect(), + changed: changed.into_iter().collect(), + left, }, }) } -// ---------------------------------------------------------------------------- +#[tracing::instrument(level = "trace", skip_all, fields(room_id))] +async fn collect_e2ee_room( + services: &Services, + (sender_user, _, globalsince, _): SyncInfo<'_>, + next_batch: u64, + room_id: &RoomId, +) -> Result)> { + let current_shortstatehash = services + .state + .get_room_shortstatehash(room_id) + .inspect_err(|e| error!("Room {room_id} has no state: {e}")); + + let room_keys_changed = services + .users + .room_keys_changed(room_id, globalsince, Some(next_batch)) + .map(|(user_id, _)| user_id) + .map(ToOwned::to_owned) + .collect::>(); + + let (current_shortstatehash, device_list_changed) = + join(current_shortstatehash, room_keys_changed) + .boxed() + .await; + + let lists = (device_list_changed, HashSet::new()); + let Ok(current_shortstatehash) = current_shortstatehash else { + return Ok(lists); + }; + + if current_shortstatehash <= globalsince { + return Ok(lists); + } + + let Ok(since_shortstatehash) = services + .timeline + .prev_shortstatehash(room_id, PduCount::Normal(globalsince).saturating_add(1)) + .await + else { + return Ok(lists); + }; + + if since_shortstatehash == current_shortstatehash { + return Ok(lists); + } + + let encrypted_room = services + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "") + .is_ok(); + + let since_encryption = services + .state_accessor + .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "") + .is_ok(); + + let sender_joined_count = services + .state_cache + .get_joined_count(room_id, sender_user); + + let (encrypted_room, since_encryption, sender_joined_count) = + join3(encrypted_room, since_encryption, sender_joined_count).await; + + if !encrypted_room { + return Ok(lists); + } + + let encrypted_since_last_sync = !since_encryption; + let joined_since_last_sync = sender_joined_count.is_ok_and(|count| count > globalsince); + let joined_members_burst: OptionFuture<_> = (joined_since_last_sync + || encrypted_since_last_sync) + .then(|| { + services + .state_cache + .room_members(room_id) + .ready_filter(|&user_id| user_id != sender_user) + .map(ToOwned::to_owned) + .map(|user_id| (MembershipState::Join, user_id)) + .into_future() + }) + .into(); + + services + .state_accessor + .state_added((since_shortstatehash, current_shortstatehash)) + .broad_filter_map(async |(_shortstatekey, shorteventid)| { + services + .timeline + .get_pdu_from_shorteventid(shorteventid) + .ok() + .await + }) + .ready_filter(|event| *event.kind() == TimelineEventType::RoomMember) + .ready_filter(|event| { + event + .state_key() + .is_some_and(|state_key| state_key != sender_user) + }) + .ready_filter_map(|event| { + let content: RoomMemberEventContent = event.get_content().ok()?; + let user_id: OwnedUserId = event.state_key()?.parse().ok()?; + + Some((content.membership, user_id)) + }) + .chain(joined_members_burst.stream()) + .fold(lists, async |(mut changed, mut left), (membership, user_id)| { + use MembershipState::*; + + let should_add = async |user_id| { + !share_encrypted_room(services, sender_user, user_id, Some(room_id)).await + }; + + match membership { + | Join if should_add(&user_id).await => changed.insert(user_id), + | Leave => left.insert(user_id), + | _ => false, + }; + + (changed, left) + }) + .map(Ok) + .boxed() + .await +} fn extension_rooms_todo<'a>( (_, _, _, request): SyncInfo<'a>,