Optimize reference graph container value type for topological_sort.
Optimize initial container capacity estimates. Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -53,9 +53,9 @@ type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
fields(
|
||||
%client,
|
||||
origin = body.origin().as_str(),
|
||||
txn = str_truncated(body.transaction_id.as_str(), 20),
|
||||
origin = body.origin().as_str(),
|
||||
%client,
|
||||
),
|
||||
)]
|
||||
pub(crate) async fn send_transaction_message_route(
|
||||
@@ -189,7 +189,7 @@ async fn handle_pdus(
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "room",
|
||||
level = "debug",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
fields(%room_id)
|
||||
)]
|
||||
|
||||
@@ -41,16 +41,20 @@ criterion_main!(benches);
|
||||
|
||||
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[expect(
|
||||
clippy::iter_on_single_items,
|
||||
clippy::iter_on_empty_collections
|
||||
)]
|
||||
fn lexico_topo_sort(c: &mut Criterion) {
|
||||
c.bench_function("lexico_topo_sort", |c| {
|
||||
use maplit::{hashmap, hashset};
|
||||
use maplit::hashmap;
|
||||
|
||||
let graph = hashmap! {
|
||||
event_id("l") => hashset![event_id("o")],
|
||||
event_id("m") => hashset![event_id("n"), event_id("o")],
|
||||
event_id("n") => hashset![event_id("o")],
|
||||
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => hashset![event_id("o")],
|
||||
event_id("l") => [event_id("o")].into_iter().collect(),
|
||||
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
|
||||
event_id("n") => [event_id("o")].into_iter().collect(),
|
||||
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => [event_id("o")].into_iter().collect(),
|
||||
};
|
||||
|
||||
c.to_async(FuturesExecutor).iter(async || {
|
||||
|
||||
@@ -65,7 +65,7 @@ mod fetch_state;
|
||||
mod resolve;
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
mod topological_sort;
|
||||
pub mod topological_sort;
|
||||
|
||||
use self::{event_auth::check_state_dependent_auth_rules, fetch_state::FetchStateExt};
|
||||
pub use self::{
|
||||
|
||||
@@ -8,7 +8,10 @@ mod mainline_sort;
|
||||
mod power_sort;
|
||||
mod split_conflicted;
|
||||
|
||||
use std::collections::{BTreeMap, BTreeSet, HashSet};
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet, HashSet},
|
||||
ops::Deref,
|
||||
};
|
||||
|
||||
use futures::{FutureExt, Stream, StreamExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
@@ -28,7 +31,6 @@ use crate::{
|
||||
trace,
|
||||
utils::{
|
||||
BoolExt,
|
||||
option::OptionExt,
|
||||
stream::{BroadbandExt, IterStream},
|
||||
},
|
||||
};
|
||||
@@ -242,20 +244,25 @@ where
|
||||
.is_some_and(|rules| rules.consider_conflicted_state_subgraph)
|
||||
|| backport_css;
|
||||
|
||||
let conflicted_states = conflicted_states.values().flatten().cloned();
|
||||
let conflicted_state_set: HashSet<_> = conflicted_states.values().flatten().collect();
|
||||
|
||||
// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
|
||||
let conflicted_subgraph = consider_conflicted_subgraph
|
||||
.then(|| conflicted_states.clone().stream())
|
||||
.map_async(async |ids| conflicted_subgraph_dfs(ids, fetch))
|
||||
.then_async(async || conflicted_subgraph_dfs(&conflicted_state_set, fetch))
|
||||
.map(Option::into_iter)
|
||||
.map(IterStream::stream)
|
||||
.flatten_stream()
|
||||
.flatten()
|
||||
.boxed();
|
||||
|
||||
let conflicted_state_ids = conflicted_state_set
|
||||
.iter()
|
||||
.map(Deref::deref)
|
||||
.cloned()
|
||||
.stream();
|
||||
|
||||
auth_difference(auth_sets)
|
||||
.chain(conflicted_states.stream())
|
||||
.chain(conflicted_state_ids)
|
||||
.broad_filter_map(async |id| exists(id.clone()).await.then_some(id))
|
||||
.chain(conflicted_subgraph)
|
||||
.collect::<HashSet<_>>()
|
||||
|
||||
@@ -35,32 +35,28 @@ const PATH_INLINE: usize = 48;
|
||||
const STACK_INLINE: usize = 48;
|
||||
|
||||
#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)]
|
||||
pub(super) fn conflicted_subgraph_dfs<ConflictedEventIds, Fetch, Fut, Pdu>(
|
||||
conflicted_event_ids: ConflictedEventIds,
|
||||
pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
|
||||
conflicted_event_ids: &Set<&OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
) -> impl Stream<Item = OwnedEventId> + Send
|
||||
where
|
||||
ConflictedEventIds: Stream<Item = OwnedEventId> + Send,
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
conflicted_event_ids
|
||||
.collect::<Set<_>>()
|
||||
.map(|ids| (Arc::new(Global::default()), ids))
|
||||
.then(async |(state, conflicted_event_ids)| {
|
||||
let state = Arc::new(Global::default());
|
||||
let state_ = state.clone();
|
||||
conflicted_event_ids
|
||||
.iter()
|
||||
.stream()
|
||||
.map(ToOwned::to_owned)
|
||||
.map(|event_id| (state.clone(), event_id))
|
||||
.for_each_concurrent(automatic_width(), async |(state, event_id)| {
|
||||
subgraph_descent(state, event_id, &conflicted_event_ids, fetch)
|
||||
.enumerate()
|
||||
.map(move |(i, event_id)| (state_.clone(), event_id, i))
|
||||
.for_each_concurrent(automatic_width(), async |(state, event_id, i)| {
|
||||
subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i)
|
||||
.await
|
||||
.expect("only mutex errors expected");
|
||||
})
|
||||
.await;
|
||||
|
||||
.map(move |()| {
|
||||
let seen = state.seen.lock().expect("locked");
|
||||
let mut state = state.subgraph.lock().expect("locked");
|
||||
debug!(
|
||||
@@ -70,10 +66,8 @@ where
|
||||
"conflicted subgraph state"
|
||||
);
|
||||
|
||||
take(&mut *state)
|
||||
take(&mut *state).into_iter().stream()
|
||||
})
|
||||
.map(Set::into_iter)
|
||||
.map(IterStream::stream)
|
||||
.flatten_stream()
|
||||
}
|
||||
|
||||
@@ -84,24 +78,26 @@ where
|
||||
fields(
|
||||
event_ids = conflicted_event_ids.len(),
|
||||
event_id = %conflicted_event_id,
|
||||
%i,
|
||||
)
|
||||
)]
|
||||
async fn subgraph_descent<Fetch, Fut, Pdu>(
|
||||
state: Arc<Global>,
|
||||
conflicted_event_id: OwnedEventId,
|
||||
conflicted_event_ids: &Set<OwnedEventId>,
|
||||
conflicted_event_ids: &Set<&OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
) -> Result<Arc<Global>>
|
||||
state: &Arc<Global>,
|
||||
conflicted_event_id: &OwnedEventId,
|
||||
i: usize,
|
||||
) -> Result
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let Global { subgraph, seen } = &*state;
|
||||
let Global { subgraph, seen } = &**state;
|
||||
|
||||
let mut local = Local {
|
||||
path: once(conflicted_event_id.clone()).collect(),
|
||||
stack: once(once(conflicted_event_id).collect()).collect(),
|
||||
stack: once(once(conflicted_event_id.clone()).collect()).collect(),
|
||||
};
|
||||
|
||||
while let Some(event_id) = pop(&mut local) {
|
||||
@@ -133,7 +129,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pop(local: &mut Local) -> Option<OwnedEventId> {
|
||||
|
||||
@@ -193,7 +193,12 @@ where
|
||||
// Add authentic event to the partially resolved state.
|
||||
if check_state_dependent_auth_rules(rules, &event, &fetch_state)
|
||||
.await
|
||||
.inspect_err(|e| debug_warn!("event failed auth check: {e}"))
|
||||
.inspect_err(|e| {
|
||||
debug_warn!(
|
||||
event_type = ?event.event_type(), ?state_key, %event_id,
|
||||
"event failed auth check: {e}"
|
||||
);
|
||||
})
|
||||
.is_ok()
|
||||
{
|
||||
let key = event.event_type().with_state_key(state_key);
|
||||
|
||||
@@ -16,6 +16,7 @@ use super::super::{
|
||||
power_levels::RoomPowerLevelsEventOptionExt,
|
||||
},
|
||||
topological_sort,
|
||||
topological_sort::ReferencedIds,
|
||||
};
|
||||
use crate::{
|
||||
Result, err,
|
||||
@@ -107,11 +108,11 @@ where
|
||||
)
|
||||
)]
|
||||
async fn add_event_auth_chain<Fetch, Fut, Pdu>(
|
||||
mut graph: HashMap<OwnedEventId, HashSet<OwnedEventId>>,
|
||||
mut graph: HashMap<OwnedEventId, ReferencedIds>,
|
||||
full_conflicted_set: &HashSet<OwnedEventId>,
|
||||
event_id: &EventId,
|
||||
fetch: &Fetch,
|
||||
) -> HashMap<OwnedEventId, HashSet<OwnedEventId>>
|
||||
) -> HashMap<OwnedEventId, ReferencedIds>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
@@ -127,22 +128,17 @@ where
|
||||
// Add the current event to the graph.
|
||||
graph.entry(event_id).or_default();
|
||||
|
||||
let auth_events = event
|
||||
event
|
||||
.as_ref()
|
||||
.map(Event::auth_events)
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
for auth_event_id in auth_events {
|
||||
// If the auth event ID is in the full conflicted set…
|
||||
if !full_conflicted_set.contains(auth_event_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the auth event ID is not in the graph, we need to check its auth events
|
||||
// later.
|
||||
if !graph.contains_key(auth_event_id) {
|
||||
state.push(auth_event_id.to_owned());
|
||||
.flatten()
|
||||
.map(ToOwned::to_owned)
|
||||
.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
|
||||
.for_each(|auth_event_id| {
|
||||
// If the auth event ID is not in the graph, check its auth events later.
|
||||
if !graph.contains_key(&auth_event_id) {
|
||||
state.push(auth_event_id.clone());
|
||||
}
|
||||
|
||||
let event_id = event
|
||||
@@ -151,11 +147,14 @@ where
|
||||
.event_id();
|
||||
|
||||
// Add the auth event ID to the list of incoming edges.
|
||||
graph
|
||||
let references = graph
|
||||
.get_mut(event_id)
|
||||
.expect("event_id must be added to graph")
|
||||
.insert(auth_event_id.to_owned());
|
||||
.expect("event_id must be added to graph");
|
||||
|
||||
if !references.contains(&auth_event_id) {
|
||||
references.push(auth_event_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
graph
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashMap, hash::Hash};
|
||||
use std::{collections::HashMap, hash::Hash, iter::IntoIterator};
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
|
||||
@@ -35,13 +35,18 @@ where
|
||||
{
|
||||
let state_maps: Vec<_> = state_maps.collect().await;
|
||||
|
||||
let mut state_set_count = 0_usize;
|
||||
let mut occurrences = HashMap::<_, HashMap<_, usize>>::new();
|
||||
let state_maps = state_maps.iter().inspect(|_state| {
|
||||
state_set_count = validated!(state_set_count + 1);
|
||||
});
|
||||
let state_ids_est = state_maps.iter().flatten().count();
|
||||
|
||||
for (k, v) in state_maps.into_iter().flat_map(|s| s.iter()) {
|
||||
let state_set_count = state_maps
|
||||
.iter()
|
||||
.fold(0_usize, |acc, _| validated!(acc + 1));
|
||||
|
||||
let mut occurrences = HashMap::<_, HashMap<_, usize>>::with_capacity(state_ids_est);
|
||||
|
||||
for (k, v) in state_maps
|
||||
.into_iter()
|
||||
.flat_map(IntoIterator::into_iter)
|
||||
{
|
||||
let acc = occurrences
|
||||
.entry(k.clone())
|
||||
.or_default()
|
||||
@@ -59,10 +64,16 @@ where
|
||||
if occurrence_count == state_set_count {
|
||||
unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
|
||||
} else {
|
||||
conflicted_state_set
|
||||
let conflicts = conflicted_state_set
|
||||
.entry((k.0.clone(), k.1.clone()))
|
||||
.or_default()
|
||||
.push(id.clone());
|
||||
.or_default();
|
||||
|
||||
debug_assert!(
|
||||
!conflicts.contains(&id),
|
||||
"Unexpected duplicate conflicted state event"
|
||||
);
|
||||
|
||||
conflicts.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
iter::once,
|
||||
};
|
||||
use std::{collections::HashMap, iter::once};
|
||||
|
||||
use futures::StreamExt;
|
||||
use maplit::{hashmap, hashset};
|
||||
use maplit::hashmap;
|
||||
use rand::seq::SliceRandom;
|
||||
use ruma::{
|
||||
MilliSecondsSinceUnixEpoch, OwnedEventId,
|
||||
@@ -42,7 +39,7 @@ async fn test_event_sort() {
|
||||
let rules = RoomVersionRules::V6;
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let auth_chain: HashSet<OwnedEventId> = HashSet::new();
|
||||
let auth_chain = Default::default();
|
||||
|
||||
let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| {
|
||||
events.get(&id).cloned().ok_or_else(not_found)
|
||||
@@ -484,6 +481,10 @@ async fn test_event_map_none() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[expect(
|
||||
clippy::iter_on_single_items,
|
||||
clippy::iter_on_empty_collections
|
||||
)]
|
||||
async fn test_reverse_topological_power_sort() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
@@ -492,11 +493,11 @@ async fn test_reverse_topological_power_sort() {
|
||||
);
|
||||
|
||||
let graph = hashmap! {
|
||||
event_id("l") => hashset![event_id("o")],
|
||||
event_id("m") => hashset![event_id("n"), event_id("o")],
|
||||
event_id("n") => hashset![event_id("o")],
|
||||
event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => hashset![event_id("o")],
|
||||
event_id("l") => [event_id("o")].into_iter().collect(),
|
||||
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
|
||||
event_id("n") => [event_id("o")].into_iter().collect(),
|
||||
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => [event_id("o")].into_iter().collect(),
|
||||
};
|
||||
|
||||
let res = super::super::topological_sort(&graph, &async |_id| {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashMap,
|
||||
pin::Pin,
|
||||
slice,
|
||||
sync::{
|
||||
@@ -30,7 +30,10 @@ use serde_json::{
|
||||
use super::{AuthSet, StateMap, auth_types_for_event, events::RoomCreateEvent};
|
||||
use crate::{
|
||||
Error, Result, err, info,
|
||||
matrix::{Event, EventHash, EventTypeExt, PduEvent, StateKey},
|
||||
matrix::{
|
||||
Event, EventHash, EventTypeExt, PduEvent, StateKey,
|
||||
state_res::topological_sort::ReferencedIds,
|
||||
},
|
||||
utils::stream::IterStream,
|
||||
};
|
||||
|
||||
@@ -64,14 +67,14 @@ pub(super) async fn do_check(
|
||||
);
|
||||
|
||||
// This will be lexi_topo_sorted for resolution
|
||||
let mut graph = HashMap::new();
|
||||
let mut graph = HashMap::<OwnedEventId, ReferencedIds>::new();
|
||||
// This is the same as in `resolve` event_id -> OriginalStateEvent
|
||||
let mut fake_event_map = HashMap::new();
|
||||
|
||||
// Create the DB of events that led up to this point
|
||||
// TODO maybe clean up some of these clones it is just tests but...
|
||||
for ev in init_events.values().chain(events) {
|
||||
graph.insert(ev.event_id().to_owned(), HashSet::new());
|
||||
graph.insert(ev.event_id().to_owned(), Default::default());
|
||||
fake_event_map.insert(ev.event_id().to_owned(), ev.clone());
|
||||
}
|
||||
|
||||
@@ -79,8 +82,8 @@ pub(super) async fn do_check(
|
||||
if let [a, b] = &pair {
|
||||
graph
|
||||
.entry(a.to_owned())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(b.clone());
|
||||
.or_default()
|
||||
.push(b.clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,8 +92,8 @@ pub(super) async fn do_check(
|
||||
if let [a, b] = &pair {
|
||||
graph
|
||||
.entry(a.to_owned())
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(b.clone());
|
||||
.or_default()
|
||||
.push(b.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,37 +24,40 @@
|
||||
|
||||
use std::{
|
||||
cmp::{Ordering, Reverse},
|
||||
collections::{BinaryHeap, HashMap, HashSet},
|
||||
collections::{BinaryHeap, HashMap},
|
||||
};
|
||||
|
||||
use futures::TryStreamExt;
|
||||
use ruma::{
|
||||
EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
|
||||
MilliSecondsSinceUnixEpoch, OwnedEventId, events::room::power_levels::UserPowerLevel,
|
||||
};
|
||||
|
||||
use crate::{Result, utils::stream::IterStream};
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct TieBreaker<'a> {
|
||||
power_level: UserPowerLevel,
|
||||
origin_server_ts: MilliSecondsSinceUnixEpoch,
|
||||
event_id: &'a EventId,
|
||||
}
|
||||
use crate::{
|
||||
Error, Result, is_not_equal_to, smallvec::SmallVec, utils::stream::IterStream, validated,
|
||||
};
|
||||
|
||||
pub type ReferencedIds = SmallVec<[OwnedEventId; 3]>;
|
||||
type PduInfo = (UserPowerLevel, MilliSecondsSinceUnixEpoch);
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct TieBreaker {
|
||||
power_level: UserPowerLevel,
|
||||
origin_server_ts: MilliSecondsSinceUnixEpoch,
|
||||
event_id: OwnedEventId,
|
||||
}
|
||||
|
||||
// NOTE: the power level comparison is "backwards" intentionally.
|
||||
impl Ord for TieBreaker<'_> {
|
||||
impl Ord for TieBreaker {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
other
|
||||
.power_level
|
||||
.cmp(&self.power_level)
|
||||
.then(self.origin_server_ts.cmp(&other.origin_server_ts))
|
||||
.then(self.event_id.cmp(other.event_id))
|
||||
.then(self.event_id.cmp(&other.event_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for TieBreaker<'_> {
|
||||
impl PartialOrd for TieBreaker {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
|
||||
}
|
||||
|
||||
@@ -78,21 +81,34 @@ impl PartialOrd for TieBreaker<'_> {
|
||||
graph = graph.len(),
|
||||
)
|
||||
)]
|
||||
#[expect(clippy::implicit_hasher)]
|
||||
#[expect(clippy::implicit_hasher, clippy::or_fun_call)]
|
||||
pub async fn topological_sort<Query, Fut>(
|
||||
graph: &HashMap<OwnedEventId, HashSet<OwnedEventId>>,
|
||||
graph: &HashMap<OwnedEventId, ReferencedIds>,
|
||||
query: &Query,
|
||||
) -> Result<Vec<OwnedEventId>>
|
||||
where
|
||||
Query: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<PduInfo>> + Send,
|
||||
{
|
||||
let query = async |event_id: OwnedEventId| {
|
||||
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
|
||||
Ok::<_, Error>(TieBreaker { power_level, origin_server_ts, event_id })
|
||||
};
|
||||
|
||||
let max_edges = graph
|
||||
.values()
|
||||
.map(ReferencedIds::len)
|
||||
.fold(graph.len(), |a, c| validated!(a + c));
|
||||
|
||||
// We consider that the DAG is directed from most recent events to oldest
|
||||
// events, so an event is an incoming edge to its referenced events.
|
||||
// zero_outdegs: Vec of events that have an outdegree of zero (no outgoing
|
||||
// edges), i.e. the oldest events. incoming_edges_map: Map of event to the list
|
||||
// of events that reference it in its referenced events.
|
||||
let init = (Vec::new(), HashMap::<OwnedEventId, HashSet<OwnedEventId>>::new());
|
||||
let init = (
|
||||
Vec::with_capacity(max_edges),
|
||||
HashMap::<OwnedEventId, ReferencedIds>::with_capacity(max_edges),
|
||||
);
|
||||
|
||||
// Populate the list of events with an outdegree of zero, and the map of
|
||||
// incoming edges.
|
||||
@@ -103,24 +119,19 @@ where
|
||||
init,
|
||||
async |(mut zero_outdeg, mut incoming_edges), (event_id, outgoing_edges)| {
|
||||
if outgoing_edges.is_empty() {
|
||||
let (power_level, origin_server_ts) = query(event_id.clone()).await?;
|
||||
|
||||
// `Reverse` because `BinaryHeap` sorts largest -> smallest and we need
|
||||
// smallest -> largest.
|
||||
zero_outdeg.push(Reverse(TieBreaker {
|
||||
power_level,
|
||||
origin_server_ts,
|
||||
event_id,
|
||||
}));
|
||||
zero_outdeg.push(Reverse(query(event_id.clone()).await?));
|
||||
}
|
||||
|
||||
incoming_edges.entry(event_id.into()).or_default();
|
||||
|
||||
for referenced_event_id in outgoing_edges {
|
||||
incoming_edges
|
||||
let references = incoming_edges
|
||||
.entry(referenced_event_id.into())
|
||||
.or_default()
|
||||
.insert(event_id.into());
|
||||
.or_default();
|
||||
|
||||
if !references.contains(event_id) {
|
||||
references.push(event_id.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((zero_outdeg, incoming_edges))
|
||||
@@ -133,34 +144,29 @@ where
|
||||
|
||||
// Use a BinaryHeap to keep the events with an outdegree of zero sorted.
|
||||
let mut heap = BinaryHeap::from(zero_outdeg);
|
||||
let mut sorted = vec![];
|
||||
let mut sorted = Vec::with_capacity(max_edges);
|
||||
|
||||
// Apply Kahn's algorithm.
|
||||
// https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
while let Some(Reverse(item)) = heap.pop() {
|
||||
for parent_id in incoming_edges
|
||||
.get(item.event_id)
|
||||
.expect("event_id in heap should also be in incoming_edges")
|
||||
.get(&item.event_id)
|
||||
.unwrap_or(&ReferencedIds::new())
|
||||
{
|
||||
let outgoing_edges = outgoing_edges_map
|
||||
.get_mut(parent_id)
|
||||
.expect("outgoing_edges should contain all event_ids");
|
||||
|
||||
outgoing_edges.remove(item.event_id);
|
||||
outgoing_edges.retain(is_not_equal_to!(&item.event_id));
|
||||
if !outgoing_edges.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Push on the heap once all the outgoing edges have been removed.
|
||||
let (power_level, origin_server_ts) = query(parent_id.clone()).await?;
|
||||
heap.push(Reverse(TieBreaker {
|
||||
power_level,
|
||||
origin_server_ts,
|
||||
event_id: parent_id.as_ref(),
|
||||
}));
|
||||
heap.push(Reverse(query(parent_id.clone()).await?));
|
||||
}
|
||||
|
||||
sorted.push(item.event_id.into());
|
||||
sorted.push(item.event_id);
|
||||
}
|
||||
|
||||
Ok(sorted)
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
iter::once,
|
||||
};
|
||||
use std::{collections::HashMap, iter::once};
|
||||
|
||||
use futures::{FutureExt, StreamExt, stream::FuturesOrdered};
|
||||
use ruma::{
|
||||
@@ -50,7 +47,7 @@ where
|
||||
while let Some((prev_event_id, mut outlier)) = todo_outlier_stack.next().await {
|
||||
let Some((pdu, mut json_opt)) = outlier.pop() else {
|
||||
// Fetch and handle failed
|
||||
graph.insert(prev_event_id.clone(), HashSet::new());
|
||||
graph.insert(prev_event_id.clone(), Default::default());
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -59,7 +56,7 @@ where
|
||||
let limit = self.services.server.config.max_fetch_prev_events;
|
||||
if amount > limit {
|
||||
debug_warn!(?limit, "Max prev event limit reached!");
|
||||
graph.insert(prev_event_id.clone(), HashSet::new());
|
||||
graph.insert(prev_event_id.clone(), Default::default());
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -74,7 +71,7 @@ where
|
||||
|
||||
let Some(json) = json_opt else {
|
||||
// Get json failed, so this was not fetched over federation
|
||||
graph.insert(prev_event_id.clone(), HashSet::new());
|
||||
graph.insert(prev_event_id.clone(), Default::default());
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -104,7 +101,7 @@ where
|
||||
);
|
||||
} else {
|
||||
// Time based check failed
|
||||
graph.insert(prev_event_id.clone(), HashSet::new());
|
||||
graph.insert(prev_event_id.clone(), Default::default());
|
||||
}
|
||||
|
||||
eventid_info.insert(prev_event_id.clone(), (pdu, json));
|
||||
|
||||
@@ -110,11 +110,24 @@ fn is_backed_off(&self, event_id: &EventId, range: Range<Duration>) -> bool {
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "exists",
|
||||
level = "trace",
|
||||
ret(level = "trace"),
|
||||
skip_all,
|
||||
fields(%event_id)
|
||||
)]
|
||||
async fn event_exists(&self, event_id: &EventId) -> bool {
|
||||
self.services.timeline.pdu_exists(event_id).await
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "fetch",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(%event_id)
|
||||
)]
|
||||
async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
|
||||
self.services.timeline.get_pdu(event_id).await
|
||||
}
|
||||
|
||||
@@ -103,15 +103,10 @@ pub async fn resolve_state(
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(
|
||||
name = "resolve",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%room_id),
|
||||
)]
|
||||
#[tracing::instrument(name = "resolve", level = "debug", skip_all)]
|
||||
pub(super) async fn state_resolution<StateSets, AuthSets>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
_room_id: &RoomId,
|
||||
room_version: &RoomVersionId,
|
||||
state_sets: StateSets,
|
||||
auth_chains: AuthSets,
|
||||
|
||||
@@ -22,8 +22,8 @@ use crate::rooms::{
|
||||
#[tracing::instrument(
|
||||
name = "upgrade",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
ret(level = "debug")
|
||||
ret(level = "debug"),
|
||||
skip_all
|
||||
)]
|
||||
pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
&self,
|
||||
|
||||
Reference in New Issue
Block a user