Optimize reference graph container value type for topological_sort.

Optimize initial container capacity estimates.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-14 10:23:47 +00:00
parent b7ea9714e8
commit 1bd4ab0601
15 changed files with 192 additions and 155 deletions

View File

@@ -53,9 +53,9 @@ type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
level = INFO_SPAN_LEVEL,
skip_all,
fields(
%client,
origin = body.origin().as_str(),
txn = str_truncated(body.transaction_id.as_str(), 20),
origin = body.origin().as_str(),
%client,
),
)]
pub(crate) async fn send_transaction_message_route(
@@ -189,7 +189,7 @@ async fn handle_pdus(
#[tracing::instrument(
name = "room",
level = "debug",
level = INFO_SPAN_LEVEL,
skip_all,
fields(%room_id)
)]

View File

@@ -41,16 +41,20 @@ criterion_main!(benches);
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
#[expect(
clippy::iter_on_single_items,
clippy::iter_on_empty_collections
)]
fn lexico_topo_sort(c: &mut Criterion) {
c.bench_function("lexico_topo_sort", |c| {
use maplit::{hashmap, hashset};
use maplit::hashmap;
let graph = hashmap! {
event_id("l") => hashset![event_id("o")],
event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => hashset![event_id("o")],
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => hashset![event_id("o")],
event_id("l") => [event_id("o")].into_iter().collect(),
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
event_id("n") => [event_id("o")].into_iter().collect(),
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => [event_id("o")].into_iter().collect(),
};
c.to_async(FuturesExecutor).iter(async || {

View File

@@ -65,7 +65,7 @@ mod fetch_state;
mod resolve;
#[cfg(test)]
mod test_utils;
mod topological_sort;
pub mod topological_sort;
use self::{event_auth::check_state_dependent_auth_rules, fetch_state::FetchStateExt};
pub use self::{

View File

@@ -8,7 +8,10 @@ mod mainline_sort;
mod power_sort;
mod split_conflicted;
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::{
collections::{BTreeMap, BTreeSet, HashSet},
ops::Deref,
};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
use itertools::Itertools;
@@ -28,7 +31,6 @@ use crate::{
trace,
utils::{
BoolExt,
option::OptionExt,
stream::{BroadbandExt, IterStream},
},
};
@@ -242,20 +244,25 @@ where
.is_some_and(|rules| rules.consider_conflicted_state_subgraph)
|| backport_css;
let conflicted_states = conflicted_states.values().flatten().cloned();
let conflicted_state_set: HashSet<_> = conflicted_states.values().flatten().collect();
// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
let conflicted_subgraph = consider_conflicted_subgraph
.then(|| conflicted_states.clone().stream())
.map_async(async |ids| conflicted_subgraph_dfs(ids, fetch))
.then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
.map(Option::into_iter)
.map(IterStream::stream)
.flatten_stream()
.flatten()
.boxed();
let conflicted_state_ids = conflicted_state_set
.iter()
.map(Deref::deref)
.cloned()
.stream();
auth_difference(auth_sets)
.chain(conflicted_states.stream())
.chain(conflicted_state_ids)
.broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
.chain(conflicted_subgraph)
.collect::<HashSet<_>>()

View File

@@ -35,32 +35,28 @@ const PATH_INLINE: usize = 48;
const STACK_INLINE: usize = 48;
#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)]
pub(super) fn conflicted_subgraph_dfs<ConflictedEventIds, Fetch, Fut, Pdu>(
conflicted_event_ids: ConflictedEventIds,
pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
conflicted_event_ids: &Set<&OwnedEventId>,
fetch: &Fetch,
) -> impl Stream<Item = OwnedEventId> + Send
where
ConflictedEventIds: Stream<Item = OwnedEventId> + Send,
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
conflicted_event_ids
.collect::<Set<_>>()
.map(|ids| (Arc::new(Global::default()), ids))
.then(async |(state, conflicted_event_ids)| {
let state = Arc::new(Global::default());
let state_ = state.clone();
conflicted_event_ids
.iter()
.stream()
.map(ToOwned::to_owned)
.map(|event_id| (state.clone(), event_id))
.for_each_concurrent(automatic_width(), async |(state, event_id)| {
subgraph_descent(state, event_id, &conflicted_event_ids, fetch)
.enumerate()
.map(move |(i, event_id)| (state_.clone(), event_id, i))
.for_each_concurrent(automatic_width(), async |(state, event_id, i)| {
subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i)
.await
.expect("only mutex errors expected");
})
.await;
.map(move |()| {
let seen = state.seen.lock().expect("locked");
let mut state = state.subgraph.lock().expect("locked");
debug!(
@@ -70,10 +66,8 @@ where
"conflicted subgraph state"
);
take(&mut *state)
take(&mut *state).into_iter().stream()
})
.map(Set::into_iter)
.map(IterStream::stream)
.flatten_stream()
}
@@ -84,24 +78,26 @@ where
fields(
event_ids = conflicted_event_ids.len(),
event_id = %conflicted_event_id,
%i,
)
)]
async fn subgraph_descent<Fetch, Fut, Pdu>(
state: Arc<Global>,
conflicted_event_id: OwnedEventId,
conflicted_event_ids: &Set<OwnedEventId>,
conflicted_event_ids: &Set<&OwnedEventId>,
fetch: &Fetch,
) -> Result<Arc<Global>>
state: &Arc<Global>,
conflicted_event_id: &OwnedEventId,
i: usize,
) -> Result
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let Global { subgraph, seen } = &*state;
let Global { subgraph, seen } = &**state;
let mut local = Local {
path: once(conflicted_event_id.clone()).collect(),
stack: once(once(conflicted_event_id).collect()).collect(),
stack: once(once(conflicted_event_id.clone()).collect()).collect(),
};
while let Some(event_id) = pop(&mut local) {
@@ -133,7 +129,7 @@ where
}
}
Ok(state)
Ok(())
}
fn pop(local: &mut Local) -> Option<OwnedEventId> {

View File

@@ -193,7 +193,12 @@ where
// Add authentic event to the partially resolved state.
if check_state_dependent_auth_rules(rules, &event, &fetch_state)
.await
.inspect_err(|e| debug_warn!("event failed auth check: {e}"))
.inspect_err(|e| {
debug_warn!(
event_type = ?event.event_type(), ?state_key, %event_id,
"event failed auth check: {e}"
);
})
.is_ok()
{
let key = event.event_type().with_state_key(state_key);

View File

@@ -16,6 +16,7 @@ use super::super::{
power_levels::RoomPowerLevelsEventOptionExt,
},
topological_sort,
topological_sort::ReferencedIds,
};
use crate::{
Result, err,
@@ -107,11 +108,11 @@ where
)
)]
async fn add_event_auth_chain<Fetch, Fut, Pdu>(
mut graph: HashMap<OwnedEventId, HashSet<OwnedEventId>>,
mut graph: HashMap<OwnedEventId, ReferencedIds>,
full_conflicted_set: &HashSet<OwnedEventId>,
event_id: &EventId,
fetch: &Fetch,
) -> HashMap<OwnedEventId, HashSet<OwnedEventId>>
) -> HashMap<OwnedEventId, ReferencedIds>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
@@ -127,22 +128,17 @@ where
// Add the current event to the graph.
graph.entry(event_id).or_default();
let auth_events = event
event
.as_ref()
.map(Event::auth_events)
.into_iter()
.flatten();
for auth_event_id in auth_events {
// If the auth event ID is in the full conflicted set…
if !full_conflicted_set.contains(auth_event_id) {
continue;
}
// If the auth event ID is not in the graph, we need to check its auth events
// later.
if !graph.contains_key(auth_event_id) {
state.push(auth_event_id.to_owned());
.flatten()
.map(ToOwned::to_owned)
.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
.for_each(|auth_event_id| {
// If the auth event ID is not in the graph, check its auth events later.
if !graph.contains_key(&auth_event_id) {
state.push(auth_event_id.clone());
}
let event_id = event
@@ -151,11 +147,14 @@ where
.event_id();
// Add the auth event ID to the list of incoming edges.
graph
let references = graph
.get_mut(event_id)
.expect("event_id must be added to graph")
.insert(auth_event_id.to_owned());
.expect("event_id must be added to graph");
if !references.contains(&auth_event_id) {
references.push(auth_event_id);
}
});
}
graph

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, hash::Hash};
use std::{collections::HashMap, hash::Hash, iter::IntoIterator};
use futures::{Stream, StreamExt};
@@ -35,13 +35,18 @@ where
{
let state_maps: Vec<_> = state_maps.collect().await;
let mut state_set_count = 0_usize;
let mut occurrences = HashMap::<_, HashMap<_, usize>>::new();
let state_maps = state_maps.iter().inspect(|_state| {
state_set_count = validated!(state_set_count + 1);
});
let state_ids_est = state_maps.iter().flatten().count();
for (k, v) in state_maps.into_iter().flat_map(|s| s.iter()) {
let state_set_count = state_maps
.iter()
.fold(0_usize, |acc, _| validated!(acc + 1));
let mut occurrences = HashMap::<_, HashMap<_, usize>>::with_capacity(state_ids_est);
for (k, v) in state_maps
.into_iter()
.flat_map(IntoIterator::into_iter)
{
let acc = occurrences
.entry(k.clone())
.or_default()
@@ -59,10 +64,16 @@ where
if occurrence_count == state_set_count {
unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
} else {
conflicted_state_set
let conflicts = conflicted_state_set
.entry((k.0.clone(), k.1.clone()))
.or_default()
.push(id.clone());
.or_default();
debug_assert!(
!conflicts.contains(&id),
"Unexpected duplicate conflicted state event"
);
conflicts.push(id.clone());
}
}
}

View File

@@ -1,10 +1,7 @@
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use std::{collections::HashMap, iter::once};
use futures::StreamExt;
use maplit::{hashmap, hashset};
use maplit::hashmap;
use rand::seq::SliceRandom;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId,
@@ -42,7 +39,7 @@ async fn test_event_sort() {
let rules = RoomVersionRules::V6;
let events = INITIAL_EVENTS();
let auth_chain: HashSet<OwnedEventId> = HashSet::new();
let auth_chain = Default::default();
let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| {
events.get(&id).cloned().ok_or_else(not_found)
@@ -484,6 +481,10 @@ async fn test_event_map_none() {
}
#[tokio::test]
#[expect(
clippy::iter_on_single_items,
clippy::iter_on_empty_collections
)]
async fn test_reverse_topological_power_sort() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
@@ -492,11 +493,11 @@ async fn test_reverse_topological_power_sort() {
);
let graph = hashmap! {
event_id("l") => hashset![event_id("o")],
event_id("m") => hashset![event_id("n"), event_id("o")],
event_id("n") => hashset![event_id("o")],
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => hashset![event_id("o")],
event_id("l") => [event_id("o")].into_iter().collect(),
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
event_id("n") => [event_id("o")].into_iter().collect(),
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => [event_id("o")].into_iter().collect(),
};
let res = super::super::topological_sort(&graph, &async |_id| {

View File

@@ -1,6 +1,6 @@
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
collections::HashMap,
pin::Pin,
slice,
sync::{
@@ -30,7 +30,10 @@ use serde_json::{
use super::{AuthSet, StateMap, auth_types_for_event, events::RoomCreateEvent};
use crate::{
Error, Result, err, info,
matrix::{Event, EventHash, EventTypeExt, PduEvent, StateKey},
matrix::{
Event, EventHash, EventTypeExt, PduEvent, StateKey,
state_res::topological_sort::ReferencedIds,
},
utils::stream::IterStream,
};
@@ -64,14 +67,14 @@ pub(super) async fn do_check(
);
// This will be lexi_topo_sorted for resolution
let mut graph = HashMap::new();
let mut graph = HashMap::<OwnedEventId, ReferencedIds>::new();
// This is the same as in `resolve` event_id -> OriginalStateEvent
let mut fake_event_map = HashMap::new();
// Create the DB of events that led up to this point
// TODO maybe clean up some of these clones it is just tests but...
for ev in init_events.values().chain(events) {
graph.insert(ev.event_id().to_owned(), HashSet::new());
graph.insert(ev.event_id().to_owned(), Default::default());
fake_event_map.insert(ev.event_id().to_owned(), ev.clone());
}
@@ -79,8 +82,8 @@ pub(super) async fn do_check(
if let [a, b] = &pair {
graph
.entry(a.to_owned())
.or_insert_with(HashSet::new)
.insert(b.clone());
.or_default()
.push(b.clone());
}
}
@@ -89,8 +92,8 @@ pub(super) async fn do_check(
if let [a, b] = &pair {
graph
.entry(a.to_owned())
.or_insert_with(HashSet::new)
.insert(b.clone());
.or_default()
.push(b.clone());
}
}
}

View File

@@ -24,37 +24,40 @@
use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
collections::{BinaryHeap, HashMap},
};
use futures::TryStreamExt;
use ruma::{
EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
};
use crate::{Result, utils::stream::IterStream};
#[derive(PartialEq, Eq)]
struct TieBreaker<'a> {
power_level: UserPowerLevel,
origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: &'a EventId,
}
use crate::{
Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
};
pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
#[derive(PartialEq, Eq)]
struct TieBreaker {
power_level: UserPowerLevel,
origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: OwnedEventId,
}
// NOTE: the power level comparison is "backwards" intentionally.
impl Ord for TieBreaker<'_> {
impl Ord for TieBreaker {
fn cmp(&self, other: &Self) -> Ordering {
other
.power_level
.cmp(&self.power_level)
.then(self.origin_server_ts.cmp(&other.origin_server_ts))
.then(self.event_id.cmp(other.event_id))
.then(self.event_id.cmp(&other.event_id))
}
}
impl PartialOrd for TieBreaker<'_> {
impl PartialOrd for TieBreaker {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}
@@ -78,21 +81,34 @@ impl PartialOrd for TieBreaker<'_> {
graph = graph.len(),
)
)]
#[expect(clippy::implicit_hasher)]
#[expect(clippy::implicit_hasher, clippy::or_fun_call)]
pub async fn topological_sort<Query, Fut>(
graph: &HashMap<OwnedEventId, HashSet<OwnedEventId>>,
graph: &HashMap<OwnedEventId, ReferencedIds>,
query: &Query,
) -> Result<Vec<OwnedEventId>>
where
Query: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<PduInfo>> + Send,
{
let query = async |event_id: OwnedEventId| {
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id })
};
let max_edges = graph
.values()
.map(ReferencedIds::len)
.fold(graph.len(), |a, c| validated!(a + c));
// We consider that the DAG is directed from most recent events to oldest
// events, so an event is an incoming edge to its referenced events.
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
// of events that reference it in its referenced events.
let init = (Vec::new(), HashMap::<OwnedEventId, HashSet<OwnedEventId>>::new());
let init = (
Vec::with_capacity(max_edges),
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
);
// Populate the list of events with an outdegree of zero, and the map of
// incoming edges.
@@ -103,24 +119,19 @@ where
init,
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| {
if outgoing_edges.is_empty() {
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
// smallest -> largest.
zero_outdeg.push(Reverse(TieBreaker {
power_level,
origin_server_ts,
event_id,
}));
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
}
incoming_edges.entry(event_id.into()).or_default();
for referenced_event_id in outgoing_edges {
incoming_edges
let references = incoming_edges
.entry(referenced_event_id.into())
.or_default()
.insert(event_id.into());
.or_default();
if !references.contains(event_id) {
references.push(event_id.into());
}
}
Ok((zero_outdeg, incoming_edges))
@@ -133,34 +144,29 @@ where
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
let mut heap = BinaryHeap::from(zero_outdeg);
let mut sorted = vec![];
let mut sorted = Vec::with_capacity(max_edges);
// Apply Kahn's algorithm.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
while let Some(Reverse(item)) = heap.pop() {
for parent_id in incoming_edges
.get(item.event_id)
.expect("event_id in heap should also be in incoming_edges")
.get(&item.event_id)
.unwrap_or(&ReferencedIds::new())
{
let outgoing_edges = outgoing_edges_map
.get_mut(parent_id)
.expect("outgoing_edges should contain all event_ids");
outgoing_edges.remove(item.event_id);
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
if !outgoing_edges.is_empty() {
continue;
}
// Push on the heap once all the outgoing edges have been removed.
let (power_level, origin_server_ts) = query(parent_id.clone()).await?;
heap.push(Reverse(TieBreaker {
power_level,
origin_server_ts,
event_id: parent_id.as_ref(),
}));
heap.push(Reverse(query(parent_id.clone()).await?));
}
sorted.push(item.event_id.into());
sorted.push(item.event_id);
}
Ok(sorted)

View File

@@ -1,7 +1,4 @@
use std::{
collections::{HashMap, HashSet},
iter::once,
};
use std::{collections::HashMap, iter::once};
use futures::{FutureExt, StreamExt, stream::FuturesOrdered};
use ruma::{
@@ -50,7 +47,7 @@ where
while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
let Some((pdu, mut json_opt)) = outlier.pop() else {
// Fetch and handle failed
graph.insert(prev_event_id.clone(), HashSet::new());
graph.insert(prev_event_id.clone(), Default::default());
continue;
};
@@ -59,7 +56,7 @@ where
let limit = self.services.server.config.max_fetch_prev_events;
if amount > limit {
debug_warn!(?limit, "Max prev event limit reached!");
graph.insert(prev_event_id.clone(), HashSet::new());
graph.insert(prev_event_id.clone(), Default::default());
continue;
}
@@ -74,7 +71,7 @@ where
let Some(json) = json_opt else {
// Get json failed, so this was not fetched over federation
graph.insert(prev_event_id.clone(), HashSet::new());
graph.insert(prev_event_id.clone(), Default::default());
continue;
};
@@ -104,7 +101,7 @@ where
);
} else {
// Time based check failed
graph.insert(prev_event_id.clone(), HashSet::new());
graph.insert(prev_event_id.clone(), Default::default());
}
eventid_info.insert(prev_event_id.clone(), (pdu, json));

View File

@@ -110,11 +110,24 @@ fn is_backed_off(&self, event_id: &EventId, range: Range<Duration>) -> bool {
}
#[implement(Service)]
#[tracing::instrument(
name = "exists",
level = "trace",
ret(level = "trace"),
skip_all,
fields(%event_id)
)]
async fn event_exists(&self, event_id: &EventId) -> bool {
self.services.timeline.pdu_exists(event_id).await
}
#[implement(Service)]
#[tracing::instrument(
name = "fetch",
level = "trace",
skip_all,
fields(%event_id)
)]
async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
self.services.timeline.get_pdu(event_id).await
}

View File

@@ -103,15 +103,10 @@ pub async fn resolve_state(
}
#[implement(super::Service)]
#[tracing::instrument(
name = "resolve",
level = "debug",
skip_all,
fields(%room_id),
)]
#[tracing::instrument(name = "resolve", level = "debug", skip_all)]
pub(super) async fn state_resolution<StateSets, AuthSets>(
&self,
room_id: &RoomId,
_room_id: &RoomId,
room_version: &RoomVersionId,
state_sets: StateSets,
auth_chains: AuthSets,

View File

@@ -22,8 +22,8 @@ use crate::rooms::{
#[tracing::instrument(
name = "upgrade",
level = "debug",
skip_all,
ret(level = "debug")
ret(level = "debug"),
skip_all
)]
pub(super) async fn upgrade_outlier_to_timeline_pdu(
&self,