diff --git a/src/admin/query/sync.rs b/src/admin/query/sync.rs index 3e71db38..2fcbd463 100644 --- a/src/admin/query/sync.rs +++ b/src/admin/query/sync.rs @@ -15,7 +15,7 @@ pub(crate) enum SyncCommand { /// Show details of sliding sync connection by ID. ShowConnection { user_id: OwnedUserId, - device_id: OwnedDeviceId, + device_id: Option, conn_id: Option, }, @@ -43,7 +43,7 @@ pub(super) async fn list_connections(&self) -> Result { pub(super) async fn show_connection( &self, user_id: OwnedUserId, - device_id: OwnedDeviceId, + device_id: Option, conn_id: Option, ) -> Result { let key = into_connection_key(user_id, device_id, conn_id); diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 7e2ffdd4..3eacbf5b 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -32,8 +32,8 @@ pub(crate) async fn get_context_route( State(services): State, body: Ruma, ) -> Result { - let sender = body.sender(); - let (sender_user, sender_device) = sender; + let sender_user = body.sender_user(); + let sender_device = body.sender_device.as_deref(); let room_id = &body.room_id; let event_id = &body.event_id; let filter = &body.filter; @@ -110,7 +110,7 @@ pub(crate) async fn get_context_route( let lazy_loading_context = lazy_loading::Context { user_id: sender_user, - device_id: Some(sender_device), + device_id: sender_device, room_id, token: Some(base_count.into_unsigned()), options: Some(&filter.lazy_load_options), diff --git a/src/api/client/events.rs b/src/api/client/events.rs index c32aa7d5..8c27e155 100644 --- a/src/api/client/events.rs +++ b/src/api/client/events.rs @@ -23,7 +23,7 @@ pub(crate) async fn events_route( State(services): State, body: Ruma, ) -> Result { - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user(); let from = body .body @@ -62,9 +62,11 @@ pub(crate) async fn events_route( .expect("configuration must limit maximum timeout"); loop { - let watchers = services - .sync - .watch(sender_user, sender_device, once(room_id).stream()); + let watchers = services.sync.watch( + sender_user, + body.sender_device.as_deref(), + once(room_id).stream(), + ); let next_batch = services.globals.wait_pending().await?; diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index f186ca72..210e563d 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -35,7 +35,8 @@ pub(crate) async fn upload_keys_route( State(services): State, body: Ruma, ) -> Result { - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user(); + let sender_device = body.sender_device()?; let one_time_keys = body .one_time_keys diff --git a/src/api/client/push.rs b/src/api/client/push.rs index b0f383c6..2fb6495a 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -589,7 +589,7 @@ pub(crate) async fn set_pushers_route( services .pusher - .set_pusher(sender_user, body.sender_device(), &body.action) + .set_pusher(sender_user, body.sender_device()?, &body.action) .await?; Ok(set_pusher::v3::Response::new()) diff --git a/src/api/client/session/logout.rs b/src/api/client/session/logout.rs index 9d9f705d..10f53818 100644 --- a/src/api/client/session/logout.rs +++ b/src/api/client/session/logout.rs @@ -23,7 +23,7 @@ pub(crate) async fn logout_route( ) -> Result { services .users - .remove_device(body.sender_user(), body.sender_device()) + .remove_device(body.sender_user(), body.sender_device()?) .await; Ok(logout::v3::Response::new()) diff --git a/src/api/client/session/token.rs b/src/api/client/session/token.rs index b01c4bd1..8fe30e46 100644 --- a/src/api/client/session/token.rs +++ b/src/api/client/session/token.rs @@ -50,7 +50,8 @@ pub(crate) async fn login_token_route( // This route SHOULD have UIA // TODO: How do we make only UIA sessions that have not been used before valid? - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user(); + let sender_device = body.sender_device()?; let password_flow = uiaa::AuthFlow { stages: vec![uiaa::AuthType::Password] }; diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index bd77a2be..f5f31dd0 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -120,13 +120,15 @@ type PresenceUpdates = HashMap; skip_all, fields( user_id = %body.sender_user(), + device_id = %body.sender_device.as_deref().map_or("", |x| x.as_str()), ) )] pub(crate) async fn sync_events_route( State(services): State, body: Ruma, ) -> Result { - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user(); + let sender_device = body.sender_device.as_deref(); let filter: OptionFuture<_> = body .body @@ -195,7 +197,8 @@ pub(crate) async fn sync_events_route( if since < next_batch || full_state { let response = build_sync_events( &services, - body.sender(), + sender_user, + sender_device, since, next_batch, full_state, @@ -216,7 +219,8 @@ pub(crate) async fn sync_events_route( // Wait for activity if time::timeout_at(stop_at, watchers).await.is_err() || services.server.is_stopping() { - let response = build_empty_response(&services, body.sender(), next_batch).await; + let response = + build_empty_response(&services, sender_user, sender_device, next_batch).await; trace!(since, next_batch, "empty response"); return Ok(response); } @@ -235,14 +239,22 @@ pub(crate) async fn sync_events_route( async fn build_empty_response( services: &Services, - (sender_user, sender_device): (&UserId, &DeviceId), + sender_user: &UserId, + sender_device: Option<&DeviceId>, next_batch: u64, ) -> sync_events::v3::Response { + let device_one_time_keys_count: OptionFuture<_> = sender_device + .map(|sender_device| { + services + .users + .count_one_time_keys(sender_user, sender_device) + }) + .into(); + sync_events::v3::Response { - device_one_time_keys_count: services - .users - .count_one_time_keys(sender_user, sender_device) - .await, + device_one_time_keys_count: device_one_time_keys_count + .await + .unwrap_or_default(), ..sync_events::v3::Response::new(to_small_string(next_batch)) } @@ -261,7 +273,8 @@ async fn build_empty_response( )] async fn build_sync_events( services: &Services, - (sender_user, sender_device): (&UserId, &DeviceId), + sender_user: &UserId, + sender_device: Option<&DeviceId>, since: u64, next_batch: u64, full_state: bool, @@ -387,27 +400,38 @@ async fn build_sync_events( .map(ToOwned::to_owned) .collect::>(); - let to_device_events = services - .users - .get_to_device_events(sender_user, sender_device, Some(since), Some(next_batch)) - .map(at!(1)) - .collect::>(); + let to_device_events: OptionFuture<_> = sender_device + .map(|sender_device| { + services + .users + .get_to_device_events(sender_user, sender_device, Some(since), Some(next_batch)) + .map(at!(1)) + .collect::>() + }) + .into(); - let device_one_time_keys_count = services - .users - .count_one_time_keys(sender_user, sender_device); + let device_one_time_keys_count: OptionFuture<_> = sender_device + .map(|sender_device| { + services + .users + .count_one_time_keys(sender_user, sender_device) + }) + .into(); // Remove all to-device events the device received *last time* - let remove_to_device_events = - services - .users - .remove_to_device_events(sender_user, sender_device, since); + let remove_to_device_events: OptionFuture<_> = sender_device + .map(|sender_device| { + services + .users + .remove_to_device_events(sender_user, sender_device, since) + }) + .into(); let ( account_data, keys_changed, - device_one_time_keys_count, - ((), to_device_events, presence_updates), + presence_updates, + (_, to_device_events, device_one_time_keys_count), ( (joined_rooms, mut device_list_updates, left_encrypted_users), left_rooms, @@ -417,8 +441,8 @@ async fn build_sync_events( ) = join5( account_data, keys_changed, - device_one_time_keys_count, - join3(remove_to_device_events, to_device_events, presence_updates), + presence_updates, + join3(remove_to_device_events, to_device_events, device_one_time_keys_count), join4(joined_rooms, left_rooms, invited_rooms, knocked_rooms), ) .boxed() @@ -454,7 +478,7 @@ async fn build_sync_events( left: device_list_left, changed: device_list_updates.into_iter().collect(), }, - device_one_time_keys_count, + device_one_time_keys_count: device_one_time_keys_count.unwrap_or_default(), // Fallback keys are not yet supported device_unused_fallback_key_types: None, next_batch: to_small_string(next_batch), @@ -465,7 +489,9 @@ async fn build_sync_events( invite: invited_rooms, knock: knocked_rooms, }, - to_device: ToDevice { events: to_device_events }, + to_device: ToDevice { + events: to_device_events.unwrap_or_default(), + }, }) } @@ -732,7 +758,7 @@ async fn load_left_room( async fn load_joined_room( services: &Services, sender_user: &UserId, - sender_device: &DeviceId, + sender_device: Option<&DeviceId>, ref room_id: OwnedRoomId, since: u64, next_batch: u64, @@ -841,7 +867,7 @@ async fn load_joined_room( let lazy_loading_context = &lazy_loading::Context { user_id: sender_user, - device_id: Some(sender_device), + device_id: sender_device, room_id, token: Some(since), options: Some(&filter.room.state.lazy_load_options), diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index ac0cd6ff..e1d46d26 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -37,7 +37,7 @@ use crate::Ruma; struct SyncInfo<'a> { services: &'a Services, sender_user: &'a UserId, - sender_device: &'a DeviceId, + sender_device: Option<&'a DeviceId>, } #[derive(Clone, Debug)] @@ -69,7 +69,7 @@ type ListIds = SmallVec<[ListId; 1]>; skip_all, fields( user_id = %body.sender_user().localpart(), - device_id = %body.sender_device(), + device_id = %body.sender_device.as_deref().map_or("", |x| x.as_str()), conn_id = ?body.body.conn_id.clone().unwrap_or_default(), since = ?body.body.pos.clone().unwrap_or_default(), ) @@ -78,7 +78,8 @@ pub(crate) async fn sync_events_v5_route( State(ref services): State, body: Ruma, ) -> Result { - let (sender_user, sender_device) = body.sender(); + let sender_user = body.sender_user(); + let sender_device = body.sender_device.as_deref(); let request = &body.body; let since = request .pos @@ -108,7 +109,7 @@ pub(crate) async fn sync_events_v5_route( let conn = conn_val.lock(); let ping_presence = services .presence - .maybe_ping_presence(sender_user, Some(sender_device), &request.set_presence) + .maybe_ping_presence(sender_user, sender_device, &request.set_presence) .inspect_err(inspect_log) .ok(); diff --git a/src/api/client/sync/v5/extensions/e2ee.rs b/src/api/client/sync/v5/extensions/e2ee.rs index 72b322ec..0d9bdfa9 100644 --- a/src/api/client/sync/v5/extensions/e2ee.rs +++ b/src/api/client/sync/v5/extensions/e2ee.rs @@ -32,6 +32,9 @@ pub(super) async fn collect( conn: &Connection, ) -> Result { let SyncInfo { services, sender_user, sender_device, .. } = sync_info; + let Some(sender_device) = sender_device else { + return Ok(response::E2EE::default()); + }; let keys_changed = services .users diff --git a/src/api/client/sync/v5/extensions/to_device.rs b/src/api/client/sync/v5/extensions/to_device.rs index 2717ab9f..e1069003 100644 --- a/src/api/client/sync/v5/extensions/to_device.rs +++ b/src/api/client/sync/v5/extensions/to_device.rs @@ -9,6 +9,10 @@ pub(super) async fn collect( SyncInfo { services, sender_user, sender_device, .. }: SyncInfo<'_>, conn: &Connection, ) -> Result> { + let Some(sender_device) = sender_device else { + return Ok(None); + }; + services .users .remove_to_device_events(sender_user, sender_device, conn.globalsince) diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 13c212b5..c5b0c670 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -40,11 +40,6 @@ pub(crate) struct Args { } impl Args { - #[inline] - pub(crate) fn sender(&self) -> (&UserId, &DeviceId) { - (self.sender_user(), self.sender_device()) - } - #[inline] pub(crate) fn sender_user(&self) -> &UserId { self.sender_user @@ -52,19 +47,19 @@ impl Args { .expect("user must be authenticated for this handler") } - #[inline] - pub(crate) fn sender_device(&self) -> &DeviceId { - self.sender_device - .as_deref() - .expect("user must be authenticated and device identified") - } - #[inline] pub(crate) fn origin(&self) -> &ServerName { self.origin .as_deref() .expect("server must be authenticated for this handler") } + + #[inline] + pub(crate) fn sender_device(&self) -> Result<&DeviceId> { + self.sender_device + .as_deref() + .ok_or(err!(Request(Forbidden("user must be authenticated and device identified")))) + } } impl Deref for Args diff --git a/src/api/router/auth/uiaa.rs b/src/api/router/auth/uiaa.rs index 25a72669..4410a7f5 100644 --- a/src/api/router/auth/uiaa.rs +++ b/src/api/router/auth/uiaa.rs @@ -14,6 +14,8 @@ pub(crate) async fn auth_uiaa(services: &Services, body: &Ruma) -> Result< where T: IncomingRequest + Send + Sync, { + let sender_device = body.sender_device()?; + let flows = [ AuthFlow::new([AuthType::Password].into()), AuthFlow::new([AuthType::Jwt].into()), @@ -51,7 +53,7 @@ where let (worked, uiaainfo) = services .uiaa - .try_auth(sender_user, body.sender_device(), auth, &uiaainfo) + .try_auth(sender_user, sender_device, auth, &uiaainfo) .await?; if !worked { @@ -71,7 +73,7 @@ where uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa - .create(sender_user, body.sender_device(), &uiaainfo, json); + .create(sender_user, sender_device, &uiaainfo, json); Err(Error::Uiaa(uiaainfo)) }, diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 8782ede1..caf43b0a 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -59,7 +59,7 @@ pub struct Room { type Connections = TokioMutex>; pub type ConnectionVal = Arc>; -pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option); +pub type ConnectionKey = (OwnedUserId, Option, Option); pub type Subscriptions = BTreeMap; pub type Lists = BTreeMap; @@ -107,7 +107,7 @@ pub async fn clear_connections( .await .retain(|(conn_user_id, conn_device_id, conn_conn_id), _| { let retain = user_id.is_none_or(is_equal_to!(conn_user_id)) - && device_id.is_none_or(is_equal_to!(conn_device_id)) + && (device_id.is_none() || device_id == conn_device_id.as_deref()) && (conn_id.is_none() || conn_id == conn_conn_id.as_ref()); if !retain { @@ -215,13 +215,17 @@ pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool { } #[inline] -pub fn into_connection_key(user_id: U, device_id: D, conn_id: Option) -> ConnectionKey +pub fn into_connection_key( + user_id: U, + device_id: Option, + conn_id: Option, +) -> ConnectionKey where U: Into, D: Into, C: Into, { - (user_id.into(), device_id.into(), conn_id.map(Into::into)) + (user_id.into(), device_id.map(Into::into), conn_id.map(Into::into)) } #[implement(Connection)] diff --git a/src/service/sync/watch.rs b/src/service/sync/watch.rs index 69fe1616..a4c7020e 100644 --- a/src/service/sync/watch.rs +++ b/src/service/sync/watch.rs @@ -8,25 +8,18 @@ use tuwunel_database::{Interfix, Separator, serialize_key}; pub async fn watch<'a, Rooms>( &self, user_id: &UserId, - device_id: &DeviceId, + device_id: Option<&DeviceId>, rooms: Rooms, ) -> Result where Rooms: Stream + Send + 'a, { - let userdeviceid_prefix = (user_id, device_id, Interfix); let globaluserdata_prefix = (Separator, user_id, Interfix); let roomuserdataid_prefix = (Option::<&RoomId>::None, user_id, Interfix); let userid_prefix = serialize_key((user_id, Interfix)).expect("failed to serialize watch prefix"); let watchers = [ - // Return when *any* user changed their key - // TODO: only send for user they share a room with - self.db - .todeviceid_events - .watch_prefix(&userdeviceid_prefix) - .boxed(), self.db .userroomid_joined .watch_raw_prefix(&userid_prefix) @@ -72,8 +65,20 @@ where .boxed(), ]; - let mut futures = FuturesUnordered::new(); - futures.extend(watchers.into_iter()); + let device_watchers = device_id.into_iter().map(|device_id| { + // Return when *any* user changed their key + // TODO: only send for user they share a room with + let userdeviceid_prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .watch_prefix(&userdeviceid_prefix) + .boxed() + }); + + let mut futures: FuturesUnordered<_> = watchers + .into_iter() + .chain(device_watchers) + .collect(); pin_mut!(rooms); while let Some(room_id) = rooms.next().await {