Split topological_sort; semi try_unfold refactor.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-25 19:53:48 +00:00
parent 357a5b7a74
commit 63b0014f8f

View File

@@ -27,7 +27,7 @@ use std::{
collections::{BinaryHeap, HashMap}, collections::{BinaryHeap, HashMap},
}; };
use futures::TryStreamExt; use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
use ruma::{ use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel, MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
}; };
@@ -40,9 +40,9 @@ type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
struct TieBreaker { struct TieBreaker {
event_id: OwnedEventId,
power_level: UserPowerLevel, power_level: UserPowerLevel,
origin_server_ts: MilliSecondsSinceUnixEpoch, origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: OwnedEventId,
} }
// NOTE: the power level comparison is "backwards" intentionally. // NOTE: the power level comparison is "backwards" intentionally.
@@ -73,6 +73,12 @@ impl PartialOrd for TieBreaker {
/// ## Returns /// ## Returns
/// ///
/// Returns the ordered list of event IDs from earliest to latest. /// Returns the ordered list of event IDs from earliest to latest.
///
/// We consider that the DAG is directed from most recent events to oldest
/// events, so an event is an incoming edge to its referenced events.
/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
/// of events that reference it in its referenced events.
#[tracing::instrument( #[tracing::instrument(
level = "debug", level = "debug",
skip_all, skip_all,
@@ -80,7 +86,7 @@ impl PartialOrd for TieBreaker {
graph = graph.len(), graph = graph.len(),
) )
)] )]
#[expect(clippy::implicit_hasher, clippy::or_fun_call)] #[expect(clippy::implicit_hasher)]
pub async fn topological_sort<Query, Fut>( pub async fn topological_sort<Query, Fut>(
graph: &HashMap<OwnedEventId, ReferencedIds>, graph: &HashMap<OwnedEventId, ReferencedIds>,
query: &Query, query: &Query,
@@ -91,7 +97,7 @@ where
{ {
let query = async |event_id: OwnedEventId| { let query = async |event_id: OwnedEventId| {
let (power_level, origin_server_ts) = query(event_id.clone()).await?; let (power_level, origin_server_ts) = query(event_id.clone()).await?;
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id }) Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
}; };
let max_edges = graph let max_edges = graph
@@ -99,74 +105,82 @@ where
.map(ReferencedIds::len) .map(ReferencedIds::len)
.fold(graph.len(), |a, c| validated!(a + c)); .fold(graph.len(), |a, c| validated!(a + c));
// We consider that the DAG is directed from most recent events to oldest let incoming = graph
// events, so an event is an incoming edge to its referenced events.
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
// of events that reference it in its referenced events.
let init = (
Vec::with_capacity(max_edges),
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
);
// Populate the list of events with an outdegree of zero, and the map of
// incoming edges.
let (zero_outdeg, incoming_edges) = graph
.iter() .iter()
.try_stream() .flat_map(|(event_id, out)| {
.try_fold( out.iter()
init, .map(move |reference| (event_id, reference))
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| { })
if outgoing_edges.is_empty() { .fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
// smallest -> largest.
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
}
for referenced_event_id in outgoing_edges { if !references.contains(event_id) {
let references = incoming_edges references.push(event_id.clone());
.entry(referenced_event_id.into())
.or_default();
if !references.contains(event_id) {
references.push(event_id.into());
}
}
Ok((zero_outdeg, incoming_edges))
},
)
.await?;
// Map of event to the list of events in its referenced events.
let mut outgoing_edges_map = graph.clone();
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
let mut heap = BinaryHeap::from(zero_outdeg);
let mut sorted = Vec::with_capacity(max_edges);
// Apply Kahn's algorithm.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
while let Some(Reverse(item)) = heap.pop() {
for parent_id in incoming_edges
.get(&item.event_id)
.unwrap_or(&ReferencedIds::new())
{
let outgoing_edges = outgoing_edges_map
.get_mut(parent_id)
.expect("outgoing_edges should contain all event_ids");
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
if !outgoing_edges.is_empty() {
continue;
} }
// Push on the heap once all the outgoing edges have been removed. incoming
heap.push(Reverse(query(parent_id.clone()).await?)); });
}
sorted.push(item.event_id); let horizon = graph
} .iter()
.filter(|(_, references)| references.is_empty())
.try_stream()
.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
.await?;
Ok(sorted) kahn_sort(horizon, graph.clone(), &incoming, &query)
.try_collect()
.await
}
// Apply Kahn's algorithm.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
#[tracing::instrument(
level = "debug",
skip_all,
fields(
heap = %heap.len(),
graph = %graph.len(),
)
)]
fn kahn_sort<Query, Fut>(
heap: BinaryHeap<Reverse<TieBreaker>>,
graph: HashMap<OwnedEventId, ReferencedIds>,
incoming: &HashMap<OwnedEventId, ReferencedIds>,
query: &Query,
) -> impl Stream<Item = Result<OwnedEventId>> + Send
where
Query: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<TieBreaker>> + Send,
{
try_unfold((heap, graph), move |(mut heap, graph)| async move {
let Some(Reverse(item)) = heap.pop() else {
return Ok(None);
};
let references = incoming.get(&item.event_id).cloned();
let state = (item.event_id, (heap, graph));
references
.into_iter()
.flatten()
.try_stream()
.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
let out = graph
.get_mut(&parent_id)
.expect("contains all parent_ids");
out.retain(is_not_equal_to!(&event_id));
// Push on the heap once all the outgoing edges have been removed.
if out.is_empty() {
heap.push(Reverse(query(parent_id.clone()).await?));
}
Ok::<_, Error>((event_id, (heap, graph)))
})
.map_ok(Some)
.await
})
} }