Fix sync handling for appservices calling without device_id

This commit is contained in:
dasha_uwu
2025-12-14 11:25:58 +05:00
committed by Jason Volk
parent 7b2079f714
commit 0c7ba1dd5a
15 changed files with 118 additions and 74 deletions

View File

@@ -15,7 +15,7 @@ pub(crate) enum SyncCommand {
/// Show details of sliding sync connection by ID. /// Show details of sliding sync connection by ID.
ShowConnection { ShowConnection {
user_id: OwnedUserId, user_id: OwnedUserId,
device_id: OwnedDeviceId, device_id: Option<OwnedDeviceId>,
conn_id: Option<String>, conn_id: Option<String>,
}, },
@@ -43,7 +43,7 @@ pub(super) async fn list_connections(&self) -> Result {
pub(super) async fn show_connection( pub(super) async fn show_connection(
&self, &self,
user_id: OwnedUserId, user_id: OwnedUserId,
device_id: OwnedDeviceId, device_id: Option<OwnedDeviceId>,
conn_id: Option<String>, conn_id: Option<String>,
) -> Result { ) -> Result {
let key = into_connection_key(user_id, device_id, conn_id); let key = into_connection_key(user_id, device_id, conn_id);

View File

@@ -32,8 +32,8 @@ pub(crate) async fn get_context_route(
State(services): State<crate::State>, State(services): State<crate::State>,
body: Ruma<get_context::v3::Request>, body: Ruma<get_context::v3::Request>,
) -> Result<get_context::v3::Response> { ) -> Result<get_context::v3::Response> {
let sender = body.sender(); let sender_user = body.sender_user();
let (sender_user, sender_device) = sender; let sender_device = body.sender_device.as_deref();
let room_id = &body.room_id; let room_id = &body.room_id;
let event_id = &body.event_id; let event_id = &body.event_id;
let filter = &body.filter; let filter = &body.filter;
@@ -110,7 +110,7 @@ pub(crate) async fn get_context_route(
let lazy_loading_context = lazy_loading::Context { let lazy_loading_context = lazy_loading::Context {
user_id: sender_user, user_id: sender_user,
device_id: Some(sender_device), device_id: sender_device,
room_id, room_id,
token: Some(base_count.into_unsigned()), token: Some(base_count.into_unsigned()),
options: Some(&filter.lazy_load_options), options: Some(&filter.lazy_load_options),

View File

@@ -23,7 +23,7 @@ pub(crate) async fn events_route(
State(services): State<crate::State>, State(services): State<crate::State>,
body: Ruma<Request>, body: Ruma<Request>,
) -> Result<Response> { ) -> Result<Response> {
let (sender_user, sender_device) = body.sender(); let sender_user = body.sender_user();
let from = body let from = body
.body .body
@@ -62,9 +62,11 @@ pub(crate) async fn events_route(
.expect("configuration must limit maximum timeout"); .expect("configuration must limit maximum timeout");
loop { loop {
let watchers = services let watchers = services.sync.watch(
.sync sender_user,
.watch(sender_user, sender_device, once(room_id).stream()); body.sender_device.as_deref(),
once(room_id).stream(),
);
let next_batch = services.globals.wait_pending().await?; let next_batch = services.globals.wait_pending().await?;

View File

@@ -35,7 +35,8 @@ pub(crate) async fn upload_keys_route(
State(services): State<crate::State>, State(services): State<crate::State>,
body: Ruma<upload_keys::v3::Request>, body: Ruma<upload_keys::v3::Request>,
) -> Result<upload_keys::v3::Response> { ) -> Result<upload_keys::v3::Response> {
let (sender_user, sender_device) = body.sender(); let sender_user = body.sender_user();
let sender_device = body.sender_device()?;
let one_time_keys = body let one_time_keys = body
.one_time_keys .one_time_keys

View File

@@ -589,7 +589,7 @@ pub(crate) async fn set_pushers_route(
services services
.pusher .pusher
.set_pusher(sender_user, body.sender_device(), &body.action) .set_pusher(sender_user, body.sender_device()?, &body.action)
.await?; .await?;
Ok(set_pusher::v3::Response::new()) Ok(set_pusher::v3::Response::new())

View File

@@ -23,7 +23,7 @@ pub(crate) async fn logout_route(
) -> Result<logout::v3::Response> { ) -> Result<logout::v3::Response> {
services services
.users .users
.remove_device(body.sender_user(), body.sender_device()) .remove_device(body.sender_user(), body.sender_device()?)
.await; .await;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())

View File

@@ -50,7 +50,8 @@ pub(crate) async fn login_token_route(
// This route SHOULD have UIA // This route SHOULD have UIA
// TODO: How do we make only UIA sessions that have not been used before valid? // TODO: How do we make only UIA sessions that have not been used before valid?
let (sender_user, sender_device) = body.sender(); let sender_user = body.sender_user();
let sender_device = body.sender_device()?;
let password_flow = uiaa::AuthFlow { stages: vec![uiaa::AuthType::Password] }; let password_flow = uiaa::AuthFlow { stages: vec![uiaa::AuthType::Password] };

View File

@@ -120,13 +120,15 @@ type PresenceUpdates = HashMap<OwnedUserId, PresenceEventContent>;
skip_all, skip_all,
fields( fields(
user_id = %body.sender_user(), user_id = %body.sender_user(),
device_id = %body.sender_device.as_deref().map_or("<no device>", |x| x.as_str()),
) )
)] )]
pub(crate) async fn sync_events_route( pub(crate) async fn sync_events_route(
State(services): State<crate::State>, State(services): State<crate::State>,
body: Ruma<sync_events::v3::Request>, body: Ruma<sync_events::v3::Request>,
) -> Result<sync_events::v3::Response> { ) -> Result<sync_events::v3::Response> {
let (sender_user, sender_device) = body.sender(); let sender_user = body.sender_user();
let sender_device = body.sender_device.as_deref();
let filter: OptionFuture<_> = body let filter: OptionFuture<_> = body
.body .body
@@ -195,7 +197,8 @@ pub(crate) async fn sync_events_route(
if since < next_batch || full_state { if since < next_batch || full_state {
let response = build_sync_events( let response = build_sync_events(
&services, &services,
body.sender(), sender_user,
sender_device,
since, since,
next_batch, next_batch,
full_state, full_state,
@@ -216,7 +219,8 @@ pub(crate) async fn sync_events_route(
// Wait for activity // Wait for activity
if time::timeout_at(stop_at, watchers).await.is_err() || services.server.is_stopping() { if time::timeout_at(stop_at, watchers).await.is_err() || services.server.is_stopping() {
let response = build_empty_response(&services, body.sender(), next_batch).await; let response =
build_empty_response(&services, sender_user, sender_device, next_batch).await;
trace!(since, next_batch, "empty response"); trace!(since, next_batch, "empty response");
return Ok(response); return Ok(response);
} }
@@ -235,14 +239,22 @@ pub(crate) async fn sync_events_route(
async fn build_empty_response( async fn build_empty_response(
services: &Services, services: &Services,
(sender_user, sender_device): (&UserId, &DeviceId), sender_user: &UserId,
sender_device: Option<&DeviceId>,
next_batch: u64, next_batch: u64,
) -> sync_events::v3::Response { ) -> sync_events::v3::Response {
let device_one_time_keys_count: OptionFuture<_> = sender_device
.map(|sender_device| {
services
.users
.count_one_time_keys(sender_user, sender_device)
})
.into();
sync_events::v3::Response { sync_events::v3::Response {
device_one_time_keys_count: services device_one_time_keys_count: device_one_time_keys_count
.users .await
.count_one_time_keys(sender_user, sender_device) .unwrap_or_default(),
.await,
..sync_events::v3::Response::new(to_small_string(next_batch)) ..sync_events::v3::Response::new(to_small_string(next_batch))
} }
@@ -261,7 +273,8 @@ async fn build_empty_response(
)] )]
async fn build_sync_events( async fn build_sync_events(
services: &Services, services: &Services,
(sender_user, sender_device): (&UserId, &DeviceId), sender_user: &UserId,
sender_device: Option<&DeviceId>,
since: u64, since: u64,
next_batch: u64, next_batch: u64,
full_state: bool, full_state: bool,
@@ -387,27 +400,38 @@ async fn build_sync_events(
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let to_device_events = services let to_device_events: OptionFuture<_> = sender_device
.users .map(|sender_device| {
.get_to_device_events(sender_user, sender_device, Some(since), Some(next_batch)) services
.map(at!(1)) .users
.collect::<Vec<_>>(); .get_to_device_events(sender_user, sender_device, Some(since), Some(next_batch))
.map(at!(1))
.collect::<Vec<_>>()
})
.into();
let device_one_time_keys_count = services let device_one_time_keys_count: OptionFuture<_> = sender_device
.users .map(|sender_device| {
.count_one_time_keys(sender_user, sender_device); services
.users
.count_one_time_keys(sender_user, sender_device)
})
.into();
// Remove all to-device events the device received *last time* // Remove all to-device events the device received *last time*
let remove_to_device_events = let remove_to_device_events: OptionFuture<_> = sender_device
services .map(|sender_device| {
.users services
.remove_to_device_events(sender_user, sender_device, since); .users
.remove_to_device_events(sender_user, sender_device, since)
})
.into();
let ( let (
account_data, account_data,
keys_changed, keys_changed,
device_one_time_keys_count, presence_updates,
((), to_device_events, presence_updates), (_, to_device_events, device_one_time_keys_count),
( (
(joined_rooms, mut device_list_updates, left_encrypted_users), (joined_rooms, mut device_list_updates, left_encrypted_users),
left_rooms, left_rooms,
@@ -417,8 +441,8 @@ async fn build_sync_events(
) = join5( ) = join5(
account_data, account_data,
keys_changed, keys_changed,
device_one_time_keys_count, presence_updates,
join3(remove_to_device_events, to_device_events, presence_updates), join3(remove_to_device_events, to_device_events, device_one_time_keys_count),
join4(joined_rooms, left_rooms, invited_rooms, knocked_rooms), join4(joined_rooms, left_rooms, invited_rooms, knocked_rooms),
) )
.boxed() .boxed()
@@ -454,7 +478,7 @@ async fn build_sync_events(
left: device_list_left, left: device_list_left,
changed: device_list_updates.into_iter().collect(), changed: device_list_updates.into_iter().collect(),
}, },
device_one_time_keys_count, device_one_time_keys_count: device_one_time_keys_count.unwrap_or_default(),
// Fallback keys are not yet supported // Fallback keys are not yet supported
device_unused_fallback_key_types: None, device_unused_fallback_key_types: None,
next_batch: to_small_string(next_batch), next_batch: to_small_string(next_batch),
@@ -465,7 +489,9 @@ async fn build_sync_events(
invite: invited_rooms, invite: invited_rooms,
knock: knocked_rooms, knock: knocked_rooms,
}, },
to_device: ToDevice { events: to_device_events }, to_device: ToDevice {
events: to_device_events.unwrap_or_default(),
},
}) })
} }
@@ -732,7 +758,7 @@ async fn load_left_room(
async fn load_joined_room( async fn load_joined_room(
services: &Services, services: &Services,
sender_user: &UserId, sender_user: &UserId,
sender_device: &DeviceId, sender_device: Option<&DeviceId>,
ref room_id: OwnedRoomId, ref room_id: OwnedRoomId,
since: u64, since: u64,
next_batch: u64, next_batch: u64,
@@ -841,7 +867,7 @@ async fn load_joined_room(
let lazy_loading_context = &lazy_loading::Context { let lazy_loading_context = &lazy_loading::Context {
user_id: sender_user, user_id: sender_user,
device_id: Some(sender_device), device_id: sender_device,
room_id, room_id,
token: Some(since), token: Some(since),
options: Some(&filter.room.state.lazy_load_options), options: Some(&filter.room.state.lazy_load_options),

View File

@@ -37,7 +37,7 @@ use crate::Ruma;
struct SyncInfo<'a> { struct SyncInfo<'a> {
services: &'a Services, services: &'a Services,
sender_user: &'a UserId, sender_user: &'a UserId,
sender_device: &'a DeviceId, sender_device: Option<&'a DeviceId>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -69,7 +69,7 @@ type ListIds = SmallVec<[ListId; 1]>;
skip_all, skip_all,
fields( fields(
user_id = %body.sender_user().localpart(), user_id = %body.sender_user().localpart(),
device_id = %body.sender_device(), device_id = %body.sender_device.as_deref().map_or("<no device>", |x| x.as_str()),
conn_id = ?body.body.conn_id.clone().unwrap_or_default(), conn_id = ?body.body.conn_id.clone().unwrap_or_default(),
since = ?body.body.pos.clone().unwrap_or_default(), since = ?body.body.pos.clone().unwrap_or_default(),
) )
@@ -78,7 +78,8 @@ pub(crate) async fn sync_events_v5_route(
State(ref services): State<crate::State>, State(ref services): State<crate::State>,
body: Ruma<Request>, body: Ruma<Request>,
) -> Result<Response> { ) -> Result<Response> {
let (sender_user, sender_device) = body.sender(); let sender_user = body.sender_user();
let sender_device = body.sender_device.as_deref();
let request = &body.body; let request = &body.body;
let since = request let since = request
.pos .pos
@@ -108,7 +109,7 @@ pub(crate) async fn sync_events_v5_route(
let conn = conn_val.lock(); let conn = conn_val.lock();
let ping_presence = services let ping_presence = services
.presence .presence
.maybe_ping_presence(sender_user, Some(sender_device), &request.set_presence) .maybe_ping_presence(sender_user, sender_device, &request.set_presence)
.inspect_err(inspect_log) .inspect_err(inspect_log)
.ok(); .ok();

View File

@@ -32,6 +32,9 @@ pub(super) async fn collect(
conn: &Connection, conn: &Connection,
) -> Result<response::E2EE> { ) -> Result<response::E2EE> {
let SyncInfo { services, sender_user, sender_device, .. } = sync_info; let SyncInfo { services, sender_user, sender_device, .. } = sync_info;
let Some(sender_device) = sender_device else {
return Ok(response::E2EE::default());
};
let keys_changed = services let keys_changed = services
.users .users

View File

@@ -9,6 +9,10 @@ pub(super) async fn collect(
SyncInfo { services, sender_user, sender_device, .. }: SyncInfo<'_>, SyncInfo { services, sender_user, sender_device, .. }: SyncInfo<'_>,
conn: &Connection, conn: &Connection,
) -> Result<Option<response::ToDevice>> { ) -> Result<Option<response::ToDevice>> {
let Some(sender_device) = sender_device else {
return Ok(None);
};
services services
.users .users
.remove_to_device_events(sender_user, sender_device, conn.globalsince) .remove_to_device_events(sender_user, sender_device, conn.globalsince)

View File

@@ -40,11 +40,6 @@ pub(crate) struct Args<T> {
} }
impl<T> Args<T> { impl<T> Args<T> {
#[inline]
pub(crate) fn sender(&self) -> (&UserId, &DeviceId) {
(self.sender_user(), self.sender_device())
}
#[inline] #[inline]
pub(crate) fn sender_user(&self) -> &UserId { pub(crate) fn sender_user(&self) -> &UserId {
self.sender_user self.sender_user
@@ -52,19 +47,19 @@ impl<T> Args<T> {
.expect("user must be authenticated for this handler") .expect("user must be authenticated for this handler")
} }
#[inline]
pub(crate) fn sender_device(&self) -> &DeviceId {
self.sender_device
.as_deref()
.expect("user must be authenticated and device identified")
}
#[inline] #[inline]
pub(crate) fn origin(&self) -> &ServerName { pub(crate) fn origin(&self) -> &ServerName {
self.origin self.origin
.as_deref() .as_deref()
.expect("server must be authenticated for this handler") .expect("server must be authenticated for this handler")
} }
#[inline]
pub(crate) fn sender_device(&self) -> Result<&DeviceId> {
self.sender_device
.as_deref()
.ok_or(err!(Request(Forbidden("user must be authenticated and device identified"))))
}
} }
impl<T> Deref for Args<T> impl<T> Deref for Args<T>

View File

@@ -14,6 +14,8 @@ pub(crate) async fn auth_uiaa<T>(services: &Services, body: &Ruma<T>) -> Result<
where where
T: IncomingRequest + Send + Sync, T: IncomingRequest + Send + Sync,
{ {
let sender_device = body.sender_device()?;
let flows = [ let flows = [
AuthFlow::new([AuthType::Password].into()), AuthFlow::new([AuthType::Password].into()),
AuthFlow::new([AuthType::Jwt].into()), AuthFlow::new([AuthType::Jwt].into()),
@@ -51,7 +53,7 @@ where
let (worked, uiaainfo) = services let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, body.sender_device(), auth, &uiaainfo) .try_auth(sender_user, sender_device, auth, &uiaainfo)
.await?; .await?;
if !worked { if !worked {
@@ -71,7 +73,7 @@ where
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services services
.uiaa .uiaa
.create(sender_user, body.sender_device(), &uiaainfo, json); .create(sender_user, sender_device, &uiaainfo, json);
Err(Error::Uiaa(uiaainfo)) Err(Error::Uiaa(uiaainfo))
}, },

View File

@@ -59,7 +59,7 @@ pub struct Room {
type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>; type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
pub type ConnectionVal = Arc<TokioMutex<Connection>>; pub type ConnectionVal = Arc<TokioMutex<Connection>>;
pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option<ConnectionId>); pub type ConnectionKey = (OwnedUserId, Option<OwnedDeviceId>, Option<ConnectionId>);
pub type Subscriptions = BTreeMap<OwnedRoomId, request::ListConfig>; pub type Subscriptions = BTreeMap<OwnedRoomId, request::ListConfig>;
pub type Lists = BTreeMap<ListId, request::List>; pub type Lists = BTreeMap<ListId, request::List>;
@@ -107,7 +107,7 @@ pub async fn clear_connections(
.await .await
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| { .retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
let retain = user_id.is_none_or(is_equal_to!(conn_user_id)) let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
&& device_id.is_none_or(is_equal_to!(conn_device_id)) && (device_id.is_none() || device_id == conn_device_id.as_deref())
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref()); && (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
if !retain { if !retain {
@@ -215,13 +215,17 @@ pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
} }
#[inline] #[inline]
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey pub fn into_connection_key<U, D, C>(
user_id: U,
device_id: Option<D>,
conn_id: Option<C>,
) -> ConnectionKey
where where
U: Into<OwnedUserId>, U: Into<OwnedUserId>,
D: Into<OwnedDeviceId>, D: Into<OwnedDeviceId>,
C: Into<ConnectionId>, C: Into<ConnectionId>,
{ {
(user_id.into(), device_id.into(), conn_id.map(Into::into)) (user_id.into(), device_id.map(Into::into), conn_id.map(Into::into))
} }
#[implement(Connection)] #[implement(Connection)]

View File

@@ -8,25 +8,18 @@ use tuwunel_database::{Interfix, Separator, serialize_key};
pub async fn watch<'a, Rooms>( pub async fn watch<'a, Rooms>(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: Option<&DeviceId>,
rooms: Rooms, rooms: Rooms,
) -> Result ) -> Result
where where
Rooms: Stream<Item = &'a RoomId> + Send + 'a, Rooms: Stream<Item = &'a RoomId> + Send + 'a,
{ {
let userdeviceid_prefix = (user_id, device_id, Interfix);
let globaluserdata_prefix = (Separator, user_id, Interfix); let globaluserdata_prefix = (Separator, user_id, Interfix);
let roomuserdataid_prefix = (Option::<&RoomId>::None, user_id, Interfix); let roomuserdataid_prefix = (Option::<&RoomId>::None, user_id, Interfix);
let userid_prefix = let userid_prefix =
serialize_key((user_id, Interfix)).expect("failed to serialize watch prefix"); serialize_key((user_id, Interfix)).expect("failed to serialize watch prefix");
let watchers = [ let watchers = [
// Return when *any* user changed their key
// TODO: only send for user they share a room with
self.db
.todeviceid_events
.watch_prefix(&userdeviceid_prefix)
.boxed(),
self.db self.db
.userroomid_joined .userroomid_joined
.watch_raw_prefix(&userid_prefix) .watch_raw_prefix(&userid_prefix)
@@ -72,8 +65,20 @@ where
.boxed(), .boxed(),
]; ];
let mut futures = FuturesUnordered::new(); let device_watchers = device_id.into_iter().map(|device_id| {
futures.extend(watchers.into_iter()); // Return when *any* user changed their key
// TODO: only send for user they share a room with
let userdeviceid_prefix = (user_id, device_id, Interfix);
self.db
.todeviceid_events
.watch_prefix(&userdeviceid_prefix)
.boxed()
});
let mut futures: FuturesUnordered<_> = watchers
.into_iter()
.chain(device_watchers)
.collect();
pin_mut!(rooms); pin_mut!(rooms);
while let Some(room_id) = rooms.next().await { while let Some(room_id) = rooms.next().await {