From 46c940b8633d012a86256cdd4148e0a8d6a6d19a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 7 Oct 2025 21:35:42 +0000 Subject: [PATCH] Refactor sliding window selector. (fixes #170) Refactor list filtering. Signed-off-by: Jason Volk --- Cargo.lock | 21 +- Cargo.toml | 2 +- docker/Dockerfile.matrix-rust-sdk | 2 +- src/admin/query/sync.rs | 4 +- src/api/client/sync/mod.rs | 10 +- src/api/client/sync/v5.rs | 571 +++++++------------------ src/api/client/sync/v5/account_data.rs | 26 +- src/api/client/sync/v5/e2ee.rs | 34 +- src/api/client/sync/v5/filter.rs | 156 +++++++ src/api/client/sync/v5/receipts.rs | 38 +- src/api/client/sync/v5/room.rs | 133 ++++-- src/api/client/sync/v5/selector.rs | 257 +++++++++++ src/api/client/sync/v5/to_device.rs | 18 +- src/api/client/sync/v5/typing.rs | 19 +- src/service/rooms/read_receipt/data.rs | 2 +- src/service/sync/mod.rs | 209 ++++----- 16 files changed, 818 insertions(+), 684 deletions(-) create mode 100644 src/api/client/sync/v5/filter.rs create mode 100644 src/api/client/sync/v5/selector.rs diff --git a/Cargo.lock b/Cargo.lock index bb77c544..be073502 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3509,7 +3509,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.13.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "assign", "js_int", @@ -3528,7 +3528,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.13.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "js_int", "ruma-common", @@ -3540,7 +3540,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.21.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "as_variant", "assign", @@ -3556,6 +3556,7 @@ dependencies = [ "serde_html_form", "serde_json", "smallstr", + "smallvec", "thiserror 2.0.16", "url", "web-time", @@ -3564,7 +3565,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.16.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "as_variant", "base64", @@ -3597,7 +3598,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.31.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "as_variant", "indexmap", @@ -3623,7 +3624,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.12.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "bytes", "headers", @@ -3645,7 +3646,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.11.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "js_int", "thiserror 2.0.16", @@ -3654,7 +3655,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.16.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "cfg-if", "proc-macro-crate", @@ -3669,7 +3670,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.12.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "js_int", "ruma-common", @@ -3681,7 +3682,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.18.0" -source = "git+https://github.com/matrix-construct/ruma?rev=e03efd9d228be89365ae0399e2178d0ffa3dc42f#e03efd9d228be89365ae0399e2178d0ffa3dc42f" +source = "git+https://github.com/matrix-construct/ruma?rev=3c26be1e4d670c2c71e2bc84b3f09c67c4188612#3c26be1e4d670c2c71e2bc84b3f09c67c4188612" dependencies = [ "base64", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index 039f5b87..d82d3992 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -320,7 +320,7 @@ default-features = false [workspace.dependencies.ruma] git = "https://github.com/matrix-construct/ruma" -rev = "e03efd9d228be89365ae0399e2178d0ffa3dc42f" +rev = "3c26be1e4d670c2c71e2bc84b3f09c67c4188612" features = [ "__compat", "appservice-api-c", diff --git a/docker/Dockerfile.matrix-rust-sdk b/docker/Dockerfile.matrix-rust-sdk index 141189cd..c301c000 100644 --- a/docker/Dockerfile.matrix-rust-sdk +++ b/docker/Dockerfile.matrix-rust-sdk @@ -14,7 +14,7 @@ ARG mrsdk_target_share="${MRSDK_TARGET_DIR}/${sys_name}/${sys_version}/${rust_ta ARG mrsdk_test_args="" ARG mrsdk_test_opts="" #TODO!!! -ARG mrsdk_skip_list="--skip test_history_share_on_invite_pin_violation --skip delayed_invite_response_and_sent_message_decryption --skip test_mutual_sas_verification_with_notification_client_ignores_verification_events" +ARG mrsdk_skip_list="--skip test_history_share_on_invite_pin_violation --skip test_room_notification_count" WORKDIR / COPY --link --from=input . . diff --git a/src/admin/query/sync.rs b/src/admin/query/sync.rs index 86646a06..d72d7d27 100644 --- a/src/admin/query/sync.rs +++ b/src/admin/query/sync.rs @@ -32,7 +32,7 @@ pub(super) async fn list_connections(&self) -> Result { let connections = self.services.sync.list_connections(); for connection_key in connections { - self.write_str(&format!("{connection_key:?}")) + self.write_str(&format!("{connection_key:?}\n")) .await?; } @@ -51,7 +51,7 @@ pub(super) async fn show_connection( let out; { - let cached = cache.lock()?; + let cached = cache.lock().await; out = format!("{cached:#?}"); }; diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 61a7c59b..7a3aa777 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -2,12 +2,7 @@ mod v3; mod v5; use futures::{StreamExt, pin_mut}; -use ruma::{ - RoomId, UserId, - events::TimelineEventType::{ - self, Beacon, CallInvite, PollStart, RoomEncrypted, RoomMessage, Sticker, - }, -}; +use ruma::{RoomId, UserId}; use tuwunel_core::{ Error, PduCount, Result, matrix::pdu::PduEvent, @@ -17,9 +12,6 @@ use tuwunel_service::Services; pub(crate) use self::{v3::sync_events_route, v5::sync_events_v5_route}; -pub(crate) const DEFAULT_BUMP_TYPES: &[TimelineEventType; 6] = - &[CallInvite, PollStart, Beacon, RoomEncrypted, RoomMessage, Sticker]; - async fn load_timeline( services: &Services, sender_user: &UserId, diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 97fabd9e..32b5e785 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -1,73 +1,73 @@ mod account_data; mod e2ee; +mod filter; mod receipts; mod room; +mod selector; mod to_device; mod typing; -use std::{ - collections::{BTreeMap, BTreeSet}, - mem::take, - ops::Deref, - time::Duration, -}; +use std::{collections::BTreeMap, fmt::Debug, time::Duration}; use axum::extract::State; use futures::{ - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, - future::{OptionFuture, join, join3, join5, try_join}, - pin_mut, + FutureExt, TryFutureExt, TryStreamExt, + future::{OptionFuture, join, join5, try_join}, }; use ruma::{ - DeviceId, OwnedRoomId, RoomId, UInt, UserId, + DeviceId, OwnedRoomId, RoomId, UserId, api::client::sync::sync_events::v5::{ - Request, Response, request::ExtensionRoomConfig, response, + ListId, Request, Response, request::ExtensionRoomConfig, response, }, - directory::RoomTypeFilter, events::room::member::MembershipState, - uint, }; use tokio::time::{Instant, timeout_at}; use tuwunel_core::{ - Err, Result, apply, + Result, apply, at, + debug::INFO_SPAN_LEVEL, + err, error::inspect_log, - extract_variant, is_equal_to, - matrix::TypeStateKey, + extract_variant, + smallvec::SmallVec, trace, utils::{ - FutureBoolExt, IterStream, ReadyExt, TryFutureExtExt, - future::ReadyEqExt, - math::{ruma_from_usize, usize_from_ruma}, + BoolExt, IterStream, TryFutureExtExt, result::FlatOk, - stream::{BroadbandExt, TryBroadbandExt, TryReadyExt, WidebandExt}, + stream::{TryBroadbandExt, TryReadyExt}, }, - warn, }; use tuwunel_service::{ Services, - sync::{KnownRooms, ListId, into_connection_key}, + sync::{Connection, into_connection_key}, }; +use self::{ + filter::{filter_room, filter_room_meta}, + selector::selector, +}; use super::share_encrypted_room; -use crate::{Ruma, client::DEFAULT_BUMP_TYPES}; +use crate::Ruma; #[derive(Copy, Clone)] struct SyncInfo<'a> { + services: &'a Services, sender_user: &'a UserId, sender_device: &'a DeviceId, request: &'a Request, - globalsince: u64, } -struct TodoRoom { - membership: MembershipState, - requested_state: BTreeSet, - timeline_limit: usize, - roomsince: u64, +#[derive(Clone, Debug)] +struct WindowRoom { + room_id: OwnedRoomId, + membership: Option, + lists: ListIds, + ranked: usize, + last_count: u64, } -type TodoRooms = BTreeMap; +type Window = BTreeMap; type ResponseLists = BTreeMap; +type ListIds = SmallVec<[ListId; 1]>; /// `POST /_matrix/client/unstable/org.matrix.simplified_msc3575/sync` /// ([MSC4186]) @@ -81,65 +81,59 @@ type ResponseLists = BTreeMap; /// [MSC4186]: https://github.com/matrix-org/matrix-spec-proposals/pull/4186 #[tracing::instrument( name = "sync", - level = "debug", + level = INFO_SPAN_LEVEL, skip_all, fields( - user_id = %body.sender_user(), + user_id = %body.sender_user().localpart(), device_id = %body.sender_device(), + conn_id = ?body.body.conn_id.clone().unwrap_or_default(), + since = ?body.body.pos.clone().or_else(|| body.body.pos_qrs_.clone()).unwrap_or_default(), ) )] pub(crate) async fn sync_events_v5_route( State(ref services): State, - mut body: Ruma, + body: Ruma, ) -> Result { - debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted"); - - let mut request = take(&mut body.body); - let mut globalsince = request + let request = &body.body; + let since = request .pos .as_ref() + .or(request.pos_qrs_.as_ref()) .and_then(|string| string.parse().ok()) .unwrap_or(0); let (sender_user, sender_device) = body.sender(); let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref()); + let conn_val = since + .ne(&0) + .then(|| services.sync.find_connection(&conn_key)) + .unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key))) + .map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?; - if globalsince != 0 && !services.sync.is_connection_cached(&conn_key) { - return Err!(Request(UnknownPos( - "Connection data unknown to server; restarting sync stream." - ))); - } - - // Client / User requested an initial sync - if globalsince == 0 { - services.sync.drop_connection(&conn_key); - } - - // Get sticky parameters from cache - let known_rooms = services - .sync - .update_cache(&conn_key, &mut request); - - let sync_info = SyncInfo { - sender_user, - sender_device, - globalsince, - request: &request, - }; - - let lists = handle_lists(services, sync_info, known_rooms); - + let conn = conn_val.lock(); let ping_presence = services .presence .maybe_ping_presence(sender_user, &request.set_presence) .inspect_err(inspect_log) .ok(); - let ((known_rooms, todo_rooms, lists), _) = join(lists, ping_presence).await; + let (mut conn, _) = join(conn, ping_presence).await; + + conn.update_cache(request); + conn.globalsince = since; + conn.next_batch = since; + + let sync_info = SyncInfo { + services, + sender_user, + sender_device, + request, + }; let timeout = request .timeout .as_ref() + .or(request.timeout_qrs_.as_ref()) .map(Duration::as_millis) .map(TryInto::try_into) .flat_ok() @@ -153,53 +147,64 @@ pub(crate) async fn sync_events_v5_route( let mut response = Response { txn_id: request.txn_id.clone(), - lists, + lists: Default::default(), pos: Default::default(), rooms: Default::default(), extensions: Default::default(), }; loop { - let watch_rooms = todo_rooms.keys().map(AsRef::as_ref).stream(); + let window; + (window, response.lists) = selector(&mut conn, sync_info).boxed().await; + let watch_rooms = window.keys().map(AsRef::as_ref).stream(); let watchers = services .sync .watch(sender_user, sender_device, watch_rooms); - let next_batch = services.globals.wait_pending().await?; - debug_assert!(globalsince <= next_batch, "next_batch is monotonic"); + conn.next_batch = services.globals.wait_pending().await?; + debug_assert!( + conn.globalsince <= conn.next_batch, + "next_batch should not be greater than since." + ); - if globalsince < next_batch { - let rooms = handle_rooms(services, sync_info, next_batch, &todo_rooms) - .map_ok(|rooms| response.rooms = rooms); + if conn.globalsince < conn.next_batch { + let rooms = + handle_rooms(sync_info, &conn, &window).map_ok(|rooms| response.rooms = rooms); - let extensions = - handle_extensions(services, sync_info, next_batch, &known_rooms, &todo_rooms) - .map_ok(|extensions| response.extensions = extensions); + let extensions = handle_extensions(sync_info, &conn, &window) + .map_ok(|extensions| response.extensions = extensions); try_join(rooms, extensions).boxed().await?; + for room_id in window.keys() { + conn.rooms + .entry(room_id.into()) + .or_default() + .roomsince = conn.next_batch; + } + if !is_empty_response(&response) { - trace!(globalsince, next_batch, "response {response:?}"); - response.pos = next_batch.to_string().into(); + response.pos = conn.next_batch.to_string().into(); + trace!(conn.globalsince, conn.next_batch, "response {response:?}"); return Ok(response); } } - if timeout_at(stop_at, watchers).await.is_err() { - trace!(globalsince, next_batch, "timeout; empty response"); - response.pos = next_batch.to_string().into(); + if timeout_at(stop_at, watchers).await.is_err() || services.server.is_stopping() { + response.pos = conn.next_batch.to_string().into(); + trace!(conn.globalsince, conn.next_batch, "timeout; empty response"); return Ok(response); } trace!( - globalsince, - last_batch = ?next_batch, + conn.globalsince, + last_batch = ?conn.next_batch, count = ?services.globals.pending_count(), stop_at = ?stop_at, "notified by watcher" ); - globalsince = next_batch; + conn.globalsince = conn.next_batch; } } @@ -212,360 +217,89 @@ fn is_empty_response(response: &Response) -> bool { } #[tracing::instrument( - level = "debug", - skip_all, - fields( - known_rooms = known_rooms.len(), - ) -)] -#[allow(clippy::too_many_arguments)] -async fn handle_lists( - services: &Services, - sync_info: SyncInfo<'_>, - known_rooms: KnownRooms, -) -> (KnownRooms, TodoRooms, ResponseLists) { - let &SyncInfo { - sender_user, - sender_device, - request, - globalsince, - } = &sync_info; - - let all_joined_rooms = services - .state_cache - .rooms_joined(sender_user) - .map(ToOwned::to_owned) - .collect::>(); - - let all_invited_rooms = services - .state_cache - .rooms_invited(sender_user) - .map(ToOwned::to_owned) - .collect::>(); - - let all_knocked_rooms = services - .state_cache - .rooms_knocked(sender_user) - .map(ToOwned::to_owned) - .collect::>(); - - let (all_joined_rooms, all_invited_rooms, all_knocked_rooms) = - join3(all_joined_rooms, all_invited_rooms, all_knocked_rooms).await; - - let all_invited_rooms = all_invited_rooms.iter().map(AsRef::as_ref); - let all_knocked_rooms = all_knocked_rooms.iter().map(AsRef::as_ref); - let all_joined_rooms = all_joined_rooms.iter().map(AsRef::as_ref); - let all_rooms = all_joined_rooms - .clone() - .chain(all_invited_rooms.clone()) - .chain(all_knocked_rooms.clone()); - - let mut todo_rooms: TodoRooms = BTreeMap::new(); - let mut response_lists = ResponseLists::new(); - for (list_id, list) in &request.lists { - let active_rooms: Vec<_> = match list.filters.as_ref().and_then(|f| f.is_invite) { - | None => all_rooms.clone().collect(), - | Some(true) => all_invited_rooms.clone().collect(), - | Some(false) => all_joined_rooms.clone().collect(), - }; - - let active_rooms = match list.filters.as_ref().map(|f| &f.not_room_types) { - | None => active_rooms, - | Some(filter) if filter.is_empty() => active_rooms, - | Some(value) => - filter_rooms( - services, - value, - &true, - active_rooms.iter().stream().map(Deref::deref), - ) - .collect() - .await, - }; - - let mut new_known_rooms: BTreeSet = BTreeSet::new(); - let ranges = list.ranges.clone(); - for mut range in ranges { - range.0 = uint!(0); - range.1 = range.1.checked_add(uint!(1)).unwrap_or(range.1); - range.1 = range - .1 - .clamp(range.0, UInt::try_from(active_rooms.len()).unwrap_or(UInt::MAX)); - - let room_ids = - active_rooms[usize_from_ruma(range.0)..usize_from_ruma(range.1)].to_vec(); - - let new_rooms: BTreeSet = room_ids - .clone() - .into_iter() - .map(From::from) - .collect(); - - new_known_rooms.extend(new_rooms); - for room_id in room_ids { - let todo_room = todo_rooms - .entry(room_id.to_owned()) - .or_insert(TodoRoom { - membership: MembershipState::Join, - requested_state: BTreeSet::new(), - timeline_limit: 0_usize, - roomsince: u64::MAX, - }); - - todo_room.membership = if all_invited_rooms - .clone() - .any(is_equal_to!(room_id)) - { - MembershipState::Invite - } else { - MembershipState::Join - }; - - todo_room.requested_state.extend( - list.room_details - .required_state - .iter() - .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), - ); - - let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100); - todo_room.timeline_limit = todo_room.timeline_limit.max(limit); - - // 0 means unknown because it got out of date - todo_room.roomsince = todo_room.roomsince.min( - known_rooms - .get(list_id.as_str()) - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - } - } - - if let Some(conn_id) = request.conn_id.as_deref() { - let conn_key = into_connection_key(sender_user, sender_device, conn_id.into()); - let list_id = list_id.as_str().into(); - services - .sync - .update_known_rooms(&conn_key, list_id, new_known_rooms, globalsince); - } - - response_lists.insert(list_id.clone(), response::List { - count: ruma_from_usize(active_rooms.len()), - }); - } - - let (known_rooms, todo_rooms) = - fetch_subscriptions(services, sync_info, known_rooms, todo_rooms).await; - - (known_rooms, todo_rooms, response_lists) -} - -#[tracing::instrument( + name = "rooms", level = "debug", skip_all, fields( - global_since, - known_rooms = known_rooms.len(), - todo_rooms = todo_rooms.len(), + next_batch = conn.next_batch, + window = window.len(), ) )] -async fn fetch_subscriptions( - services: &Services, - SyncInfo { - sender_user, - sender_device, - globalsince, - request, - }: SyncInfo<'_>, - known_rooms: KnownRooms, - todo_rooms: TodoRooms, -) -> (KnownRooms, TodoRooms) { - let subs = (todo_rooms, BTreeSet::new()); - let (todo_rooms, known_subs) = request - .room_subscriptions - .iter() - .stream() - .broad_filter_map(async |(room_id, room)| { - let not_exists = services.metadata.exists(room_id).eq(&false); - let is_disabled = services.metadata.is_disabled(room_id); - let is_banned = services.metadata.is_banned(room_id); - - pin_mut!(not_exists, is_disabled, is_banned); - not_exists - .or(is_disabled) - .or(is_banned) - .await - .eq(&false) - .then_some((room_id, room)) - }) - .ready_fold(subs, |(mut todo_rooms, mut known_subs), (room_id, room)| { - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert(TodoRoom { - membership: MembershipState::Join, - requested_state: BTreeSet::new(), - timeline_limit: 0_usize, - roomsince: u64::MAX, - }); - - todo_room.requested_state.extend( - room.required_state - .iter() - .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), - ); - - let limit: UInt = room.timeline_limit; - todo_room.timeline_limit = todo_room - .timeline_limit - .max(usize_from_ruma(limit)); - - // 0 means unknown because it got out of date - todo_room.roomsince = todo_room.roomsince.min( - known_rooms - .get("subscriptions") - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - - known_subs.insert(room_id.clone()); - (todo_rooms, known_subs) - }) - .await; - - if let Some(conn_id) = request.conn_id.as_deref() { - let conn_key = into_connection_key(sender_user, sender_device, conn_id.into()); - let list_id = "subscriptions".into(); - services - .sync - .update_known_rooms(&conn_key, list_id, known_subs, globalsince); - } - - (known_rooms, todo_rooms) -} - -#[tracing::instrument( - level = "debug", - skip_all, - fields(?filters, negate) -)] -fn filter_rooms<'a, Rooms>( - services: &'a Services, - filters: &'a [RoomTypeFilter], - negate: &'a bool, - rooms: Rooms, -) -> impl Stream + Send + 'a -where - Rooms: Stream + Send + 'a, -{ - rooms - .wide_filter_map(async |room_id| { - match services - .state_accessor - .get_room_type(room_id) - .await - { - | Ok(room_type) => Some((room_id, Some(room_type))), - | Err(e) if e.is_not_found() => Some((room_id, None)), - | Err(_) => None, - } - }) - .map(|(room_id, room_type)| (room_id, RoomTypeFilter::from(room_type))) - .ready_filter_map(|(room_id, room_type_filter)| { - let contains = filters.contains(&room_type_filter); - let pos = !*negate && (filters.is_empty() || contains); - let neg = *negate && !contains; - - (pos || neg).then_some(room_id) - }) -} - -#[tracing::instrument( - level = "debug", - skip_all, - fields( - next_batch, - todo_rooms = todo_rooms.len(), - ) -)] async fn handle_rooms( - services: &Services, sync_info: SyncInfo<'_>, - next_batch: u64, - todo_rooms: &TodoRooms, + conn: &Connection, + window: &Window, ) -> Result> { - let rooms: BTreeMap<_, _> = todo_rooms + window .iter() .try_stream() - .broad_and_then(async |(room_id, todo_room)| { - let room = room::handle(services, next_batch, sync_info, room_id, todo_room).await?; - - Ok((room_id, room)) + .broad_and_then(async |(room_id, room)| { + room::handle(sync_info, conn, room) + .map_ok(|room| (room_id, room)) + .await }) .ready_try_filter_map(|(room_id, room)| Ok(room.map(|room| (room_id, room)))) .map_ok(|(room_id, room)| (room_id.to_owned(), room)) .try_collect() - .await?; - - Ok(rooms) + .await } #[tracing::instrument( + name = "extensions", level = "debug", skip_all, fields( - global_since, - known_rooms = known_rooms.len(), + next_batch = conn.next_batch, + window = window.len(), + rooms = conn.rooms.len(), + subs = conn.subscriptions.len(), ) )] async fn handle_extensions( - services: &Services, sync_info: SyncInfo<'_>, - next_batch: u64, - known_rooms: &KnownRooms, - todo_rooms: &TodoRooms, + conn: &Connection, + window: &Window, ) -> Result { - let SyncInfo { request, .. } = sync_info; + let SyncInfo { .. } = sync_info; - let account_data: OptionFuture<_> = request + let account_data: OptionFuture<_> = conn .extensions .account_data .enabled .unwrap_or(false) - .then(|| account_data::collect(services, sync_info, next_batch, known_rooms, todo_rooms)) + .then(|| account_data::collect(sync_info, conn, window)) .into(); - let receipts: OptionFuture<_> = request + let receipts: OptionFuture<_> = conn .extensions .receipts .enabled .unwrap_or(false) - .then(|| receipts::collect(services, sync_info, next_batch, known_rooms, todo_rooms)) + .then(|| receipts::collect(sync_info, conn, window)) .into(); - let typing: OptionFuture<_> = request + let typing: OptionFuture<_> = conn .extensions .typing .enabled .unwrap_or(false) - .then(|| typing::collect(services, sync_info, next_batch, known_rooms, todo_rooms)) + .then(|| typing::collect(sync_info, conn, window)) .into(); - let to_device: OptionFuture<_> = request + let to_device: OptionFuture<_> = conn .extensions .to_device .enabled .unwrap_or(false) - .then(|| to_device::collect(services, sync_info, next_batch)) + .then(|| to_device::collect(sync_info, conn)) .into(); - let e2ee: OptionFuture<_> = request + let e2ee: OptionFuture<_> = conn .extensions .e2ee .enabled .unwrap_or(false) - .then(|| e2ee::collect(services, sync_info, next_batch)) + .then(|| e2ee::collect(sync_info, conn)) .into(); let (account_data, receipts, typing, to_device, e2ee) = @@ -582,45 +316,60 @@ async fn handle_extensions( }) } -fn extension_rooms_todo<'a, ListIter, ConfigIter>( - SyncInfo { request, .. }: SyncInfo<'a>, - known_rooms: &'a KnownRooms, - todo_rooms: &'a TodoRooms, - lists: Option, - rooms: Option, +#[tracing::instrument( + name = "selector", + level = "trace", + skip_all, + fields(?implicit, ?explicit), +)] +fn extension_rooms_selector<'a, ListIter, SubsIter>( + SyncInfo { .. }: SyncInfo<'a>, + conn: &'a Connection, + window: &'a Window, + implicit: Option, + explicit: Option, ) -> impl Iterator + Send + Sync + 'a where - ListIter: Iterator + Clone + Send + Sync + 'a, - ConfigIter: Iterator + Clone + Send + Sync + 'a, + ListIter: Iterator + Clone + Debug + Send + Sync + 'a, + SubsIter: Iterator + Clone + Debug + Send + Sync + 'a, { - let lists_explicit = lists.clone().into_iter().flatten(); - - let rooms_explicit = rooms + let has_all_subscribed = explicit .clone() .into_iter() .flatten() - .filter_map(|erc| extract_variant!(erc, ExtensionRoomConfig::Room)) - .map(AsRef::::as_ref); + .any(|erc| matches!(erc, ExtensionRoomConfig::AllSubscribed)); - let lists_requested = request - .lists - .keys() - .filter(move |_| lists.is_none()); + let all_subscribed = has_all_subscribed + .then(|| conn.subscriptions.keys()) + .into_iter() + .flatten() + .map(AsRef::as_ref); - let rooms_implicit = todo_rooms - .keys() - .map(AsRef::as_ref) - .filter(move |_| rooms.is_none()); - - lists_explicit - .chain(lists_requested) - .flat_map(|list_id| { - known_rooms - .get(list_id.as_str()) + let rooms_explicit = has_all_subscribed + .is_false() + .then(move || { + explicit .into_iter() - .flat_map(BTreeMap::keys) + .flatten() + .filter_map(|erc| extract_variant!(erc, ExtensionRoomConfig::Room)) + .map(AsRef::as_ref) }) - .map(AsRef::as_ref) + .into_iter() + .flatten(); + + let rooms_selected = window + .iter() + .filter(move |(_, room)| { + implicit.as_ref().is_none_or(|lists| { + lists + .clone() + .any(|list| room.lists.contains(list)) + }) + }) + .map(at!(0)) + .map(AsRef::as_ref); + + all_subscribed .chain(rooms_explicit) - .chain(rooms_implicit) + .chain(rooms_selected) } diff --git a/src/api/client/sync/v5/account_data.rs b/src/api/client/sync/v5/account_data.rs index 986caa0f..ddc5a284 100644 --- a/src/api/client/sync/v5/account_data.rs +++ b/src/api/client/sync/v5/account_data.rs @@ -4,41 +4,39 @@ use tuwunel_core::{ Result, extract_variant, utils::{IterStream, ReadyExt, stream::BroadbandExt}, }; -use tuwunel_service::Services; +use tuwunel_service::sync::Room; -use super::{KnownRooms, SyncInfo, TodoRoom, TodoRooms, extension_rooms_todo}; +use super::{Connection, SyncInfo, Window, extension_rooms_selector}; -#[tracing::instrument(level = "trace", skip_all, fields(globalsince, next_batch))] +#[tracing::instrument(level = "trace", skip_all)] pub(super) async fn collect( - services: &Services, sync_info: SyncInfo<'_>, - next_batch: u64, - known_rooms: &KnownRooms, - todo_rooms: &TodoRooms, + conn: &Connection, + window: &Window, ) -> Result { - let SyncInfo { sender_user, globalsince, request, .. } = sync_info; + let SyncInfo { services, sender_user, .. } = sync_info; - let lists = request + let implicit = conn .extensions .account_data .lists .as_deref() .map(<[_]>::iter); - let rooms = request + let explicit = conn .extensions .account_data .rooms .as_deref() .map(<[_]>::iter); - let rooms = extension_rooms_todo(sync_info, known_rooms, todo_rooms, lists, rooms) + let rooms = extension_rooms_selector(sync_info, conn, window, implicit, explicit) .stream() .broad_filter_map(async |room_id| { - let &TodoRoom { roomsince, .. } = todo_rooms.get(room_id)?; + let &Room { roomsince, .. } = conn.rooms.get(room_id)?; let changes: Vec<_> = services .account_data - .changes_since(Some(room_id), sender_user, roomsince, Some(next_batch)) + .changes_since(Some(room_id), sender_user, roomsince, Some(conn.next_batch)) .ready_filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Room)) .collect() .await; @@ -52,7 +50,7 @@ pub(super) async fn collect( let global = services .account_data - .changes_since(None, sender_user, globalsince, Some(next_batch)) + .changes_since(None, sender_user, conn.globalsince, Some(conn.next_batch)) .ready_filter_map(|e| extract_variant!(e, AnyRawAccountDataEvent::Global)) .collect(); diff --git a/src/api/client/sync/v5/e2ee.rs b/src/api/client/sync/v5/e2ee.rs index 07cea3e4..32d373a3 100644 --- a/src/api/client/sync/v5/e2ee.rs +++ b/src/api/client/sync/v5/e2ee.rs @@ -22,23 +22,20 @@ use tuwunel_core::{ stream::BroadbandExt, }, }; -use tuwunel_service::Services; +use tuwunel_service::sync::Connection; use super::{SyncInfo, share_encrypted_room}; -#[tracing::instrument(level = "trace", skip_all, fields(globalsince, next_batch,))] +#[tracing::instrument(level = "trace", skip_all)] pub(super) async fn collect( - services: &Services, sync_info: SyncInfo<'_>, - next_batch: u64, + conn: &Connection, ) -> Result { - let SyncInfo { - sender_user, sender_device, globalsince, .. - } = sync_info; + let SyncInfo { services, sender_user, sender_device, .. } = sync_info; let keys_changed = services .users - .keys_changed(sender_user, globalsince, Some(next_batch)) + .keys_changed(sender_user, conn.globalsince, Some(conn.next_batch)) .map(ToOwned::to_owned) .collect::>() .map(|changed| (changed, HashSet::new())); @@ -48,11 +45,7 @@ pub(super) async fn collect( .state_cache .rooms_joined(sender_user) .map(ToOwned::to_owned) - .broad_filter_map(async |room_id| { - collect_room(services, sync_info, next_batch, &room_id) - .await - .ok() - }) + .broad_filter_map(async |room_id| collect_room(sync_info, conn, &room_id).await.ok()) .chain(once(keys_changed)) .ready_fold((changed, left), |(mut changed, mut left), room| { changed.extend(room.0); @@ -77,7 +70,7 @@ pub(super) async fn collect( .last_one_time_keys_update(sender_user) .then(|since| -> OptionFuture<_> { since - .gt(&globalsince) + .gt(&conn.globalsince) .then(|| { services .users @@ -103,9 +96,8 @@ pub(super) async fn collect( #[tracing::instrument(level = "trace", skip_all, fields(room_id))] async fn collect_room( - services: &Services, - SyncInfo { sender_user, globalsince, .. }: SyncInfo<'_>, - next_batch: u64, + SyncInfo { services, sender_user, .. }: SyncInfo<'_>, + conn: &Connection, room_id: &RoomId, ) -> Result)> { let current_shortstatehash = services @@ -115,7 +107,7 @@ async fn collect_room( let room_keys_changed = services .users - .room_keys_changed(room_id, globalsince, Some(next_batch)) + .room_keys_changed(room_id, conn.globalsince, Some(conn.next_batch)) .map(|(user_id, _)| user_id) .map(ToOwned::to_owned) .collect::>(); @@ -130,13 +122,13 @@ async fn collect_room( return Ok(lists); }; - if current_shortstatehash <= globalsince { + if current_shortstatehash <= conn.globalsince { return Ok(lists); } let Ok(since_shortstatehash) = services .timeline - .prev_shortstatehash(room_id, PduCount::Normal(globalsince).saturating_add(1)) + .prev_shortstatehash(room_id, PduCount::Normal(conn.globalsince).saturating_add(1)) .await else { return Ok(lists); @@ -168,7 +160,7 @@ async fn collect_room( } let encrypted_since_last_sync = !since_encryption; - let joined_since_last_sync = sender_joined_count.is_ok_and(|count| count > globalsince); + let joined_since_last_sync = sender_joined_count.is_ok_and(|count| count > conn.globalsince); let joined_members_burst: OptionFuture<_> = (joined_since_last_sync || encrypted_since_last_sync) .then(|| { diff --git a/src/api/client/sync/v5/filter.rs b/src/api/client/sync/v5/filter.rs new file mode 100644 index 00000000..26008cfd --- /dev/null +++ b/src/api/client/sync/v5/filter.rs @@ -0,0 +1,156 @@ +use futures::{StreamExt, future::OptionFuture, pin_mut}; +use ruma::{ + RoomId, api::client::sync::sync_events::v5::request::ListFilters, directory::RoomTypeFilter, + events::room::member::MembershipState, +}; +use tuwunel_core::{ + is_equal_to, is_true, + utils::{ + BoolExt, FutureBoolExt, IterStream, ReadyExt, + future::{OptionExt, ReadyEqExt}, + }, +}; + +use super::SyncInfo; + +#[tracing::instrument(name = "filter", level = "trace", skip_all)] +pub(super) async fn filter_room( + SyncInfo { services, sender_user, .. }: SyncInfo<'_>, + filter: &ListFilters, + room_id: &RoomId, + membership: Option<&MembershipState>, +) -> bool { + let match_invite: OptionFuture<_> = filter + .is_invite + .map(async |is_invite| match (membership, is_invite) { + | (Some(MembershipState::Invite), true) => true, + | (Some(MembershipState::Invite), false) => false, + | (Some(_), true) => false, + | (Some(_), false) => true, + | _ => + services + .state_cache + .is_invited(sender_user, room_id) + .await == is_invite, + }) + .into(); + + let match_direct: OptionFuture<_> = filter + .is_dm + .map(async |is_dm| { + services + .account_data + .is_direct(sender_user, room_id) + .await == is_dm + }) + .into(); + + let match_direct_member: OptionFuture<_> = filter + .is_dm + .map(async |is_dm| { + services + .state_accessor + .is_direct(room_id, sender_user) + .await == is_dm + }) + .into(); + + let match_encrypted: OptionFuture<_> = filter + .is_encrypted + .map(async |is_encrypted| { + services + .state_accessor + .is_encrypted_room(room_id) + .await == is_encrypted + }) + .into(); + + let match_space_child: OptionFuture<_> = filter + .spaces + .is_empty() + .is_false() + .then(async || { + filter + .spaces + .iter() + .stream() + .flat_map(|room_id| services.spaces.get_space_children(room_id)) + .ready_any(is_equal_to!(room_id)) + .await + }) + .into(); + + let fetch_tags = !filter.tags.is_empty() || !filter.not_tags.is_empty(); + let match_room_tag: OptionFuture<_> = fetch_tags + .then(async || { + if let Some(tags) = services + .account_data + .get_room_tags(sender_user, room_id) + .await + .ok() + .filter(|tags| !tags.is_empty()) + { + tags.keys().any(|tag| { + (filter.not_tags.is_empty() || !filter.not_tags.contains(tag)) + || (!filter.tags.is_empty() && filter.tags.contains(tag)) + }) + } else { + filter.tags.is_empty() + } + }) + .into(); + + let fetch_room_type = !filter.room_types.is_empty() || !filter.not_room_types.is_empty(); + let match_room_type: OptionFuture<_> = fetch_room_type + .then(async || { + let room_type = services + .state_accessor + .get_room_type(room_id) + .await + .ok(); + + let room_type = RoomTypeFilter::from(room_type); + (filter.not_room_types.is_empty() || !filter.not_room_types.contains(&room_type)) + && (filter.room_types.is_empty() || filter.room_types.contains(&room_type)) + }) + .into(); + + match_encrypted + .is_none_or(is_true!()) + .and3( + match_invite.is_none_or(is_true!()), + match_direct.is_none_or(is_true!()), + match_direct_member.is_none_or(is_true!()), + ) + .and3( + match_space_child.is_none_or(is_true!()), + match_room_type.is_none_or(is_true!()), + match_room_tag.is_none_or(is_true!()), + ) + .await +} + +#[tracing::instrument(name = "filter_meta", level = "trace", skip_all)] +pub(super) async fn filter_room_meta( + SyncInfo { services, sender_user, .. }: SyncInfo<'_>, + room_id: &RoomId, +) -> bool { + let not_exists = services.metadata.exists(room_id).eq(&false); + + let is_disabled = services.metadata.is_disabled(room_id); + + let is_banned = services.metadata.is_banned(room_id); + + let not_visible = services + .state_accessor + .user_can_see_state_events(sender_user, room_id) + .eq(&false); + + pin_mut!(not_visible, not_exists, is_disabled, is_banned); + not_visible + .or(not_exists) + .or(is_disabled) + .or(is_banned) + .await + .eq(&false) +} diff --git a/src/api/client/sync/v5/receipts.rs b/src/api/client/sync/v5/receipts.rs index 99ca1d38..9f37031f 100644 --- a/src/api/client/sync/v5/receipts.rs +++ b/src/api/client/sync/v5/receipts.rs @@ -9,58 +9,54 @@ use tuwunel_core::{ Result, utils::{BoolExt, IterStream, stream::BroadbandExt}, }; -use tuwunel_service::{Services, rooms::read_receipt::pack_receipts}; +use tuwunel_service::{rooms::read_receipt::pack_receipts, sync::Room}; -use super::{KnownRooms, SyncInfo, TodoRoom, TodoRooms, extension_rooms_todo}; +use super::{Connection, SyncInfo, Window, extension_rooms_selector}; #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn collect( - services: &Services, sync_info: SyncInfo<'_>, - next_batch: u64, - known_rooms: &KnownRooms, - todo_rooms: &TodoRooms, + conn: &Connection, + window: &Window, ) -> Result { - let SyncInfo { request, .. } = sync_info; + let SyncInfo { .. } = sync_info; - let lists = request + let implicit = conn .extensions .receipts .lists .as_deref() .map(<[_]>::iter); - let rooms = request + let explicit = conn .extensions .receipts .rooms .as_deref() .map(<[_]>::iter); - let rooms = extension_rooms_todo(sync_info, known_rooms, todo_rooms, lists, rooms) + let rooms = extension_rooms_selector(sync_info, conn, window, implicit, explicit) .stream() - .broad_filter_map(async |room_id| { - collect_room(services, sync_info, next_batch, todo_rooms, room_id).await - }) + .broad_filter_map(|room_id| collect_room(sync_info, conn, window, room_id)) .collect() .await; Ok(response::Receipts { rooms }) } +#[tracing::instrument(level = "trace", skip_all, fields(room_id))] async fn collect_room( - services: &Services, - SyncInfo { sender_user, .. }: SyncInfo<'_>, - next_batch: u64, - todo_rooms: &TodoRooms, + SyncInfo { services, sender_user, .. }: SyncInfo<'_>, + conn: &Connection, + _window: &Window, room_id: &RoomId, ) -> Option<(OwnedRoomId, Raw)> { - let &TodoRoom { roomsince, .. } = todo_rooms.get(room_id)?; + let &Room { roomsince, .. } = conn.rooms.get(room_id)?; let private_receipt = services .read_receipt .last_privateread_update(sender_user, room_id) .then(async |last_private_update| { - if last_private_update <= roomsince || last_private_update > next_batch { + if last_private_update <= roomsince || last_private_update > conn.next_batch { return None; } @@ -77,7 +73,7 @@ async fn collect_room( let receipts: Vec> = services .read_receipt - .readreceipts_since(room_id, roomsince, Some(next_batch)) + .readreceipts_since(room_id, roomsince, Some(conn.next_batch)) .filter_map(async |(read_user, _ts, v)| { services .users @@ -92,6 +88,6 @@ async fn collect_room( receipts .is_empty() - .eq(&false) + .is_false() .then(|| (room_id.to_owned(), pack_receipts(receipts.into_iter()))) } diff --git a/src/api/client/sync/v5/room.rs b/src/api/client/sync/v5/room.rs index e47c1229..f1ea54b3 100644 --- a/src/api/client/sync/v5/room.rs +++ b/src/api/client/sync/v5/room.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, collections::HashSet}; use futures::{ FutureExt, StreamExt, TryFutureExt, TryStreamExt, @@ -6,43 +6,79 @@ use futures::{ }; use ruma::{ JsOption, MxcUri, OwnedMxcUri, RoomId, UInt, UserId, - api::client::sync::sync_events::{UnreadNotificationsCount, v5::response}, - events::{StateEventType, room::member::MembershipState}, + api::client::sync::sync_events::{ + UnreadNotificationsCount, + v5::{DisplayName, response}, + }, + events::{ + StateEventType, + TimelineEventType::{ + self, Beacon, CallInvite, PollStart, RoomEncrypted, RoomMessage, Sticker, + }, + room::member::MembershipState, + }, }; use tuwunel_core::{ - Result, at, debug_error, is_equal_to, + Result, at, debug_error, err, is_equal_to, matrix::{Event, StateKey, pdu::PduCount}, ref_at, - utils::{IterStream, ReadyExt, TryFutureExtExt, result::FlatOk, stream::BroadbandExt}, + utils::{ + BoolExt, IterStream, ReadyExt, TryFutureExtExt, math::usize_from_ruma, result::FlatOk, + stream::BroadbandExt, + }, }; -use tuwunel_service::Services; +use tuwunel_service::{Services, sync::Room}; -use super::{SyncInfo, TodoRoom}; -use crate::client::{DEFAULT_BUMP_TYPES, ignored_filter, sync::load_timeline}; +use super::{super::load_timeline, Connection, SyncInfo, WindowRoom}; +use crate::client::ignored_filter; -#[tracing::instrument(level = "debug", skip_all, fields(room_id, roomsince))] +static DEFAULT_BUMP_TYPES: [TimelineEventType; 6] = + [CallInvite, PollStart, Beacon, RoomEncrypted, RoomMessage, Sticker]; + +#[tracing::instrument( + name = "room", + level = "debug", + skip_all, + fields(room_id, roomsince) +)] #[allow(clippy::too_many_arguments)] pub(super) async fn handle( - services: &Services, - next_batch: u64, - SyncInfo { sender_user, .. }: SyncInfo<'_>, - room_id: &RoomId, - &TodoRoom { - ref membership, - ref requested_state, - timeline_limit, - roomsince, - }: &TodoRoom, + SyncInfo { services, sender_user, .. }: SyncInfo<'_>, + conn: &Connection, + WindowRoom { lists, membership, room_id, .. }: &WindowRoom, ) -> Result> { - let timeline: OptionFuture<_> = membership - .ne(&MembershipState::Invite) + debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted"); + + let &Room { roomsince } = conn + .rooms + .get(room_id) + .ok_or_else(|| err!("Missing connection state for {room_id}"))?; + + let is_invite = *membership == Some(MembershipState::Invite); + let default_details = (0_usize, HashSet::new()); + let (timeline_limit, required_state) = lists + .iter() + .filter_map(|list_id| conn.lists.get(list_id)) + .map(|list| &list.room_details) + .chain(conn.subscriptions.get(room_id).into_iter()) + .fold(default_details, |(mut timeline_limit, mut required_state), config| { + let limit = usize_from_ruma(config.timeline_limit); + + timeline_limit = timeline_limit.max(limit); + required_state.extend(config.required_state.clone()); + + (timeline_limit, required_state) + }); + + let timeline: OptionFuture<_> = is_invite + .is_false() .then(|| { load_timeline( services, sender_user, room_id, PduCount::Normal(roomsince), - Some(PduCount::from(next_batch)), + Some(PduCount::from(conn.next_batch)), timeline_limit, ) }) @@ -56,7 +92,7 @@ pub(super) async fn handle( let (timeline_pdus, limited, _lastcount) = timeline.unwrap_or_else(|| (Vec::new(), true, PduCount::default())); - if roomsince != 0 && timeline_pdus.is_empty() && membership.ne(&MembershipState::Invite) { + if roomsince != 0 && timeline_pdus.is_empty() && !is_invite { return Ok(None); } @@ -76,7 +112,7 @@ pub(super) async fn handle( .is_ok() }) .fold(Option::::None, |mut bump_stamp, (_, pdu)| { - let ts = pdu.origin_server_ts().get(); + let ts = pdu.origin_server_ts().0; if bump_stamp.is_none_or(|bump_stamp| bump_stamp < ts) { bump_stamp.replace(ts); } @@ -84,7 +120,7 @@ pub(super) async fn handle( bump_stamp }); - let lazy = requested_state + let lazy = required_state .iter() .any(is_equal_to!(&(StateEventType::RoomMember, "$LAZY".into()))); @@ -102,7 +138,7 @@ pub(super) async fn handle( .map(|sender| (StateEventType::RoomMember, StateKey::from_str(sender.as_str()))) .stream(); - let wildcard_state = requested_state + let wildcard_state = required_state .iter() .filter(|(_, state_key)| state_key == "*") .map(|(event_type, _)| { @@ -115,7 +151,7 @@ pub(super) async fn handle( .stream() .flatten(); - let required_state = requested_state + let required_state = required_state .iter() .cloned() .stream() @@ -138,8 +174,7 @@ pub(super) async fn handle( .collect(); // TODO: figure out a timestamp we can use for remote invites - let invite_state: OptionFuture<_> = membership - .eq(&MembershipState::Invite) + let invite_state: OptionFuture<_> = is_invite .then(|| { services .state_cache @@ -159,6 +194,7 @@ pub(super) async fn handle( let room_name = services .state_accessor .get_name(room_id) + .map_ok(Into::into) .map(Result::ok); let room_avatar = services @@ -194,12 +230,17 @@ pub(super) async fn handle( .map_ok(Result::ok) .map(FlatOk::flat_ok); - let meta = join(room_name, room_avatar); + let is_dm = services + .state_accessor + .is_direct(room_id, sender_user) + .map(|is_dm| is_dm.then_some(is_dm)); + + let meta = join3(room_name, room_avatar, is_dm); let events = join3(timeline, required_state, invite_state); let member_counts = join(joined_count, invited_count); let notification_counts = join(highlight_count, notification_count); let ( - (room_name, room_avatar), + (room_name, room_avatar, is_dm), (timeline, required_state, invite_state), (joined_count, invited_count), (highlight_count, notification_count), @@ -211,7 +252,7 @@ pub(super) async fn handle( services, sender_user, room_id, - room_name.as_deref(), + room_name.as_ref(), room_avatar.as_deref(), ) .await?; @@ -220,14 +261,16 @@ pub(super) async fn handle( Ok(Some(response::Room { initial: Some(roomsince == 0), + lists: lists.clone(), + membership: membership.clone(), name: room_name.or(hero_name), avatar: JsOption::from_option(room_avatar.or(heroes_avatar)), - invite_state: invite_state.flatten(), + is_dm, required_state, - timeline, - is_dm: None, - prev_batch, + invite_state: invite_state.flatten(), + prev_batch: prev_batch.as_deref().map(Into::into), limited, + timeline, bump_stamp, heroes, num_live, @@ -237,15 +280,15 @@ pub(super) async fn handle( })) } -#[tracing::instrument(level = "debug", skip_all, fields(room_id, roomsince))] +#[tracing::instrument(name = "heroes", level = "trace", skip_all)] #[allow(clippy::type_complexity)] async fn calculate_heroes( services: &Services, sender_user: &UserId, room_id: &RoomId, - room_name: Option<&str>, + room_name: Option<&DisplayName>, room_avatar: Option<&MxcUri>, -) -> Result<(Option>, Option, Option)> { +) -> Result<(Option>, Option, Option)> { const MAX_HEROES: usize = 5; let heroes: Vec<_> = services .state_cache @@ -275,8 +318,10 @@ async fn calculate_heroes( let (name, avatar) = join(name, avatar).await; let hero = response::Hero { user_id, - name: name.unwrap_or(content.displayname), avatar: avatar.unwrap_or(content.avatar_url), + name: name + .unwrap_or(content.displayname) + .map(Into::into), }; Some(hero) @@ -291,7 +336,7 @@ async fn calculate_heroes( heroes[0] .name .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()), + .unwrap_or_else(|| heroes[0].user_id.as_str().into()), ), | Ordering::Greater => { let firsts = heroes[1..] @@ -299,7 +344,7 @@ async fn calculate_heroes( .map(|h| { h.name .clone() - .unwrap_or_else(|| h.user_id.to_string()) + .unwrap_or_else(|| h.user_id.as_str().into()) }) .collect::>() .join(", "); @@ -307,9 +352,9 @@ async fn calculate_heroes( let last = heroes[0] .name .clone() - .unwrap_or_else(|| heroes[0].user_id.to_string()); + .unwrap_or_else(|| heroes[0].user_id.as_str().into()); - Some(format!("{firsts} and {last}")) + Some(format!("{firsts} and {last}")).map(Into::into) }, }; diff --git a/src/api/client/sync/v5/selector.rs b/src/api/client/sync/v5/selector.rs new file mode 100644 index 00000000..a7d7bfc9 --- /dev/null +++ b/src/api/client/sync/v5/selector.rs @@ -0,0 +1,257 @@ +use std::cmp::Ordering; + +use futures::{ + FutureExt, StreamExt, TryFutureExt, + future::{OptionFuture, join3}, +}; +use ruma::{OwnedRoomId, UInt, events::room::member::MembershipState, uint}; +use tuwunel_core::{ + apply, is_true, + matrix::PduCount, + trace, + utils::{ + BoolExt, + future::TryExtExt, + math::usize_from_ruma, + stream::{BroadbandExt, IterStream}, + }, +}; +use tuwunel_service::sync::Connection; + +use super::{ + ListIds, ResponseLists, SyncInfo, Window, WindowRoom, filter_room, filter_room_meta, +}; + +#[tracing::instrument(level = "debug", skip_all)] +pub(super) async fn selector( + conn: &mut Connection, + sync_info: SyncInfo<'_>, +) -> (Window, ResponseLists) { + use MembershipState::*; + + let SyncInfo { services, sender_user, request, .. } = sync_info; + + trace!(?request); + let mut rooms = services + .state_cache + .user_memberships(sender_user, Some(&[Join, Invite, Knock])) + .map(|(membership, room_id)| (room_id.to_owned(), Some(membership))) + .broad_filter_map(|(room_id, membership)| { + match_lists_for_room(sync_info, conn, room_id, membership) + }) + .collect::>() + .await; + + rooms.sort_by(room_sort); + rooms + .iter_mut() + .enumerate() + .for_each(|(i, room)| { + room.ranked = i; + }); + + trace!(?rooms); + let lists = response_lists(rooms.iter()); + + trace!(?lists); + let window = select_window(sync_info, conn, rooms.iter(), &lists).await; + + trace!(?window); + for room in &rooms { + conn.rooms + .entry(room.room_id.clone()) + .or_default(); + } + + (window, lists) +} + +async fn select_window<'a, Rooms>( + sync_info: SyncInfo<'_>, + conn: &Connection, + rooms: Rooms, + lists: &ResponseLists, +) -> Window +where + Rooms: Iterator + Clone + Send + Sync, +{ + static FULL_RANGE: (UInt, UInt) = (UInt::MIN, UInt::MAX); + + let selections = lists + .keys() + .cloned() + .filter_map(|id| conn.lists.get(&id).map(|list| (id, list))) + .flat_map(|(id, list)| { + let full_range = list + .ranges + .is_empty() + .then_some(&FULL_RANGE) + .into_iter(); + + list.ranges + .iter() + .chain(full_range) + .map(apply!(2, usize_from_ruma)) + .map(move |range| (id.clone(), range)) + }) + .flat_map(|(id, (start, end))| { + let list = rooms + .clone() + .filter(move |&room| room.lists.contains(&id)); + + let cycled = list.clone().all(|room| { + conn.rooms + .get(&room.room_id) + .is_some_and(|room| room.roomsince != 0) + }); + + list.enumerate() + .skip_while(move |&(i, room)| { + i < start + || conn + .rooms + .get(&room.room_id) + .is_some_and(|conn_room| { + conn_room.roomsince >= room.last_count + || (!cycled && conn_room.roomsince != 0) + }) + }) + .take(end.saturating_add(1).saturating_sub(start)) + .map(|(_, room)| (room.room_id.clone(), room.clone())) + }); + + conn.subscriptions + .iter() + .stream() + .broad_filter_map(async |(room_id, _)| { + filter_room_meta(sync_info, room_id) + .await + .then(|| WindowRoom { + room_id: room_id.clone(), + membership: None, + lists: Default::default(), + ranked: usize::MAX, + last_count: 0, + }) + }) + .map(|room| (room.room_id.clone(), room)) + .chain(selections.stream()) + .collect() + .await +} + +#[tracing::instrument( + name = "matcher", + level = "trace", + skip_all, + fields(?room_id, ?membership) +)] +async fn match_lists_for_room( + sync_info: SyncInfo<'_>, + conn: &Connection, + room_id: OwnedRoomId, + membership: Option, +) -> Option { + let SyncInfo { services, sender_user, .. } = sync_info; + + let lists = conn + .lists + .iter() + .stream() + .filter_map(async |(id, list)| { + let filter: OptionFuture<_> = list + .filters + .clone() + .map(async |filters| { + filter_room(sync_info, &filters, &room_id, membership.as_ref()).await + }) + .into(); + + filter + .await + .is_none_or(is_true!()) + .then(|| id.clone()) + }) + .collect::() + .await; + + let last_timeline_count: OptionFuture<_> = lists + .is_empty() + .is_false() + .then(|| { + services + .timeline + .last_timeline_count(None, &room_id, None) + .map_ok(PduCount::into_unsigned) + .ok() + }) + .into(); + + let last_account_count: OptionFuture<_> = lists + .is_empty() + .is_false() + .then(|| { + services + .account_data + .last_count(Some(room_id.as_ref()), sender_user, conn.next_batch) + .ok() + }) + .into(); + + let last_receipt_count: OptionFuture<_> = lists + .is_empty() + .is_false() + .then(|| { + services + .read_receipt + .last_receipt_count(&room_id, sender_user.into(), conn.globalsince.into()) + .map(Result::ok) + }) + .into(); + + let (last_timeline_count, last_account_count, last_receipt_count) = + join3(last_timeline_count, last_account_count, last_receipt_count).await; + + Some(WindowRoom { + room_id: room_id.clone(), + membership, + lists, + ranked: 0, + last_count: [last_timeline_count, last_account_count, last_receipt_count] + .into_iter() + .map(Option::flatten) + .map(Option::unwrap_or_default) + .max() + .unwrap_or_default(), + }) +} + +fn response_lists<'a, Rooms>(rooms: Rooms) -> ResponseLists +where + Rooms: Iterator, +{ + rooms + .flat_map(|room| room.lists.iter()) + .fold(ResponseLists::default(), |mut lists, id| { + let list = lists.entry(id.clone()).or_default(); + list.count = list + .count + .checked_add(uint!(1)) + .expect("list count must not overflow JsInt"); + + lists + }) +} + +fn room_sort(a: &WindowRoom, b: &WindowRoom) -> Ordering { + if a.membership != b.membership { + if a.membership == Some(MembershipState::Invite) { + return Ordering::Less; + } + if b.membership == Some(MembershipState::Invite) { + return Ordering::Greater; + } + } + + b.last_count.cmp(&a.last_count) +} diff --git a/src/api/client/sync/v5/to_device.rs b/src/api/client/sync/v5/to_device.rs index 6a4298fe..918a926c 100644 --- a/src/api/client/sync/v5/to_device.rs +++ b/src/api/client/sync/v5/to_device.rs @@ -1,26 +1,22 @@ use futures::StreamExt; use ruma::api::client::sync::sync_events::v5::response; use tuwunel_core::{self, Result}; -use tuwunel_service::Services; -use super::SyncInfo; +use super::{Connection, SyncInfo}; -#[tracing::instrument(level = "trace", skip_all, fields(globalsince, next_batch))] +#[tracing::instrument(level = "trace", skip_all)] pub(super) async fn collect( - services: &Services, - SyncInfo { - sender_user, sender_device, globalsince, .. - }: SyncInfo<'_>, - next_batch: u64, + SyncInfo { services, sender_user, sender_device, .. }: SyncInfo<'_>, + conn: &Connection, ) -> Result> { services .users - .remove_to_device_events(sender_user, sender_device, globalsince) + .remove_to_device_events(sender_user, sender_device, conn.globalsince) .await; let events: Vec<_> = services .users - .get_to_device_events(sender_user, sender_device, None, Some(next_batch)) + .get_to_device_events(sender_user, sender_device, None, Some(conn.next_batch)) .collect() .await; @@ -28,7 +24,7 @@ pub(super) async fn collect( .is_empty() .eq(&false) .then(|| response::ToDevice { - next_batch: next_batch.to_string(), + next_batch: conn.next_batch.to_string().into(), events, }); diff --git a/src/api/client/sync/v5/typing.rs b/src/api/client/sync/v5/typing.rs index 650f3864..fb7293b7 100644 --- a/src/api/client/sync/v5/typing.rs +++ b/src/api/client/sync/v5/typing.rs @@ -10,37 +10,34 @@ use tuwunel_core::{ Result, debug_error, utils::{IterStream, ReadyExt}, }; -use tuwunel_service::Services; -use super::{KnownRooms, SyncInfo, TodoRooms, extension_rooms_todo}; +use super::{Connection, SyncInfo, Window, extension_rooms_selector}; -#[tracing::instrument(level = "trace", skip_all, fields(globalsince))] +#[tracing::instrument(level = "trace", skip_all)] pub(super) async fn collect( - services: &Services, sync_info: SyncInfo<'_>, - _next_batch: u64, - known_rooms: &KnownRooms, - todo_rooms: &TodoRooms, + conn: &Connection, + window: &Window, ) -> Result { use response::Typing; - let SyncInfo { sender_user, request, .. } = sync_info; + let SyncInfo { services, sender_user, .. } = sync_info; - let lists = request + let implicit = conn .extensions .typing .lists .as_deref() .map(<[_]>::iter); - let rooms = request + let explicit = conn .extensions .typing .rooms .as_deref() .map(<[_]>::iter); - extension_rooms_todo(sync_info, known_rooms, todo_rooms, lists, rooms) + extension_rooms_selector(sync_info, conn, window, implicit, explicit) .stream() .filter_map(async |room_id| { services diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 44f272e9..d7584366 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -97,7 +97,7 @@ impl Data { self.readreceiptid_readreceipt .rev_keys_prefix(&key) .ignore_err() - .ready_take_while(|(_, c, _): &Key<'_>| since.is_none_or(|since| since.gt(&c))) + .ready_take_while(|(_, c, _): &Key<'_>| since.is_none_or(|since| since.gt(c))) .ready_filter(|(_, _, u): &Key<'_>| user_id.as_ref().is_none_or(is_equal_to!(u))) .map(|(_, c, _): Key<'_>| c) .boxed() diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 2962bacd..4bfeaf90 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -1,18 +1,19 @@ mod watch; use std::{ - collections::{BTreeMap, BTreeSet}, - sync::{Arc, Mutex, Mutex as StdMutex}, + collections::BTreeMap, + sync::{Arc, Mutex as StdMutex}, }; use ruma::{ DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId, api::client::sync::sync_events::v5::{ - Request, request, + ConnId as ConnectionId, ListId, Request, request, request::{AccountData, E2EE, Receipts, ToDevice, Typing}, }, }; -use tuwunel_core::{Result, err, implement, is_equal_to, smallstr::SmallString}; +use tokio::sync::Mutex as TokioMutex; +use tuwunel_core::{Result, err, implement, is_equal_to}; use tuwunel_database::Map; pub struct Service { @@ -38,23 +39,27 @@ pub struct Data { } #[derive(Debug, Default)] -pub struct Cache { - lists: Lists, - known_rooms: KnownRooms, - subscriptions: Subscriptions, - extensions: request::Extensions, +pub struct Connection { + pub lists: Lists, + pub rooms: Rooms, + pub subscriptions: Subscriptions, + pub extensions: request::Extensions, + pub globalsince: u64, + pub next_batch: u64, } -type Connections = Mutex>; -type Connection = Arc>; -pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option); -pub type ConnectionId = SmallString<[u8; 16]>; +#[derive(Clone, Debug, Default)] +pub struct Room { + pub roomsince: u64, +} -pub type Subscriptions = BTreeMap; +type Connections = StdMutex>; +pub type ConnectionVal = Arc>; +pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option); + +pub type Subscriptions = BTreeMap; pub type Lists = BTreeMap; -pub type KnownRooms = BTreeMap; -pub type ListRooms = BTreeMap; -pub type ListId = SmallString<[u8; 16]>; +pub type Rooms = BTreeMap; impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { @@ -75,31 +80,26 @@ impl crate::Service for Service { userid_lastonetimekeyupdate: args.db["userid_lastonetimekeyupdate"].clone(), }, services: args.services.clone(), - connections: StdMutex::new(BTreeMap::new()), + connections: Default::default(), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -#[implement(Service)] -pub fn update_cache(&self, key: &ConnectionKey, request: &mut Request) -> KnownRooms { - let cache = self.get_connection(key); - let mut cached = cache.lock().expect("locked"); - - Self::update_cache_lists(request, &mut cached); - Self::update_cache_subscriptions(request, &mut cached); - Self::update_cache_extensions(request, &mut cached); - - cached.known_rooms.clone() +#[implement(Connection)] +pub fn update_cache(&mut self, request: &Request) { + Self::update_cache_lists(request, self); + Self::update_cache_subscriptions(request, self); + Self::update_cache_extensions(request, self); } -#[implement(Service)] -fn update_cache_lists(request: &mut Request, cached: &mut Cache) { - for (list_id, request_list) in &mut request.lists { +#[implement(Connection)] +fn update_cache_lists(request: &Request, cached: &mut Self) { + for (list_id, request_list) in &request.lists { cached .lists - .entry(list_id.as_str().into()) + .entry(list_id.clone()) .and_modify(|cached_list| { Self::update_cache_list(request_list, cached_list); }) @@ -107,134 +107,88 @@ fn update_cache_lists(request: &mut Request, cached: &mut Cache) { } } -#[implement(Service)] -fn update_cache_list(request: &mut request::List, cached: &mut request::List) { - list_or_sticky( - &mut request.room_details.required_state, - &mut cached.room_details.required_state, - ); +#[implement(Connection)] +fn update_cache_list(request: &request::List, cached: &mut request::List) { + list_or_sticky(&request.room_details.required_state, &mut cached.room_details.required_state); - match (&mut request.filters, &mut cached.filters) { + match (&request.filters, &mut cached.filters) { | (None, None) => {}, - | (None, Some(cached)) => request.filters = Some(cached.clone()), + | (None, Some(_cached)) => {}, | (Some(request), None) => cached.filters = Some(request.clone()), | (Some(request), Some(cached)) => { - some_or_sticky(&mut request.is_dm, &mut cached.is_dm); - some_or_sticky(&mut request.is_encrypted, &mut cached.is_encrypted); - some_or_sticky(&mut request.is_invite, &mut cached.is_invite); - list_or_sticky(&mut request.room_types, &mut cached.room_types); - list_or_sticky(&mut request.not_room_types, &mut cached.not_room_types); - list_or_sticky(&mut request.tags, &mut cached.not_tags); - list_or_sticky(&mut request.spaces, &mut cached.spaces); + some_or_sticky(request.is_dm.as_ref(), &mut cached.is_dm); + some_or_sticky(request.is_encrypted.as_ref(), &mut cached.is_encrypted); + some_or_sticky(request.is_invite.as_ref(), &mut cached.is_invite); + list_or_sticky(&request.room_types, &mut cached.room_types); + list_or_sticky(&request.not_room_types, &mut cached.not_room_types); + list_or_sticky(&request.tags, &mut cached.not_tags); + list_or_sticky(&request.spaces, &mut cached.spaces); }, } } -#[implement(Service)] -fn update_cache_subscriptions(request: &mut Request, cached: &mut Cache) { +#[implement(Connection)] +fn update_cache_subscriptions(request: &Request, cached: &mut Self) { cached .subscriptions .extend(request.room_subscriptions.clone()); - - request - .room_subscriptions - .extend(cached.subscriptions.clone()); } -#[implement(Service)] -fn update_cache_extensions(request: &mut Request, cached: &mut Cache) { - let request = &mut request.extensions; +#[implement(Connection)] +fn update_cache_extensions(request: &Request, cached: &mut Self) { + let request = &request.extensions; let cached = &mut cached.extensions; - Self::update_cache_account_data(&mut request.account_data, &mut cached.account_data); - Self::update_cache_receipts(&mut request.receipts, &mut cached.receipts); - Self::update_cache_typing(&mut request.typing, &mut cached.typing); - Self::update_cache_to_device(&mut request.to_device, &mut cached.to_device); - Self::update_cache_e2ee(&mut request.e2ee, &mut cached.e2ee); + Self::update_cache_account_data(&request.account_data, &mut cached.account_data); + Self::update_cache_receipts(&request.receipts, &mut cached.receipts); + Self::update_cache_typing(&request.typing, &mut cached.typing); + Self::update_cache_to_device(&request.to_device, &mut cached.to_device); + Self::update_cache_e2ee(&request.e2ee, &mut cached.e2ee); } -#[implement(Service)] -fn update_cache_account_data(request: &mut AccountData, cached: &mut AccountData) { - some_or_sticky(&mut request.enabled, &mut cached.enabled); - some_or_sticky(&mut request.lists, &mut cached.lists); - some_or_sticky(&mut request.rooms, &mut cached.rooms); +#[implement(Connection)] +fn update_cache_account_data(request: &AccountData, cached: &mut AccountData) { + some_or_sticky(request.enabled.as_ref(), &mut cached.enabled); + some_or_sticky(request.lists.as_ref(), &mut cached.lists); + some_or_sticky(request.rooms.as_ref(), &mut cached.rooms); } -#[implement(Service)] -fn update_cache_receipts(request: &mut Receipts, cached: &mut Receipts) { - some_or_sticky(&mut request.enabled, &mut cached.enabled); - some_or_sticky(&mut request.rooms, &mut cached.rooms); - some_or_sticky(&mut request.lists, &mut cached.lists); +#[implement(Connection)] +fn update_cache_receipts(request: &Receipts, cached: &mut Receipts) { + some_or_sticky(request.enabled.as_ref(), &mut cached.enabled); + some_or_sticky(request.rooms.as_ref(), &mut cached.rooms); + some_or_sticky(request.lists.as_ref(), &mut cached.lists); } -#[implement(Service)] -fn update_cache_typing(request: &mut Typing, cached: &mut Typing) { - some_or_sticky(&mut request.enabled, &mut cached.enabled); - some_or_sticky(&mut request.rooms, &mut cached.rooms); - some_or_sticky(&mut request.lists, &mut cached.lists); +#[implement(Connection)] +fn update_cache_typing(request: &Typing, cached: &mut Typing) { + some_or_sticky(request.enabled.as_ref(), &mut cached.enabled); + some_or_sticky(request.rooms.as_ref(), &mut cached.rooms); + some_or_sticky(request.lists.as_ref(), &mut cached.lists); } -#[implement(Service)] -fn update_cache_to_device(request: &mut ToDevice, cached: &mut ToDevice) { - some_or_sticky(&mut request.enabled, &mut cached.enabled); +#[implement(Connection)] +fn update_cache_to_device(request: &ToDevice, cached: &mut ToDevice) { + some_or_sticky(request.enabled.as_ref(), &mut cached.enabled); } -#[implement(Service)] -fn update_cache_e2ee(request: &mut E2EE, cached: &mut E2EE) { - some_or_sticky(&mut request.enabled, &mut cached.enabled); +#[implement(Connection)] +fn update_cache_e2ee(request: &E2EE, cached: &mut E2EE) { + some_or_sticky(request.enabled.as_ref(), &mut cached.enabled); } -/// load params from cache if body doesn't contain it, as long as it's allowed -/// in some cases we may need to allow an empty list as an actual value -fn list_or_sticky(target: &mut Vec, cached: &mut Vec) { +fn list_or_sticky(target: &Vec, cached: &mut Vec) { if !target.is_empty() { cached.clone_from(target); - } else { - target.clone_from(cached); } } -fn some_or_sticky(target: &mut Option, cached: &mut Option) { +fn some_or_sticky(target: Option<&T>, cached: &mut Option) { if let Some(target) = target { cached.replace(target.clone()); - } else { - target.clone_from(cached); } } -#[implement(Service)] -pub fn update_known_rooms( - &self, - key: &ConnectionKey, - list_id: ListId, - new_rooms: BTreeSet, - globalsince: u64, -) { - assert!(key.2.is_some(), "Some(conn_id) required for this call"); - - let cache = self.get_connection(key); - let mut cached = cache.lock().expect("locked"); - let list_rooms = cached.known_rooms.entry(list_id).or_default(); - - for (room_id, lastsince) in list_rooms.iter_mut() { - if !new_rooms.contains(room_id) { - *lastsince = 0; - } - } - - for room_id in new_rooms { - list_rooms.insert(room_id, globalsince); - } -} - -#[implement(Service)] -pub fn update_subscriptions(&self, key: &ConnectionKey, subscriptions: Subscriptions) { - self.get_connection(key) - .lock() - .expect("locked") - .subscriptions = subscriptions; -} - #[implement(Service)] pub fn clear_connections( &self, @@ -270,17 +224,18 @@ pub fn list_connections(&self) -> Vec { } #[implement(Service)] -pub fn get_connection(&self, key: &ConnectionKey) -> Arc> { +pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal { self.connections .lock() .expect("locked") .entry(key.clone()) - .or_insert_with(|| Arc::new(Mutex::new(Cache::default()))) + .and_modify(|existing| *existing = ConnectionVal::default()) + .or_default() .clone() } #[implement(Service)] -pub fn find_connection(&self, key: &ConnectionKey) -> Result>> { +pub fn find_connection(&self, key: &ConnectionKey) -> Result { self.connections .lock() .expect("locked") @@ -290,7 +245,7 @@ pub fn find_connection(&self, key: &ConnectionKey) -> Result>> } #[implement(Service)] -pub fn is_connection_cached(&self, key: &ConnectionKey) -> bool { +pub fn contains_connection(&self, key: &ConnectionKey) -> bool { self.connections .lock() .expect("locked")