diff --git a/src/service/rooms/state_res/topological_sort.rs b/src/service/rooms/state_res/topological_sort.rs index 7432a632..4419461c 100644 --- a/src/service/rooms/state_res/topological_sort.rs +++ b/src/service/rooms/state_res/topological_sort.rs @@ -27,7 +27,7 @@ use std::{ collections::{BinaryHeap, HashMap}, }; -use futures::TryStreamExt; +use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold}; use ruma::{ MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel, }; @@ -40,9 +40,9 @@ type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch); #[derive(PartialEq, Eq)] struct TieBreaker { + event_id: OwnedEventId, power_level: UserPowerLevel, origin_server_ts: MilliSecondsSinceUnixEpoch, - event_id: OwnedEventId, } // NOTE: the power level comparison is "backwards" intentionally. @@ -73,6 +73,12 @@ impl PartialOrd for TieBreaker { /// ## Returns /// /// Returns the ordered list of event IDs from earliest to latest. +/// +/// We consider that the DAG is directed from most recent events to oldest +/// events, so an event is an incoming edge to its referenced events. +/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing +/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list +/// of events that reference it in its referenced events. #[tracing::instrument( level = "debug", skip_all, @@ -80,7 +86,7 @@ impl PartialOrd for TieBreaker { graph = graph.len(), ) )] -#[expect(clippy::implicit_hasher, clippy::or_fun_call)] +#[expect(clippy::implicit_hasher)] pub async fn topological_sort( graph: &HashMap, query: &Query, @@ -91,7 +97,7 @@ where { let query = async |event_id: OwnedEventId| { let (power_level, origin_server_ts) = query(event_id.clone()).await?; - Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id }) + Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts }) }; let max_edges = graph @@ -99,74 +105,82 @@ where .map(ReferencedIds::len) .fold(graph.len(), |a, c| validated!(a + c)); - // We consider that the DAG is directed from most recent events to oldest - // events, so an event is an incoming edge to its referenced events. - // zero_outdegs: Vec of events that have an outdegree of zero (no outgoing - // edges), i.e. the oldest events. incoming_edges_map: Map of event to the list - // of events that reference it in its referenced events. - let init = ( - Vec::with_capacity(max_edges), - HashMap::::with_capacity(max_edges), - ); - - // Populate the list of events with an outdegree of zero, and the map of - // incoming edges. - let (zero_outdeg, incoming_edges) = graph + let incoming = graph .iter() - .try_stream() - .try_fold( - init, - async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| { - if outgoing_edges.is_empty() { - // `Reverse` because `BinaryHeap` sorts largest -> smallest and we need - // smallest -> largest. - zero_outdeg.push(Reverse(query(event_id.clone()).await?)); - } + .flat_map(|(event_id, out)| { + out.iter() + .map(move |reference| (event_id, reference)) + }) + .fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| { + let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default(); - for referenced_event_id in outgoing_edges { - let references = incoming_edges - .entry(referenced_event_id.into()) - .or_default(); - - if !references.contains(event_id) { - references.push(event_id.into()); - } - } - - Ok((zero_outdeg, incoming_edges)) - }, - ) - .await?; - - // Map of event to the list of events in its referenced events. - let mut outgoing_edges_map = graph.clone(); - - // Use a BinaryHeap to keep the events with an outdegree of zero sorted. - let mut heap = BinaryHeap::from(zero_outdeg); - let mut sorted = Vec::with_capacity(max_edges); - - // Apply Kahn's algorithm. - // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm - while let Some(Reverse(item)) = heap.pop() { - for parent_id in incoming_edges - .get(&item.event_id) - .unwrap_or(&ReferencedIds::new()) - { - let outgoing_edges = outgoing_edges_map - .get_mut(parent_id) - .expect("outgoing_edges should contain all event_ids"); - - outgoing_edges.retain(is_not_equal_to!(&item.event_id)); - if !outgoing_edges.is_empty() { - continue; + if !references.contains(event_id) { + references.push(event_id.clone()); } - // Push on the heap once all the outgoing edges have been removed. - heap.push(Reverse(query(parent_id.clone()).await?)); - } + incoming + }); - sorted.push(item.event_id); - } + let horizon = graph + .iter() + .filter(|(_, references)| references.is_empty()) + .try_stream() + .and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?))) + .try_collect::>>() + .await?; - Ok(sorted) + kahn_sort(horizon, graph.clone(), &incoming, &query) + .try_collect() + .await +} + +// Apply Kahn's algorithm. +// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm +// Use a BinaryHeap to keep the events with an outdegree of zero sorted. +#[tracing::instrument( + level = "debug", + skip_all, + fields( + heap = %heap.len(), + graph = %graph.len(), + ) +)] +fn kahn_sort( + heap: BinaryHeap>, + graph: HashMap, + incoming: &HashMap, + query: &Query, +) -> impl Stream> + Send +where + Query: Fn(OwnedEventId) -> Fut + Sync, + Fut: Future> + Send, +{ + try_unfold((heap, graph), move |(mut heap, graph)| async move { + let Some(Reverse(item)) = heap.pop() else { + return Ok(None); + }; + + let references = incoming.get(&item.event_id).cloned(); + let state = (item.event_id, (heap, graph)); + references + .into_iter() + .flatten() + .try_stream() + .try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move { + let out = graph + .get_mut(&parent_id) + .expect("contains all parent_ids"); + + out.retain(is_not_equal_to!(&event_id)); + + // Push on the heap once all the outgoing edges have been removed. + if out.is_empty() { + heap.push(Reverse(query(parent_id.clone()).await?)); + } + + Ok::<_, Error>((event_id, (heap, graph))) + }) + .map_ok(Some) + .await + }) }