Optimize several container types in state res.

Optimize mainline sort.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-04 08:16:29 +00:00
parent ec1359ea5e
commit ac89116316
7 changed files with 141 additions and 97 deletions

View File

@@ -9,9 +9,10 @@ mod power_sort;
mod split_conflicted;
mod topological_sort;
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashSet};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use itertools::Itertools;
use ruma::{OwnedEventId, events::StateEventType, room_version_rules::RoomVersionRules};
pub use self::topological_sort::topological_sort;
@@ -25,6 +26,7 @@ use super::test_utils;
use crate::{
Result, debug,
matrix::{Event, TypeStateKey},
smallvec::SmallVec,
trace,
utils::{
BoolExt,
@@ -33,9 +35,6 @@ use crate::{
},
};
/// ConflictMap of OwnedEventId specifically.
pub type ConflictMap = StateMap<ConflictVec>;
/// A mapping of event type and state_key to some value `T`, usually an
/// `EventId`.
pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
@@ -43,8 +42,11 @@ pub type StateMap<Id> = BTreeMap<TypeStateKey, Id>;
/// Full recursive set of `auth_events` for each event in a StateMap.
pub type AuthSet<Id> = BTreeSet<Id>;
/// ConflictMap of OwnedEventId specifically.
pub type ConflictMap<Id> = StateMap<ConflictVec<Id>>;
/// List of conflicting event_ids
type ConflictVec = Vec<OwnedEventId>;
type ConflictVec<Id> = SmallVec<[Id; 2]>;
/// Apply the [state resolution] algorithm introduced in room version 2 to
/// resolve the state of a room.
@@ -116,24 +118,25 @@ where
.is_some_and(|rules| rules.consider_conflicted_state_subgraph)
|| backport_css;
let conflicted_states = conflicted_states.values().flatten().cloned();
// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
let conflicted_subgraph = consider_conflicted_subgraph
.then(|| conflicted_states.clone().into_values().flatten())
.map_async(async |ids| conflicted_subgraph_dfs(ids.stream(), fetch));
let conflicted_subgraph = conflicted_subgraph
.await
.into_iter()
.stream()
.flatten();
.then(|| conflicted_states.clone().stream())
.map_async(async |ids| conflicted_subgraph_dfs(ids, fetch))
.map(Option::into_iter)
.map(IterStream::stream)
.flatten_stream()
.flatten()
.boxed();
// 0. The full conflicted set is the union of the conflicted state set and the
// auth difference. Don't honor events that don't exist.
let full_conflicted_set: AuthSet<_> = auth_difference(auth_sets)
.chain(conflicted_states.into_values().flatten().stream())
let full_conflicted_set = auth_difference(auth_sets)
.chain(conflicted_states.stream())
.broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
.chain(conflicted_subgraph)
.collect::<AuthSet<_>>()
.collect::<HashSet<_>>()
.inspect(|set| debug!(count = set.len(), "full conflicted set"))
.inspect(|set| trace!(?set, "full conflicted set"))
.await;
@@ -142,14 +145,15 @@ where
// set. For each such power event P, enlarge X by adding the events in the
// auth chain of P which also belong to the full conflicted set. Sort X into
// a list using the reverse topological power ordering.
let sorted_power_events: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
let sorted_power_set: Vec<_> = power_sort(rules, &full_conflicted_set, fetch)
.inspect_ok(|list| debug!(count = list.len(), "sorted power events"))
.inspect_ok(|list| trace!(?list, "sorted power events"))
.boxed()
.await?;
let sorted_power_events_set: AuthSet<_> = sorted_power_events.iter().collect();
let power_set_event_ids: Vec<_> = sorted_power_set.iter().sorted().collect();
let sorted_power_events = sorted_power_events
let sorted_power_set = sorted_power_set
.iter()
.stream()
.map(AsRef::as_ref);
@@ -167,9 +171,10 @@ where
// state map, to the list of events from the previous step to get a partially
// resolved state.
let partially_resolved_state =
iterative_auth_check(rules, sorted_power_events, initial_state, fetch)
iterative_auth_check(rules, sorted_power_set, initial_state, fetch)
.inspect_ok(|map| debug!(count = map.len(), "partially resolved power state"))
.inspect_ok(|map| trace!(?map, "partially resolved power state"))
.boxed()
.await?;
// This "epochs" power level event
@@ -179,7 +184,7 @@ where
let remaining_events: Vec<_> = full_conflicted_set
.into_iter()
.filter(|id| !sorted_power_events_set.contains(id))
.filter(|id| power_set_event_ids.binary_search(&id).is_err())
.collect();
debug!(count = remaining_events.len(), "remaining events");
@@ -195,7 +200,8 @@ where
// the mainline ordering based on the power level in the partially resolved
// state obtained in step 2.
let sorted_remaining_events = have_remaining_events
.then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch));
.then_async(move || mainline_sort(power_event.cloned(), remaining_events, fetch))
.boxed();
let sorted_remaining_events = sorted_remaining_events
.await
@@ -213,6 +219,7 @@ where
// and the list of events from the previous step.
let mut resolved_state =
iterative_auth_check(rules, sorted_remaining_events, partially_resolved_state, fetch)
.boxed()
.await?;
// 5. Update the result by replacing any event with the event with the same key

View File

@@ -1,5 +1,6 @@
use std::{
collections::HashSet as Set,
iter::once,
mem::take,
sync::{Arc, Mutex},
};
@@ -9,7 +10,8 @@ use ruma::OwnedEventId;
use crate::{
Result, debug,
matrix::Event,
matrix::{Event, pdu::AuthEvents},
smallvec::SmallVec,
utils::stream::{IterStream, automatic_width},
};
@@ -19,12 +21,19 @@ struct Global {
seen: Mutex<Set<OwnedEventId>>,
}
#[derive(Default)]
#[derive(Default, Debug)]
struct Local {
path: Vec<OwnedEventId>,
stack: Vec<Vec<OwnedEventId>>,
path: Path,
stack: Stack,
}
type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
type Stack = SmallVec<[Frame; STACK_INLINE]>;
type Frame = AuthEvents;
const PATH_INLINE: usize = 48;
const STACK_INLINE: usize = 48;
#[tracing::instrument(name = "conflicted_subgraph", level = "debug", skip_all)]
pub(super) fn conflicted_subgraph_dfs<ConflictedEventIds, Fetch, Fut, Pdu>(
conflicted_event_ids: ConflictedEventIds,
@@ -52,9 +61,11 @@ where
})
.await;
let seen = state.seen.lock().expect("locked");
let mut state = state.subgraph.lock().expect("locked");
debug!(
input_event = conflicted_event_ids.len(),
input_events = conflicted_event_ids.len(),
seen_events = seen.len(),
output_events = state.len(),
"conflicted subgraph state"
);
@@ -71,8 +82,8 @@ where
level = "trace",
skip_all,
fields(
event_id = %conflicted_event_id,
event_ids = conflicted_event_ids.len(),
event_id = %conflicted_event_id,
)
)]
async fn subgraph_descent<Fetch, Fut, Pdu>(
@@ -89,8 +100,8 @@ where
let Global { subgraph, seen } = &*state;
let mut local = Local {
path: vec![conflicted_event_id.clone()],
stack: vec![vec![conflicted_event_id]],
path: once(conflicted_event_id.clone()).collect(),
stack: once(once(conflicted_event_id).collect()).collect(),
};
while let Some(event_id) = pop(&mut local) {
@@ -128,13 +139,13 @@ where
fn pop(local: &mut Local) -> Option<OwnedEventId> {
let Local { path, stack } = local;
while stack.last().is_some_and(Vec::is_empty) {
while stack.last().is_some_and(Frame::is_empty) {
stack.pop();
path.pop();
}
stack
.last_mut()
.and_then(Vec::pop)
.and_then(Frame::pop)
.inspect(|event_id| path.push(event_id.clone()))
}

View File

@@ -12,6 +12,7 @@ use super::{
use crate::{
Error, Result, debug_warn, err, error,
matrix::{Event, EventTypeExt, StateKey},
smallvec::SmallVec,
trace,
utils::stream::{IterStream, ReadyExt, TryReadyExt, TryWidebandExt},
};
@@ -170,10 +171,10 @@ where
Ok(key_val)
});
let auth_events: Vec<_> = auth_events
let auth_events = auth_events
.chain(auth_types_events)
.try_collect()
.map_ok(|mut vec: Vec<_>| {
.map_ok(|mut vec: SmallVec<[_; 4]>| {
vec.sort_by(|a, b| a.0.cmp(&b.0));
vec.reverse();
vec.dedup_by(|a, b| a.0.eq(&b.0));

View File

@@ -1,13 +1,13 @@
use std::collections::HashMap;
use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold,
};
use ruma::{EventId, OwnedEventId, events::TimelineEventType};
use crate::{
Result,
Error, Result, at, is_equal_to,
matrix::Event,
trace,
utils::stream::{IterStream, TryReadyExt, WidebandExt},
utils::stream::{BroadbandExt, IterStream, TryReadyExt},
};
/// Perform mainline ordering of the given events.
@@ -39,11 +39,14 @@ use crate::{
level = "debug",
skip_all,
fields(
?power_level,
event_id = power_level_event_id
.as_deref()
.map(EventId::as_str)
.unwrap_or_default(),
)
)]
pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>(
mut power_level: Option<OwnedEventId>,
power_level_event_id: Option<OwnedEventId>,
events: RemainingEvents,
fetch: &Fetch,
) -> Result<Vec<OwnedEventId>>
@@ -54,49 +57,60 @@ where
Pdu: Event,
{
// Populate the mainline of the power level.
let mut mainline = vec![];
while let Some(power_level_event_id) = power_level {
let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| {
let Some(power_level_event_id) = power_level_event_id else {
return Ok::<_, Error>(None);
};
let power_level_event = fetch(power_level_event_id).await?;
let this_event_id = power_level_event.event_id().to_owned();
let next_event_id = get_power_levels_auth_event(&power_level_event, fetch)
.map_ok(|event| {
event
.as_ref()
.map(Event::event_id)
.map(ToOwned::to_owned)
})
.await?;
mainline.push(power_level_event.event_id().to_owned());
power_level = get_power_levels_auth_event(&power_level_event, fetch)
.await?
.map(|event| event.event_id().to_owned());
}
trace!(?this_event_id, ?next_event_id, "mainline descent",);
let mainline_map: HashMap<_, _> = mainline
.iter()
.rev()
.enumerate()
.map(|(idx, event_id)| (event_id.clone(), idx))
.collect();
Ok(Some((this_event_id, next_event_id)))
})
.try_collect()
.await?;
let order_map: HashMap<_, _> = events
.wide_filter_map(async |event_id| {
let event = fetch(event_id.to_owned()).await.ok()?;
let position = mainline_position(&event, &mainline_map, fetch)
let mainline = mainline.iter().rev().map(AsRef::as_ref);
events
.map(ToOwned::to_owned)
.broad_filter_map(async |event_id| {
let event = fetch(event_id.clone()).await.ok()?;
let origin_server_ts = event.origin_server_ts();
let position = mainline_position(Some(event), &mainline, fetch)
.await
.ok()?;
let event_id = event.event_id().to_owned();
let origin_server_ts = event.origin_server_ts();
Some((event_id, (position, origin_server_ts)))
})
.inspect(|(event_id, (position, origin_server_ts))| {
trace!(position, ?origin_server_ts, ?event_id, "mainline position");
})
.collect()
.await;
.map(|mut vec: Vec<_>| {
vec.sort_by(|a, b| {
let (a_pos, a_ots) = &a.1;
let (b_pos, b_ots) = &b.1;
a_pos
.cmp(b_pos)
.then(a_ots.cmp(b_ots))
.then(a.cmp(b))
});
let mut sorted_event_ids: Vec<_> = order_map.keys().cloned().collect();
sorted_event_ids.sort_by(|a, b| {
let (a_pos, a_ots) = &order_map[a];
let (b_pos, b_ots) = &order_map[b];
a_pos
.cmp(b_pos)
.then(a_ots.cmp(b_ots))
.then(a.cmp(b))
});
Ok(sorted_event_ids)
vec.into_iter().map(at!(0)).collect()
})
.map(Ok)
.await
}
/// Get the mainline position of the given event from the given mainline map.
@@ -112,30 +126,38 @@ where
/// Returns the mainline position of the event, or an `Err(_)` if one of the
/// events in the auth chain of the event was not found.
#[tracing::instrument(
name = "position",
level = "trace",
ret(level = "trace"),
skip_all,
fields(
event = ?event.event_id(),
mainline = mainline_map.len(),
mainline = mainline.clone().count(),
event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned),
)
)]
async fn mainline_position<Fetch, Fut, Pdu>(
event: &Pdu,
mainline_map: &HashMap<OwnedEventId, usize>,
async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>(
mut current_event: Option<Pdu>,
mainline: &Mainline,
fetch: &Fetch,
) -> Result<usize>
where
Mainline: Iterator<Item = &'a EventId> + Clone + Send + Sync,
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let mut current_event = Some(event.clone());
while let Some(event) = current_event {
trace!(event_id = ?event.event_id(), "mainline");
trace!(
event_id = ?event.event_id(),
"mainline position search",
);
// If the current event is in the mainline map, return its position.
if let Some(position) = mainline_map.get(event.event_id()) {
return Ok(*position);
if let Some(position) = mainline
.clone()
.position(is_equal_to!(event.event_id()))
{
return Ok(position);
}
// Look for the power levels event in the auth events.

View File

@@ -15,7 +15,7 @@ use super::{
RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
power_levels::RoomPowerLevelsEventOptionExt,
},
AuthSet, topological_sort,
topological_sort,
};
use crate::{
Result, err,
@@ -50,7 +50,7 @@ use crate::{
)]
pub(super) async fn power_sort<Fetch, Fut, Pdu>(
rules: &RoomVersionRules,
full_conflicted_set: &AuthSet<OwnedEventId>,
full_conflicted_set: &HashSet<OwnedEventId>,
fetch: &Fetch,
) -> Result<Vec<OwnedEventId>>
where
@@ -108,7 +108,7 @@ where
)]
async fn add_event_auth_chain<Fetch, Fut, Pdu>(
mut graph: HashMap<OwnedEventId, HashSet<OwnedEventId>>,
full_conflicted_set: &AuthSet<OwnedEventId>,
full_conflicted_set: &HashSet<OwnedEventId>,
event_id: &EventId,
fetch: &Fetch,
) -> HashMap<OwnedEventId, HashSet<OwnedEventId>>

View File

@@ -2,7 +2,7 @@ use std::{collections::HashMap, hash::Hash};
use futures::{Stream, StreamExt};
use super::StateMap;
use super::{ConflictMap, StateMap};
use crate::validated;
/// Split the unconflicted state map and the conflicted state set.
@@ -28,7 +28,7 @@ use crate::validated;
#[tracing::instrument(name = "split", level = "debug", skip_all)]
pub(super) async fn split_conflicted_state<'a, Maps, Id>(
state_maps: Maps,
) -> (StateMap<Id>, StateMap<Vec<Id>>)
) -> (StateMap<Id>, ConflictMap<Id>)
where
Maps: Stream<Item = StateMap<Id>>,
Id: Clone + Eq + Hash + Ord + Send + Sync + 'a,
@@ -52,7 +52,7 @@ where
}
let mut unconflicted_state_map = StateMap::new();
let mut conflicted_state_set = StateMap::<Vec<Id>>::new();
let mut conflicted_state_set = ConflictMap::new();
for (k, v) in occurrences {
for (id, occurrence_count) in v {

View File

@@ -1,4 +1,7 @@
use std::collections::HashMap;
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use futures::StreamExt;
use maplit::{hashmap, hashset};
@@ -16,7 +19,7 @@ use ruma::{
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use super::{
AuthSet, StateMap,
StateMap,
test_utils::{
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
member_content_ban, member_content_join, not_found, room_id, to_init_pdu_event,
@@ -39,7 +42,7 @@ async fn test_event_sort() {
let rules = RoomVersionRules::V6;
let events = INITIAL_EVENTS();
let auth_chain: AuthSet<OwnedEventId> = AuthSet::new();
let auth_chain: HashSet<OwnedEventId> = HashSet::new();
let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| {
events.get(&id).cloned().ok_or_else(not_found)
@@ -752,9 +755,9 @@ async fn split_conflicted_state_set_conflicted_unique_state_keys() {
assert_eq!(unconflicted, StateMap::new());
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => vec![0],
StateEventType::RoomMember => "@b:hs1" => vec![1],
StateEventType::RoomMember => "@c:hs1" => vec![2],
StateEventType::RoomMember => "@a:hs1" => once(0).collect(),
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
],);
}
@@ -781,7 +784,7 @@ async fn split_conflicted_state_set_conflicted_same_state_key() {
assert_eq!(unconflicted, StateMap::new());
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2],
StateEventType::RoomMember => "@a:hs1" => [0, 1, 2].into_iter().collect(),
],);
}
@@ -833,7 +836,7 @@ async fn split_conflicted_state_set_mixed() {
StateEventType::RoomMember => "@a:hs1" => 0,
],);
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@b:hs1" => vec![1],
StateEventType::RoomMember => "@c:hs1" => vec![2],
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
],);
}