From ac89116316e02ecb415265c6ef8d52bee2c680ad Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 4 Feb 2026 08:16:29 +0000 Subject: [PATCH] Optimize several container types in state res. Optimize mainline sort. Signed-off-by: Jason Volk --- src/core/matrix/state_res/resolve.rs | 51 ++++---- .../state_res/resolve/conflicted_subgraph.rs | 31 +++-- .../state_res/resolve/iterative_auth_check.rs | 5 +- .../matrix/state_res/resolve/mainline_sort.rs | 118 +++++++++++------- .../matrix/state_res/resolve/power_sort.rs | 6 +- .../state_res/resolve/split_conflicted.rs | 6 +- src/core/matrix/state_res/resolve/tests.rs | 21 ++-- 7 files changed, 141 insertions(+), 97 deletions(-) diff --git a/src/core/matrix/state_res/resolve.rs b/src/core/matrix/state_res/resolve.rs index a7003653..931ef912 100644 --- a/src/core/matrix/state_res/resolve.rs +++ b/src/core/matrix/state_res/resolve.rs @@ -9,9 +9,10 @@ mod power_sort; mod split_conflicted; mod topological_sort; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; use futures::{FutureExt, Stream, StreamExt, TryFutureExt}; +use itertools::Itertools; use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules}; pub use self::topological_sort::topological_sort; @@ -25,6 +26,7 @@ use super::test_utils; use crate::{ Result, debug, matrix::{Event, TypeStateKey}, + smallvec::SmallVec, trace, utils::{ BoolExt, @@ -33,9 +35,6 @@ use crate::{ }, }; -/// ConflictMap of OwnedEventId specifically. -pub type ConflictMap = StateMap; - /// A mapping of event type and state_key to some value `T`, usually an /// `EventId`. pub type StateMap = BTreeMap; @@ -43,8 +42,11 @@ pub type StateMap = BTreeMap; /// Full recursive set of `auth_events` for each event in a StateMap. pub type AuthSet = BTreeSet; +/// ConflictMap of OwnedEventId specifically. +pub type ConflictMap = StateMap>; + /// List of conflicting event_ids -type ConflictVec = Vec; +type ConflictVec = SmallVec<[Id; 2]>; /// Apply the [state resolution] algorithm introduced in room version 2 to /// resolve the state of a room. @@ -116,24 +118,25 @@ where .is_some_and(|rules| rules.consider_conflicted_state_subgraph) || backport_css; + let conflicted_states = conflicted_states.values().flatten().cloned(); + // Since `org.matrix.hydra.11`, fetch the conflicted state subgraph. let conflicted_subgraph = consider_conflicted_subgraph - .then(|| conflicted_states.clone().into_values().flatten()) - .map_async(async |ids| conflicted_subgraph_dfs(ids.stream(), fetch)); - - let conflicted_subgraph = conflicted_subgraph - .await - .into_iter() - .stream() - .flatten(); + .then(|| conflicted_states.clone().stream()) + .map_async(async |ids| conflicted_subgraph_dfs(ids, fetch)) + .map(Option::into_iter) + .map(IterStream::stream) + .flatten_stream() + .flatten() + .boxed(); // 0. The full conflicted set is the union of the conflicted state set and the // auth difference. Don't honor events that don't exist. - let full_conflicted_set: AuthSet<_> = auth_difference(auth_sets) - .chain(conflicted_states.into_values().flatten().stream()) + let full_conflicted_set = auth_difference(auth_sets) + .chain(conflicted_states.stream()) .broad_filter_map(async |id| exists(id.clone()).await.then_some(id)) .chain(conflicted_subgraph) - .collect::>() + .collect::>() .inspect(|set| debug!(count = set.len(), "full conflicted set")) .inspect(|set| trace!(?set, "full conflicted set")) .await; @@ -142,14 +145,15 @@ where // set. For each such power event P, enlarge X by adding the events in the // auth chain of P which also belong to the full conflicted set. Sort X into // a list using the reverse topological power ordering. - let sorted_power_events: Vec<_> = power_sort(rules, &full_conflicted_set, fetch) + let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch) .inspect_ok(|list| debug!(count = list.len(), "sorted power events")) .inspect_ok(|list| trace!(?list, "sorted power events")) + .boxed() .await?; - let sorted_power_events_set: AuthSet<_> = sorted_power_events.iter().collect(); + let power_set_event_ids: Vec<_> = sorted_power_set.iter().sorted().collect(); - let sorted_power_events = sorted_power_events + let sorted_power_set = sorted_power_set .iter() .stream() .map(AsRef::as_ref); @@ -167,9 +171,10 @@ where // state map, to the list of events from the previous step to get a partially // resolved state. let partially_resolved_state = - iterative_auth_check(rules, sorted_power_events, initial_state, fetch) + iterative_auth_check(rules, sorted_power_set, initial_state, fetch) .inspect_ok(|map| debug!(count = map.len(), "partially resolved power state")) .inspect_ok(|map| trace!(?map, "partially resolved power state")) + .boxed() .await?; // This "epochs" power level event @@ -179,7 +184,7 @@ where let remaining_events: Vec<_> = full_conflicted_set .into_iter() - .filter(|id| !sorted_power_events_set.contains(id)) + .filter(|id| power_set_event_ids.binary_search(&id).is_err()) .collect(); debug!(count = remaining_events.len(), "remaining events"); @@ -195,7 +200,8 @@ where // the mainline ordering based on the power level in the partially resolved // state obtained in step 2. let sorted_remaining_events = have_remaining_events - .then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch)); + .then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch)) + .boxed(); let sorted_remaining_events = sorted_remaining_events .await @@ -213,6 +219,7 @@ where // and the list of events from the previous step. let mut resolved_state = iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch) + .boxed() .await?; // 5. Update the result by replacing any event with the event with the same key diff --git a/src/core/matrix/state_res/resolve/conflicted_subgraph.rs b/src/core/matrix/state_res/resolve/conflicted_subgraph.rs index 18649d00..237d9fae 100644 --- a/src/core/matrix/state_res/resolve/conflicted_subgraph.rs +++ b/src/core/matrix/state_res/resolve/conflicted_subgraph.rs @@ -1,5 +1,6 @@ use std::{ collections::HashSet as Set, + iter::once, mem::take, sync::{Arc, Mutex}, }; @@ -9,7 +10,8 @@ use ruma::OwnedEventId; use crate::{ Result, debug, - matrix::Event, + matrix::{Event, pdu::AuthEvents}, + smallvec::SmallVec, utils::stream::{IterStream, automatic_width}, }; @@ -19,12 +21,19 @@ struct Global { seen: Mutex>, } -#[derive(Default)] +#[derive(Default, Debug)] struct Local { - path: Vec, - stack: Vec>, + path: Path, + stack: Stack, } +type Path = SmallVec<[OwnedEventId; PATH_INLINE]>; +type Stack = SmallVec<[Frame; STACK_INLINE]>; +type Frame = AuthEvents; + +const PATH_INLINE: usize = 48; +const STACK_INLINE: usize = 48; + #[tracing::instrument(name = "conflicted_subgraph", level = "debug", skip_all)] pub(super) fn conflicted_subgraph_dfs( conflicted_event_ids: ConflictedEventIds, @@ -52,9 +61,11 @@ where }) .await; + let seen = state.seen.lock().expect("locked"); let mut state = state.subgraph.lock().expect("locked"); debug!( - input_event = conflicted_event_ids.len(), + input_events = conflicted_event_ids.len(), + seen_events = seen.len(), output_events = state.len(), "conflicted subgraph state" ); @@ -71,8 +82,8 @@ where level = "trace", skip_all, fields( - event_id = %conflicted_event_id, event_ids = conflicted_event_ids.len(), + event_id = %conflicted_event_id, ) )] async fn subgraph_descent( @@ -89,8 +100,8 @@ where let Global { subgraph, seen } = &*state; let mut local = Local { - path: vec![conflicted_event_id.clone()], - stack: vec![vec![conflicted_event_id]], + path: once(conflicted_event_id.clone()).collect(), + stack: once(once(conflicted_event_id).collect()).collect(), }; while let Some(event_id) = pop(&mut local) { @@ -128,13 +139,13 @@ where fn pop(local: &mut Local) -> Option { let Local { path, stack } = local; - while stack.last().is_some_and(Vec::is_empty) { + while stack.last().is_some_and(Frame::is_empty) { stack.pop(); path.pop(); } stack .last_mut() - .and_then(Vec::pop) + .and_then(Frame::pop) .inspect(|event_id| path.push(event_id.clone())) } diff --git a/src/core/matrix/state_res/resolve/iterative_auth_check.rs b/src/core/matrix/state_res/resolve/iterative_auth_check.rs index ab4a668a..2cf842ec 100644 --- a/src/core/matrix/state_res/resolve/iterative_auth_check.rs +++ b/src/core/matrix/state_res/resolve/iterative_auth_check.rs @@ -12,6 +12,7 @@ use super::{ use crate::{ Error, Result, debug_warn, err, error, matrix::{Event, EventTypeExt, StateKey}, + smallvec::SmallVec, trace, utils::stream::{IterStream, ReadyExt, TryReadyExt, TryWidebandExt}, }; @@ -170,10 +171,10 @@ where Ok(key_val) }); - let auth_events: Vec<_> = auth_events + let auth_events = auth_events .chain(auth_types_events) .try_collect() - .map_ok(|mut vec: Vec<_>| { + .map_ok(|mut vec: SmallVec<[_; 4]>| { vec.sort_by(|a, b| a.0.cmp(&b.0)); vec.reverse(); vec.dedup_by(|a, b| a.0.eq(&b.0)); diff --git a/src/core/matrix/state_res/resolve/mainline_sort.rs b/src/core/matrix/state_res/resolve/mainline_sort.rs index 03561700..b5347bbe 100644 --- a/src/core/matrix/state_res/resolve/mainline_sort.rs +++ b/src/core/matrix/state_res/resolve/mainline_sort.rs @@ -1,13 +1,13 @@ -use std::collections::HashMap; - -use futures::{Stream, StreamExt, TryStreamExt, pin_mut}; +use futures::{ + FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold, +}; use ruma::{EventId, OwnedEventId, events::TimelineEventType}; use crate::{ - Result, + Error, Result, at, is_equal_to, matrix::Event, trace, - utils::stream::{IterStream, TryReadyExt, WidebandExt}, + utils::stream::{BroadbandExt, IterStream, TryReadyExt}, }; /// Perform mainline ordering of the given events. @@ -39,11 +39,14 @@ use crate::{ level = "debug", skip_all, fields( - ?power_level, + event_id = power_level_event_id + .as_deref() + .map(EventId::as_str) + .unwrap_or_default(), ) )] pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>( - mut power_level: Option, + power_level_event_id: Option, events: RemainingEvents, fetch: &Fetch, ) -> Result> @@ -54,49 +57,60 @@ where Pdu: Event, { // Populate the mainline of the power level. - let mut mainline = vec![]; - while let Some(power_level_event_id) = power_level { + let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| { + let Some(power_level_event_id) = power_level_event_id else { + return Ok::<_, Error>(None); + }; + let power_level_event = fetch(power_level_event_id).await?; + let this_event_id = power_level_event.event_id().to_owned(); + let next_event_id = get_power_levels_auth_event(&power_level_event, fetch) + .map_ok(|event| { + event + .as_ref() + .map(Event::event_id) + .map(ToOwned::to_owned) + }) + .await?; - mainline.push(power_level_event.event_id().to_owned()); - power_level = get_power_levels_auth_event(&power_level_event, fetch) - .await? - .map(|event| event.event_id().to_owned()); - } + trace!(?this_event_id, ?next_event_id, "mainline descent",); - let mainline_map: HashMap<_, _> = mainline - .iter() - .rev() - .enumerate() - .map(|(idx, event_id)| (event_id.clone(), idx)) - .collect(); + Ok(Some((this_event_id, next_event_id))) + }) + .try_collect() + .await?; - let order_map: HashMap<_, _> = events - .wide_filter_map(async |event_id| { - let event = fetch(event_id.to_owned()).await.ok()?; - let position = mainline_position(&event, &mainline_map, fetch) + let mainline = mainline.iter().rev().map(AsRef::as_ref); + + events + .map(ToOwned::to_owned) + .broad_filter_map(async |event_id| { + let event = fetch(event_id.clone()).await.ok()?; + let origin_server_ts = event.origin_server_ts(); + let position = mainline_position(Some(event), &mainline, fetch) .await .ok()?; - let event_id = event.event_id().to_owned(); - let origin_server_ts = event.origin_server_ts(); Some((event_id, (position, origin_server_ts))) }) + .inspect(|(event_id, (position, origin_server_ts))| { + trace!(position, ?origin_server_ts, ?event_id, "mainline position"); + }) .collect() - .await; + .map(|mut vec: Vec<_>| { + vec.sort_by(|a, b| { + let (a_pos, a_ots) = &a.1; + let (b_pos, b_ots) = &b.1; + a_pos + .cmp(b_pos) + .then(a_ots.cmp(b_ots)) + .then(a.cmp(b)) + }); - let mut sorted_event_ids: Vec<_> = order_map.keys().cloned().collect(); - - sorted_event_ids.sort_by(|a, b| { - let (a_pos, a_ots) = &order_map[a]; - let (b_pos, b_ots) = &order_map[b]; - a_pos - .cmp(b_pos) - .then(a_ots.cmp(b_ots)) - .then(a.cmp(b)) - }); - - Ok(sorted_event_ids) + vec.into_iter().map(at!(0)).collect() + }) + .map(Ok) + .await } /// Get the mainline position of the given event from the given mainline map. @@ -112,30 +126,38 @@ where /// Returns the mainline position of the event, or an `Err(_)` if one of the /// events in the auth chain of the event was not found. #[tracing::instrument( + name = "position", level = "trace", + ret(level = "trace"), skip_all, fields( - event = ?event.event_id(), - mainline = mainline_map.len(), + mainline = mainline.clone().count(), + event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned), ) )] -async fn mainline_position( - event: &Pdu, - mainline_map: &HashMap, +async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>( + mut current_event: Option, + mainline: &Mainline, fetch: &Fetch, ) -> Result where + Mainline: Iterator + Clone + Send + Sync, Fetch: Fn(OwnedEventId) -> Fut + Sync, Fut: Future> + Send, Pdu: Event, { - let mut current_event = Some(event.clone()); while let Some(event) = current_event { - trace!(event_id = ?event.event_id(), "mainline"); + trace!( + event_id = ?event.event_id(), + "mainline position search", + ); // If the current event is in the mainline map, return its position. - if let Some(position) = mainline_map.get(event.event_id()) { - return Ok(*position); + if let Some(position) = mainline + .clone() + .position(is_equal_to!(event.event_id())) + { + return Ok(position); } // Look for the power levels event in the auth events. diff --git a/src/core/matrix/state_res/resolve/power_sort.rs b/src/core/matrix/state_res/resolve/power_sort.rs index b192cc00..c95eb4c2 100644 --- a/src/core/matrix/state_res/resolve/power_sort.rs +++ b/src/core/matrix/state_res/resolve/power_sort.rs @@ -15,7 +15,7 @@ use super::{ RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event, power_levels::RoomPowerLevelsEventOptionExt, }, - AuthSet, topological_sort, + topological_sort, }; use crate::{ Result, err, @@ -50,7 +50,7 @@ use crate::{ )] pub(super) async fn power_sort( rules: &RoomVersionRules, - full_conflicted_set: &AuthSet, + full_conflicted_set: &HashSet, fetch: &Fetch, ) -> Result> where @@ -108,7 +108,7 @@ where )] async fn add_event_auth_chain( mut graph: HashMap>, - full_conflicted_set: &AuthSet, + full_conflicted_set: &HashSet, event_id: &EventId, fetch: &Fetch, ) -> HashMap> diff --git a/src/core/matrix/state_res/resolve/split_conflicted.rs b/src/core/matrix/state_res/resolve/split_conflicted.rs index 9313a213..117bb407 100644 --- a/src/core/matrix/state_res/resolve/split_conflicted.rs +++ b/src/core/matrix/state_res/resolve/split_conflicted.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, hash::Hash}; use futures::{Stream, StreamExt}; -use super::StateMap; +use super::{ConflictMap, StateMap}; use crate::validated; /// Split the unconflicted state map and the conflicted state set. @@ -28,7 +28,7 @@ use crate::validated; #[tracing::instrument(name = "split", level = "debug", skip_all)] pub(super) async fn split_conflicted_state<'a, Maps, Id>( state_maps: Maps, -) -> (StateMap, StateMap>) +) -> (StateMap, ConflictMap) where Maps: Stream>, Id: Clone + Eq + Hash + Ord + Send + Sync + 'a, @@ -52,7 +52,7 @@ where } let mut unconflicted_state_map = StateMap::new(); - let mut conflicted_state_set = StateMap::>::new(); + let mut conflicted_state_set = ConflictMap::new(); for (k, v) in occurrences { for (id, occurrence_count) in v { diff --git a/src/core/matrix/state_res/resolve/tests.rs b/src/core/matrix/state_res/resolve/tests.rs index 88b32769..e2f53705 100644 --- a/src/core/matrix/state_res/resolve/tests.rs +++ b/src/core/matrix/state_res/resolve/tests.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::{ + collections::{HashMap, HashSet}, + iter::once, +}; use futures::StreamExt; use maplit::{hashmap, hashset}; @@ -16,7 +19,7 @@ use ruma::{ use serde_json::{json, value::to_raw_value as to_raw_json_value}; use super::{ - AuthSet, StateMap, + StateMap, test_utils::{ INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id, member_content_ban, member_content_join, not_found, room_id, to_init_pdu_event, @@ -39,7 +42,7 @@ async fn test_event_sort() { let rules = RoomVersionRules::V6; let events = INITIAL_EVENTS(); - let auth_chain: AuthSet = AuthSet::new(); + let auth_chain: HashSet = HashSet::new(); let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| { events.get(&id).cloned().ok_or_else(not_found) @@ -752,9 +755,9 @@ async fn split_conflicted_state_set_conflicted_unique_state_keys() { assert_eq!(unconflicted, StateMap::new()); assert_eq!(conflicted, state_set![ - StateEventType::RoomMember => "@a:hs1" => vec![0], - StateEventType::RoomMember => "@b:hs1" => vec![1], - StateEventType::RoomMember => "@c:hs1" => vec![2], + StateEventType::RoomMember => "@a:hs1" => once(0).collect(), + StateEventType::RoomMember => "@b:hs1" => once(1).collect(), + StateEventType::RoomMember => "@c:hs1" => once(2).collect(), ],); } @@ -781,7 +784,7 @@ async fn split_conflicted_state_set_conflicted_same_state_key() { assert_eq!(unconflicted, StateMap::new()); assert_eq!(conflicted, state_set![ - StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2], + StateEventType::RoomMember => "@a:hs1" => [0, 1, 2].into_iter().collect(), ],); } @@ -833,7 +836,7 @@ async fn split_conflicted_state_set_mixed() { StateEventType::RoomMember => "@a:hs1" => 0, ],); assert_eq!(conflicted, state_set![ - StateEventType::RoomMember => "@b:hs1" => vec![1], - StateEventType::RoomMember => "@c:hs1" => vec![2], + StateEventType::RoomMember => "@b:hs1" => once(1).collect(), + StateEventType::RoomMember => "@c:hs1" => once(2).collect(), ],); }