Unify calculate_state_changes in syncv3

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-06-29 01:21:21 +00:00
parent 3c47516c85
commit c9adee86f5

View File

@@ -7,7 +7,7 @@ use std::{
use axum::extract::State;
use futures::{
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
future::{OptionFuture, join, join3, join4, join5, try_join, try_join4},
future::{OptionFuture, join, join3, join4, join5, try_join4},
pin_mut,
};
use ruma::{
@@ -917,124 +917,22 @@ async fn load_joined_room(
)
)]
#[allow(clippy::too_many_arguments)]
async fn calculate_state_changes(
async fn calculate_state_changes<'a>(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
full_state: bool,
filter: &FilterDefinition,
since_shortstatehash: Option<ShortStateHash>,
current_shortstatehash: ShortStateHash,
joined_since_last_sync: bool,
witness: Option<&Witness>,
) -> Result<StateChanges> {
if since_shortstatehash.is_none() {
calculate_state_initial(
services,
sender_user,
room_id,
full_state,
filter,
current_shortstatehash,
witness,
)
.await
} else {
calculate_state_incremental(
services,
sender_user,
room_id,
full_state,
filter,
since_shortstatehash,
current_shortstatehash,
joined_since_last_sync,
witness,
)
.await
}
}
#[tracing::instrument(name = "initial", level = "trace", skip_all)]
#[allow(clippy::too_many_arguments)]
async fn calculate_state_initial(
services: &Services,
sender_user: &UserId,
room_id: &RoomId,
full_state: bool,
_filter: &FilterDefinition,
current_shortstatehash: ShortStateHash,
witness: Option<&Witness>,
) -> Result<StateChanges> {
let (shortstatekeys, event_ids): (Vec<_>, Vec<_>) = services
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.unzip()
.await;
let state_events = services
.rooms
.short
.multi_get_statekey_from_short(shortstatekeys.into_iter().stream())
.zip(event_ids.into_iter().stream())
.ready_filter_map(|item| Some((item.0.ok()?, item.1)))
.ready_filter_map(|((event_type, state_key), event_id)| {
let lazy = !full_state
&& event_type == StateEventType::RoomMember
&& state_key
.as_str()
.try_into()
.is_ok_and(|user_id: &UserId| {
sender_user != user_id
&& witness.is_some_and(|witness| !witness.contains(user_id))
});
lazy.or_some(event_id)
})
.broad_filter_map(|event_id: OwnedEventId| async move {
services
.rooms
.timeline
.get_pdu(&event_id)
.await
.ok()
})
.collect()
.map(Ok);
let counts = calculate_counts(services, room_id, sender_user);
let ((joined_member_count, invited_member_count, heroes), state_events) =
try_join(counts, state_events).boxed().await?;
// The state_events above should contain all timeline_users, let's mark them as
// lazy loaded.
Ok(StateChanges {
heroes,
joined_member_count,
invited_member_count,
state_events,
..Default::default()
})
}
#[tracing::instrument(name = "incremental", level = "trace", skip_all)]
#[allow(clippy::too_many_arguments)]
async fn calculate_state_incremental<'a>(
services: &Services,
sender_user: &'a UserId,
room_id: &RoomId,
full_state: bool,
_filter: &FilterDefinition,
since_shortstatehash: Option<ShortStateHash>,
current_shortstatehash: ShortStateHash,
joined_since_last_sync: bool,
witness: Option<&'a Witness>,
) -> Result<StateChanges> {
let since_shortstatehash = since_shortstatehash.unwrap_or(current_shortstatehash);
let initial = full_state || since_shortstatehash.is_none() || joined_since_last_sync;
let state_changed = since_shortstatehash != current_shortstatehash;
let incremental = !initial && since_shortstatehash != Some(current_shortstatehash);
let since_shortstatehash = since_shortstatehash.unwrap_or(current_shortstatehash);
let encrypted_room = services
.rooms
@@ -1056,7 +954,7 @@ async fn calculate_state_incremental<'a>(
};
let lazy_state_ids: OptionFuture<_> = witness
.filter(|_| !full_state && !encrypted_room)
.filter(|_| !encrypted_room)
.map(|witness| {
StreamExt::into_future(
witness
@@ -1067,7 +965,7 @@ async fn calculate_state_incremental<'a>(
})
.into();
let state_diff_ids: OptionFuture<_> = (!full_state && state_changed)
let state_diff_ids: OptionFuture<_> = incremental
.then(|| {
StreamExt::into_future(
services
@@ -1079,7 +977,7 @@ async fn calculate_state_incremental<'a>(
})
.into();
let current_state_ids: OptionFuture<_> = full_state
let current_state_ids: OptionFuture<_> = initial
.then(|| {
StreamExt::into_future(
services
@@ -1094,7 +992,7 @@ async fn calculate_state_incremental<'a>(
let state_events = current_state_ids
.stream()
.chain(state_diff_ids.stream())
.broad_filter_map(|(shortstatekey, shorteventid)| async move {
.broad_filter_map(async |(shortstatekey, shorteventid)| {
if witness.is_none() || encrypted_room {
return Some(shorteventid);
}
@@ -1109,22 +1007,22 @@ async fn calculate_state_incremental<'a>(
.get_eventid_from_short(shorteventid)
.ok()
})
.broad_filter_map(|event_id: OwnedEventId| async move {
.broad_filter_map(async |event_id: OwnedEventId| {
services
.rooms
.timeline
.get_pdu(&event_id)
.await
.ok()
.await
})
.collect::<Vec<_>>()
.boxed()
.await;
let (device_list_updates, left_encrypted_users) = state_events
let device_updates = state_events
.iter()
.stream()
.ready_filter(|_| encrypted_room)
.ready_filter(|_| encrypted_room && !initial)
.ready_filter(|state_event| state_event.kind == RoomMember)
.ready_filter_map(|state_event| {
let content: RoomMemberEventContent = state_event.get_content().ok()?;
@@ -1132,7 +1030,7 @@ async fn calculate_state_incremental<'a>(
Some((content, user_id))
})
.fold_default(|(mut dlu, mut leu): pair_of!(HashSet<_>), (content, user_id)| async move {
.fold_default(async |(mut dlu, mut leu): pair_of!(HashSet<_>), (content, user_id)| {
use MembershipState::*;
let shares_encrypted_room = async |user_id| {
@@ -1147,18 +1045,22 @@ async fn calculate_state_incremental<'a>(
};
(dlu, leu)
})
.await;
});
let send_member_count = state_events
let send_member_counts = state_events
.iter()
.any(|event| event.kind == RoomMember);
let (joined_member_count, invited_member_count, heroes) = if send_member_count {
calculate_counts(services, room_id, sender_user).await?
} else {
(None, None, None)
};
let member_counts: OptionFuture<_> = send_member_counts
.then(|| calculate_counts(services, room_id, sender_user))
.into();
let (member_counts, device_updates) = join(member_counts, device_updates).await;
let (device_list_updates, left_encrypted_users) = device_updates;
let (joined_member_count, invited_member_count, heroes) =
member_counts.unwrap_or((None, None, None));
Ok(StateChanges {
heroes,
@@ -1191,7 +1093,7 @@ async fn calculate_counts(
services: &Services,
room_id: &RoomId,
sender_user: &UserId,
) -> Result<(Option<u64>, Option<u64>, Option<Vec<OwnedUserId>>)> {
) -> (Option<u64>, Option<u64>, Option<Vec<OwnedUserId>>) {
let joined_member_count = services
.rooms
.state_cache
@@ -1213,7 +1115,7 @@ async fn calculate_counts(
.then(|| calculate_heroes(services, room_id, sender_user))
.into();
Ok((Some(joined_member_count), Some(invited_member_count), heroes.await))
(Some(joined_member_count), Some(invited_member_count), heroes.await)
}
async fn calculate_heroes(