diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index af02c022..be68d6e2 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -94,6 +94,7 @@ pub(crate) async fn sync_events_v5_route( State(ref services): State, body: Ruma, ) -> Result { + let (sender_user, sender_device) = body.sender(); let request = &body.body; let since = request .pos @@ -102,38 +103,6 @@ pub(crate) async fn sync_events_v5_route( .and_then(|string| string.parse().ok()) .unwrap_or(0); - let (sender_user, sender_device) = body.sender(); - let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref()); - let conn_val = since - .ne(&0) - .then(|| services.sync.find_connection(&conn_key)) - .unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key))) - .map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?; - - let conn = conn_val.lock(); - let ping_presence = services - .presence - .maybe_ping_presence(sender_user, &request.set_presence) - .inspect_err(inspect_log) - .ok(); - - let (mut conn, _) = join(conn, ping_presence).await; - - if conn.next_batch != since { - return Err!(Request(UnknownPos("Requesting unknown or duplicate 'pos' parameter."))); - } - - conn.update_cache(request); - conn.globalsince = since; - conn.next_batch = since; - - let sync_info = SyncInfo { - services, - sender_user, - sender_device, - request, - }; - let timeout = request .timeout .as_ref() @@ -149,6 +118,49 @@ pub(crate) async fn sync_events_v5_route( .checked_add(Duration::from_millis(timeout)) .expect("configuration must limit maximum timeout"); + let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref()); + let conn_val = since + .ne(&0) + .then(|| services.sync.find_connection(&conn_key)) + .unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key))) + .map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?; + + let conn = conn_val.lock(); + let ping_presence = services + .presence + .maybe_ping_presence(sender_user, &request.set_presence) + .inspect_err(inspect_log) + .ok(); + + let (mut conn, _) = join(conn, ping_presence).await; + + // The client must either use the last returned next_batch or replay the + // next_batch from the penultimate request: it's either up-to-date or + // one-behind. If we receive anything else we can boot them. + let advancing = since == conn.next_batch; + let replaying = since == conn.globalsince; + if !advancing && !replaying { + return Err!(Request(UnknownPos("Requesting unknown or stale stream position."))); + } + + debug_assert!( + advancing || replaying, + "Request should either be advancing or replaying the last request." + ); + + // Update parameters regardless of replay or advance + conn.update_cache(request); + conn.update_rooms_prologue(advancing); + conn.globalsince = since; + conn.next_batch = services.globals.current_count(); + + let sync_info = SyncInfo { + services, + sender_user, + sender_device, + request, + }; + let mut response = Response { txn_id: request.txn_id.clone(), lists: Default::default(), @@ -158,19 +170,20 @@ pub(crate) async fn sync_events_v5_route( }; loop { + debug_assert!( + conn.globalsince <= conn.next_batch, + "next_batch should not be greater than since." + ); + let window; (window, response.lists) = selector(&mut conn, sync_info).boxed().await; + let watch_rooms = window.keys().map(AsRef::as_ref).stream(); let watchers = services .sync .watch(sender_user, sender_device, watch_rooms); conn.next_batch = services.globals.wait_pending().await?; - debug_assert!( - conn.globalsince <= conn.next_batch, - "next_batch should not be greater than since." - ); - if conn.globalsince < conn.next_batch { let rooms = handle_rooms(sync_info, &conn, &window).map_ok(|rooms| response.rooms = rooms); @@ -180,12 +193,7 @@ pub(crate) async fn sync_events_v5_route( try_join(rooms, extensions).boxed().await?; - for room_id in window.keys() { - conn.rooms - .entry(room_id.into()) - .or_default() - .roomsince = conn.next_batch; - } + conn.update_rooms_epilogue(window.keys().map(AsRef::as_ref)); if !is_empty_response(&response) { response.pos = conn.next_batch.to_string().into(); diff --git a/src/api/client/sync/v5/room.rs b/src/api/client/sync/v5/room.rs index f1ea54b3..01415808 100644 --- a/src/api/client/sync/v5/room.rs +++ b/src/api/client/sync/v5/room.rs @@ -49,7 +49,7 @@ pub(super) async fn handle( ) -> Result> { debug_assert!(DEFAULT_BUMP_TYPES.is_sorted(), "DEFAULT_BUMP_TYPES is not sorted"); - let &Room { roomsince } = conn + let &Room { roomsince, .. } = conn .rooms .get(room_id) .ok_or_else(|| err!("Missing connection state for {room_id}"))?; diff --git a/src/api/client/sync/v5/selector.rs b/src/api/client/sync/v5/selector.rs index a7d7bfc9..3cf3e236 100644 --- a/src/api/client/sync/v5/selector.rs +++ b/src/api/client/sync/v5/selector.rs @@ -95,26 +95,16 @@ where .map(move |range| (id.clone(), range)) }) .flat_map(|(id, (start, end))| { - let list = rooms + rooms .clone() - .filter(move |&room| room.lists.contains(&id)); - - let cycled = list.clone().all(|room| { - conn.rooms - .get(&room.room_id) - .is_some_and(|room| room.roomsince != 0) - }); - - list.enumerate() + .filter(move |&room| room.lists.contains(&id)) + .enumerate() .skip_while(move |&(i, room)| { i < start || conn .rooms .get(&room.room_id) - .is_some_and(|conn_room| { - conn_room.roomsince >= room.last_count - || (!cycled && conn_room.roomsince != 0) - }) + .is_some_and(|conn_room| conn_room.roomsince >= room.last_count) }) .take(end.saturating_add(1).saturating_sub(start)) .map(|(_, room)| (room.room_id.clone(), room.clone())) @@ -193,7 +183,7 @@ async fn match_lists_for_room( .then(|| { services .account_data - .last_count(Some(room_id.as_ref()), sender_user, conn.next_batch) + .last_count(Some(room_id.as_ref()), sender_user, None) .ok() }) .into(); @@ -204,7 +194,7 @@ async fn match_lists_for_room( .then(|| { services .read_receipt - .last_receipt_count(&room_id, sender_user.into(), conn.globalsince.into()) + .last_receipt_count(&room_id, sender_user.into(), None) .map(Result::ok) }) .into(); diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 4bfeaf90..c06a563f 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use ruma::{ - DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId, + DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId, api::client::sync::sync_events::v5::{ ConnId as ConnectionId, ListId, Request, request, request::{AccountData, E2EE, Receipts, ToDevice, Typing}, @@ -51,6 +51,8 @@ pub struct Connection { #[derive(Clone, Debug, Default)] pub struct Room { pub roomsince: u64, + pub last_batch: u64, + pub next_batch: u64, } type Connections = StdMutex>; @@ -87,6 +89,32 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } +#[implement(Connection)] +pub fn update_rooms_prologue(&mut self, advance: bool) { + self.rooms.values_mut().for_each(|room| { + if advance { + room.roomsince = room.next_batch; + room.last_batch = room.next_batch; + } else { + room.roomsince = room.last_batch; + room.next_batch = room.last_batch; + } + }); +} + +#[implement(Connection)] +pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms) +where + Rooms: Iterator + Send + 'a, +{ + window.for_each(|room_id| { + let room = self.rooms.entry(room_id.into()).or_default(); + + room.roomsince = self.next_batch; + room.next_batch = self.next_batch; + }); +} + #[implement(Connection)] pub fn update_cache(&mut self, request: &Request) { Self::update_cache_lists(request, self);