Split topological_sort; semi try_unfold refactor.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -27,7 +27,7 @@ use std::{
|
||||
collections::{BinaryHeap, HashMap},
|
||||
};
|
||||
|
||||
use futures::TryStreamExt;
|
||||
use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
|
||||
use ruma::{
|
||||
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
|
||||
};
|
||||
@@ -40,9 +40,9 @@ type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct TieBreaker {
|
||||
event_id: OwnedEventId,
|
||||
power_level: UserPowerLevel,
|
||||
origin_server_ts: MilliSecondsSinceUnixEpoch,
|
||||
event_id: OwnedEventId,
|
||||
}
|
||||
|
||||
// NOTE: the power level comparison is "backwards" intentionally.
|
||||
@@ -73,6 +73,12 @@ impl PartialOrd for TieBreaker {
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the ordered list of event IDs from earliest to latest.
|
||||
///
|
||||
/// We consider that the DAG is directed from most recent events to oldest
|
||||
/// events, so an event is an incoming edge to its referenced events.
|
||||
/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
|
||||
/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
|
||||
/// of events that reference it in its referenced events.
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
@@ -80,7 +86,7 @@ impl PartialOrd for TieBreaker {
|
||||
graph = graph.len(),
|
||||
)
|
||||
)]
|
||||
#[expect(clippy::implicit_hasher, clippy::or_fun_call)]
|
||||
#[expect(clippy::implicit_hasher)]
|
||||
pub async fn topological_sort<Query, Fut>(
|
||||
graph: &HashMap<OwnedEventId, ReferencedIds>,
|
||||
query: &Query,
|
||||
@@ -91,7 +97,7 @@ where
|
||||
{
|
||||
let query = async |event_id: OwnedEventId| {
|
||||
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
|
||||
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id })
|
||||
Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
|
||||
};
|
||||
|
||||
let max_edges = graph
|
||||
@@ -99,74 +105,82 @@ where
|
||||
.map(ReferencedIds::len)
|
||||
.fold(graph.len(), |a, c| validated!(a + c));
|
||||
|
||||
// We consider that the DAG is directed from most recent events to oldest
|
||||
// events, so an event is an incoming edge to its referenced events.
|
||||
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
|
||||
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
|
||||
// of events that reference it in its referenced events.
|
||||
let init = (
|
||||
Vec::with_capacity(max_edges),
|
||||
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
|
||||
);
|
||||
|
||||
// Populate the list of events with an outdegree of zero, and the map of
|
||||
// incoming edges.
|
||||
let (zero_outdeg, incoming_edges) = graph
|
||||
let incoming = graph
|
||||
.iter()
|
||||
.try_stream()
|
||||
.try_fold(
|
||||
init,
|
||||
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| {
|
||||
if outgoing_edges.is_empty() {
|
||||
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
|
||||
// smallest -> largest.
|
||||
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
|
||||
}
|
||||
.flat_map(|(event_id, out)| {
|
||||
out.iter()
|
||||
.map(move |reference| (event_id, reference))
|
||||
})
|
||||
.fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
|
||||
let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
|
||||
|
||||
for referenced_event_id in outgoing_edges {
|
||||
let references = incoming_edges
|
||||
.entry(referenced_event_id.into())
|
||||
.or_default();
|
||||
|
||||
if !references.contains(event_id) {
|
||||
references.push(event_id.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((zero_outdeg, incoming_edges))
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Map of event to the list of events in its referenced events.
|
||||
let mut outgoing_edges_map = graph.clone();
|
||||
|
||||
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
|
||||
let mut heap = BinaryHeap::from(zero_outdeg);
|
||||
let mut sorted = Vec::with_capacity(max_edges);
|
||||
|
||||
// Apply Kahn's algorithm.
|
||||
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
while let Some(Reverse(item)) = heap.pop() {
|
||||
for parent_id in incoming_edges
|
||||
.get(&item.event_id)
|
||||
.unwrap_or(&ReferencedIds::new())
|
||||
{
|
||||
let outgoing_edges = outgoing_edges_map
|
||||
.get_mut(parent_id)
|
||||
.expect("outgoing_edges should contain all event_ids");
|
||||
|
||||
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
|
||||
if !outgoing_edges.is_empty() {
|
||||
continue;
|
||||
if !references.contains(event_id) {
|
||||
references.push(event_id.clone());
|
||||
}
|
||||
|
||||
// Push on the heap once all the outgoing edges have been removed.
|
||||
heap.push(Reverse(query(parent_id.clone()).await?));
|
||||
}
|
||||
incoming
|
||||
});
|
||||
|
||||
sorted.push(item.event_id);
|
||||
}
|
||||
let horizon = graph
|
||||
.iter()
|
||||
.filter(|(_, references)| references.is_empty())
|
||||
.try_stream()
|
||||
.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
|
||||
.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
|
||||
.await?;
|
||||
|
||||
Ok(sorted)
|
||||
kahn_sort(horizon, graph.clone(), &incoming, &query)
|
||||
.try_collect()
|
||||
.await
|
||||
}
|
||||
|
||||
// Apply Kahn's algorithm.
|
||||
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
heap = %heap.len(),
|
||||
graph = %graph.len(),
|
||||
)
|
||||
)]
|
||||
fn kahn_sort<Query, Fut>(
|
||||
heap: BinaryHeap<Reverse<TieBreaker>>,
|
||||
graph: HashMap<OwnedEventId, ReferencedIds>,
|
||||
incoming: &HashMap<OwnedEventId, ReferencedIds>,
|
||||
query: &Query,
|
||||
) -> impl Stream<Item = Result<OwnedEventId>> + Send
|
||||
where
|
||||
Query: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<TieBreaker>> + Send,
|
||||
{
|
||||
try_unfold((heap, graph), move |(mut heap, graph)| async move {
|
||||
let Some(Reverse(item)) = heap.pop() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let references = incoming.get(&item.event_id).cloned();
|
||||
let state = (item.event_id, (heap, graph));
|
||||
references
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.try_stream()
|
||||
.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
|
||||
let out = graph
|
||||
.get_mut(&parent_id)
|
||||
.expect("contains all parent_ids");
|
||||
|
||||
out.retain(is_not_equal_to!(&event_id));
|
||||
|
||||
// Push on the heap once all the outgoing edges have been removed.
|
||||
if out.is_empty() {
|
||||
heap.push(Reverse(query(parent_id.clone()).await?));
|
||||
}
|
||||
|
||||
Ok::<_, Error>((event_id, (heap, graph)))
|
||||
})
|
||||
.map_ok(Some)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user