Split topological_sort; semi try_unfold refactor.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-25 19:53:48 +00:00
parent 357a5b7a74
commit 63b0014f8f

View File

@@ -27,7 +27,7 @@ use std::{
collections::{BinaryHeap, HashMap},
};
use futures::TryStreamExt;
use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
};
@@ -40,9 +40,9 @@ type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
#[derive(PartialEq, Eq)]
struct TieBreaker {
event_id: OwnedEventId,
power_level: UserPowerLevel,
origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: OwnedEventId,
}
// NOTE: the power level comparison is "backwards" intentionally.
@@ -73,6 +73,12 @@ impl PartialOrd for TieBreaker {
/// ## Returns
///
/// Returns the ordered list of event IDs from earliest to latest.
///
/// We consider that the DAG is directed from most recent events to oldest
/// events, so an event is an incoming edge to its referenced events.
/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
/// of events that reference it in its referenced events.
#[tracing::instrument(
level = "debug",
skip_all,
@@ -80,7 +86,7 @@ impl PartialOrd for TieBreaker {
graph = graph.len(),
)
)]
#[expect(clippy::implicit_hasher, clippy::or_fun_call)]
#[expect(clippy::implicit_hasher)]
pub async fn topological_sort<Query, Fut>(
graph: &HashMap<OwnedEventId, ReferencedIds>,
query: &Query,
@@ -91,7 +97,7 @@ where
{
let query = async |event_id: OwnedEventId| {
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id })
Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
};
let max_edges = graph
@@ -99,74 +105,82 @@ where
.map(ReferencedIds::len)
.fold(graph.len(), |a, c| validated!(a + c));
// We consider that the DAG is directed from most recent events to oldest
// events, so an event is an incoming edge to its referenced events.
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
// of events that reference it in its referenced events.
let init = (
Vec::with_capacity(max_edges),
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
);
// Populate the list of events with an outdegree of zero, and the map of
// incoming edges.
let (zero_outdeg, incoming_edges) = graph
let incoming = graph
.iter()
.try_stream()
.try_fold(
init,
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| {
if outgoing_edges.is_empty() {
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
// smallest -> largest.
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
}
.flat_map(|(event_id, out)| {
out.iter()
.map(move |reference| (event_id, reference))
})
.fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
for referenced_event_id in outgoing_edges {
let references = incoming_edges
.entry(referenced_event_id.into())
.or_default();
if !references.contains(event_id) {
references.push(event_id.into());
}
}
Ok((zero_outdeg, incoming_edges))
},
)
.await?;
// Map of event to the list of events in its referenced events.
let mut outgoing_edges_map = graph.clone();
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
let mut heap = BinaryHeap::from(zero_outdeg);
let mut sorted = Vec::with_capacity(max_edges);
// Apply Kahn's algorithm.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
while let Some(Reverse(item)) = heap.pop() {
for parent_id in incoming_edges
.get(&item.event_id)
.unwrap_or(&ReferencedIds::new())
{
let outgoing_edges = outgoing_edges_map
.get_mut(parent_id)
.expect("outgoing_edges should contain all event_ids");
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
if !outgoing_edges.is_empty() {
continue;
if !references.contains(event_id) {
references.push(event_id.clone());
}
// Push on the heap once all the outgoing edges have been removed.
heap.push(Reverse(query(parent_id.clone()).await?));
}
incoming
});
sorted.push(item.event_id);
}
let horizon = graph
.iter()
.filter(|(_, references)| references.is_empty())
.try_stream()
.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
.await?;
Ok(sorted)
kahn_sort(horizon, graph.clone(), &incoming, &query)
.try_collect()
.await
}
// Apply Kahn's algorithm.
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
#[tracing::instrument(
level = "debug",
skip_all,
fields(
heap = %heap.len(),
graph = %graph.len(),
)
)]
fn kahn_sort<Query, Fut>(
heap: BinaryHeap<Reverse<TieBreaker>>,
graph: HashMap<OwnedEventId, ReferencedIds>,
incoming: &HashMap<OwnedEventId, ReferencedIds>,
query: &Query,
) -> impl Stream<Item = Result<OwnedEventId>> + Send
where
Query: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<TieBreaker>> + Send,
{
try_unfold((heap, graph), move |(mut heap, graph)| async move {
let Some(Reverse(item)) = heap.pop() else {
return Ok(None);
};
let references = incoming.get(&item.event_id).cloned();
let state = (item.event_id, (heap, graph));
references
.into_iter()
.flatten()
.try_stream()
.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
let out = graph
.get_mut(&parent_id)
.expect("contains all parent_ids");
out.retain(is_not_equal_to!(&event_id));
// Push on the heap once all the outgoing edges have been removed.
if out.is_empty() {
heap.push(Reverse(query(parent_id.clone()).await?));
}
Ok::<_, Error>((event_id, (heap, graph)))
})
.map_ok(Some)
.await
})
}