Split topological_sort; semi try_unfold refactor.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -27,7 +27,7 @@ use std::{
|
|||||||
collections::{BinaryHeap, HashMap},
|
collections::{BinaryHeap, HashMap},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::TryStreamExt;
|
use futures::{Stream, TryFutureExt, TryStreamExt, stream::try_unfold};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
|
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
|
||||||
};
|
};
|
||||||
@@ -40,9 +40,9 @@ type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
|
|||||||
|
|
||||||
#[derive(PartialEq, Eq)]
|
#[derive(PartialEq, Eq)]
|
||||||
struct TieBreaker {
|
struct TieBreaker {
|
||||||
|
event_id: OwnedEventId,
|
||||||
power_level: UserPowerLevel,
|
power_level: UserPowerLevel,
|
||||||
origin_server_ts: MilliSecondsSinceUnixEpoch,
|
origin_server_ts: MilliSecondsSinceUnixEpoch,
|
||||||
event_id: OwnedEventId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: the power level comparison is "backwards" intentionally.
|
// NOTE: the power level comparison is "backwards" intentionally.
|
||||||
@@ -73,6 +73,12 @@ impl PartialOrd for TieBreaker {
|
|||||||
/// ## Returns
|
/// ## Returns
|
||||||
///
|
///
|
||||||
/// Returns the ordered list of event IDs from earliest to latest.
|
/// Returns the ordered list of event IDs from earliest to latest.
|
||||||
|
///
|
||||||
|
/// We consider that the DAG is directed from most recent events to oldest
|
||||||
|
/// events, so an event is an incoming edge to its referenced events.
|
||||||
|
/// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
|
||||||
|
/// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
|
||||||
|
/// of events that reference it in its referenced events.
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
level = "debug",
|
level = "debug",
|
||||||
skip_all,
|
skip_all,
|
||||||
@@ -80,7 +86,7 @@ impl PartialOrd for TieBreaker {
|
|||||||
graph = graph.len(),
|
graph = graph.len(),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[expect(clippy::implicit_hasher, clippy::or_fun_call)]
|
#[expect(clippy::implicit_hasher)]
|
||||||
pub async fn topological_sort<Query, Fut>(
|
pub async fn topological_sort<Query, Fut>(
|
||||||
graph: &HashMap<OwnedEventId, ReferencedIds>,
|
graph: &HashMap<OwnedEventId, ReferencedIds>,
|
||||||
query: &Query,
|
query: &Query,
|
||||||
@@ -91,7 +97,7 @@ where
|
|||||||
{
|
{
|
||||||
let query = async |event_id: OwnedEventId| {
|
let query = async |event_id: OwnedEventId| {
|
||||||
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
|
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
|
||||||
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id })
|
Ok::<_, Error>(TieBreaker { event_id, power_level, origin_server_ts })
|
||||||
};
|
};
|
||||||
|
|
||||||
let max_edges = graph
|
let max_edges = graph
|
||||||
@@ -99,74 +105,82 @@ where
|
|||||||
.map(ReferencedIds::len)
|
.map(ReferencedIds::len)
|
||||||
.fold(graph.len(), |a, c| validated!(a + c));
|
.fold(graph.len(), |a, c| validated!(a + c));
|
||||||
|
|
||||||
// We consider that the DAG is directed from most recent events to oldest
|
let incoming = graph
|
||||||
// events, so an event is an incoming edge to its referenced events.
|
|
||||||
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
|
|
||||||
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
|
|
||||||
// of events that reference it in its referenced events.
|
|
||||||
let init = (
|
|
||||||
Vec::with_capacity(max_edges),
|
|
||||||
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Populate the list of events with an outdegree of zero, and the map of
|
|
||||||
// incoming edges.
|
|
||||||
let (zero_outdeg, incoming_edges) = graph
|
|
||||||
.iter()
|
.iter()
|
||||||
.try_stream()
|
.flat_map(|(event_id, out)| {
|
||||||
.try_fold(
|
out.iter()
|
||||||
init,
|
.map(move |reference| (event_id, reference))
|
||||||
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| {
|
})
|
||||||
if outgoing_edges.is_empty() {
|
.fold(HashMap::with_capacity(max_edges), |mut incoming, (event_id, reference)| {
|
||||||
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
|
let references: &mut ReferencedIds = incoming.entry(reference.clone()).or_default();
|
||||||
// smallest -> largest.
|
|
||||||
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
|
|
||||||
}
|
|
||||||
|
|
||||||
for referenced_event_id in outgoing_edges {
|
if !references.contains(event_id) {
|
||||||
let references = incoming_edges
|
references.push(event_id.clone());
|
||||||
.entry(referenced_event_id.into())
|
|
||||||
.or_default();
|
|
||||||
|
|
||||||
if !references.contains(event_id) {
|
|
||||||
references.push(event_id.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok((zero_outdeg, incoming_edges))
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Map of event to the list of events in its referenced events.
|
|
||||||
let mut outgoing_edges_map = graph.clone();
|
|
||||||
|
|
||||||
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
|
|
||||||
let mut heap = BinaryHeap::from(zero_outdeg);
|
|
||||||
let mut sorted = Vec::with_capacity(max_edges);
|
|
||||||
|
|
||||||
// Apply Kahn's algorithm.
|
|
||||||
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
|
||||||
while let Some(Reverse(item)) = heap.pop() {
|
|
||||||
for parent_id in incoming_edges
|
|
||||||
.get(&item.event_id)
|
|
||||||
.unwrap_or(&ReferencedIds::new())
|
|
||||||
{
|
|
||||||
let outgoing_edges = outgoing_edges_map
|
|
||||||
.get_mut(parent_id)
|
|
||||||
.expect("outgoing_edges should contain all event_ids");
|
|
||||||
|
|
||||||
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
|
|
||||||
if !outgoing_edges.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Push on the heap once all the outgoing edges have been removed.
|
incoming
|
||||||
heap.push(Reverse(query(parent_id.clone()).await?));
|
});
|
||||||
}
|
|
||||||
|
|
||||||
sorted.push(item.event_id);
|
let horizon = graph
|
||||||
}
|
.iter()
|
||||||
|
.filter(|(_, references)| references.is_empty())
|
||||||
|
.try_stream()
|
||||||
|
.and_then(async |(event_id, _)| Ok(Reverse(query(event_id.clone()).await?)))
|
||||||
|
.try_collect::<BinaryHeap<Reverse<TieBreaker>>>()
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(sorted)
|
kahn_sort(horizon, graph.clone(), &incoming, &query)
|
||||||
|
.try_collect()
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Kahn's algorithm.
|
||||||
|
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||||
|
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
|
||||||
|
#[tracing::instrument(
|
||||||
|
level = "debug",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
heap = %heap.len(),
|
||||||
|
graph = %graph.len(),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
fn kahn_sort<Query, Fut>(
|
||||||
|
heap: BinaryHeap<Reverse<TieBreaker>>,
|
||||||
|
graph: HashMap<OwnedEventId, ReferencedIds>,
|
||||||
|
incoming: &HashMap<OwnedEventId, ReferencedIds>,
|
||||||
|
query: &Query,
|
||||||
|
) -> impl Stream<Item = Result<OwnedEventId>> + Send
|
||||||
|
where
|
||||||
|
Query: Fn(OwnedEventId) -> Fut + Sync,
|
||||||
|
Fut: Future<Output = Result<TieBreaker>> + Send,
|
||||||
|
{
|
||||||
|
try_unfold((heap, graph), move |(mut heap, graph)| async move {
|
||||||
|
let Some(Reverse(item)) = heap.pop() else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let references = incoming.get(&item.event_id).cloned();
|
||||||
|
let state = (item.event_id, (heap, graph));
|
||||||
|
references
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.try_stream()
|
||||||
|
.try_fold(state, |(event_id, (mut heap, mut graph)), parent_id| async move {
|
||||||
|
let out = graph
|
||||||
|
.get_mut(&parent_id)
|
||||||
|
.expect("contains all parent_ids");
|
||||||
|
|
||||||
|
out.retain(is_not_equal_to!(&event_id));
|
||||||
|
|
||||||
|
// Push on the heap once all the outgoing edges have been removed.
|
||||||
|
if out.is_empty() {
|
||||||
|
heap.push(Reverse(query(parent_id.clone()).await?));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok::<_, Error>((event_id, (heap, graph)))
|
||||||
|
})
|
||||||
|
.map_ok(Some)
|
||||||
|
.await
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user