Refactor conflicted_subgraph into stream::unfold() pattern.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-25 07:38:35 +00:00
parent 9fb6594975
commit 39cf124813
4 changed files with 146 additions and 98 deletions

View File

@@ -1,4 +1,3 @@
#![type_length_limit = "262144"] //TODO: REDUCE ME
#![expect(clippy::toplevel_ref_arg)] #![expect(clippy::toplevel_ref_arg)]
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91 #![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91

View File

@@ -131,7 +131,10 @@ where
.boxed() .boxed()
.await?; .await?;
let power_set_event_ids: Vec<_> = sorted_power_set.iter().sorted().collect(); let power_set_event_ids: Vec<_> = sorted_power_set
.iter()
.sorted_unstable()
.collect();
let sorted_power_set = sorted_power_set let sorted_power_set = sorted_power_set
.iter() .iter()
@@ -236,7 +239,7 @@ where
ExistsFut: Future<Output = bool> + Send, ExistsFut: Future<Output = bool> + Send,
FetchEvent: Fn(OwnedEventId) -> EventFut + Sync, FetchEvent: Fn(OwnedEventId) -> EventFut + Sync,
EventFut: Future<Output = Result<Pdu>> + Send, EventFut: Future<Output = Result<Pdu>> + Send,
Pdu: Event + Clone, Pdu: Event,
{ {
let consider_conflicted_subgraph = rules let consider_conflicted_subgraph = rules
.state_res .state_res
@@ -244,7 +247,12 @@ where
.is_some_and(|rules| rules.consider_conflicted_state_subgraph) .is_some_and(|rules| rules.consider_conflicted_state_subgraph)
|| backport_css; || backport_css;
let conflicted_state_set: HashSet<_> = conflicted_states.values().flatten().collect(); let conflicted_state_set: Vec<_> = conflicted_states
.values()
.flatten()
.sorted_unstable()
.dedup()
.collect();
// Since `org.matrix.hydra.11`, fetch the conflicted state subgraph. // Since `org.matrix.hydra.11`, fetch the conflicted state subgraph.
let conflicted_subgraph = consider_conflicted_subgraph let conflicted_subgraph = consider_conflicted_subgraph

View File

@@ -1,23 +1,22 @@
use std::{ use std::{collections::HashSet as Set, iter::once, ops::Deref};
collections::HashSet as Set,
iter::once,
mem::take,
sync::{Arc, Mutex},
};
use futures::{Future, FutureExt, Stream, StreamExt}; use futures::{
Future, Stream, StreamExt,
stream::{FuturesUnordered, unfold},
};
use ruma::OwnedEventId; use ruma::OwnedEventId;
use tuwunel_core::{ use tuwunel_core::{
Result, debug, Result, implement,
matrix::{Event, pdu::AuthEvents}, matrix::{Event, pdu::AuthEvents},
smallvec::SmallVec, smallvec::SmallVec,
utils::stream::{IterStream, automatic_width}, utils::stream::IterStream,
}; };
#[derive(Default)] #[derive(Default)]
struct Global { struct Global<Fut: Future + Send> {
subgraph: Mutex<Set<OwnedEventId>>, subgraph: Set<OwnedEventId>,
seen: Mutex<Set<OwnedEventId>>, seen: Set<OwnedEventId>,
todo: Todo<Fut>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
@@ -26,16 +25,24 @@ struct Local {
stack: Stack, stack: Stack,
} }
type Todo<Fut> = FuturesUnordered<Fut>;
type Path = SmallVec<[OwnedEventId; PATH_INLINE]>; type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
type Stack = SmallVec<[Frame; STACK_INLINE]>; type Stack = SmallVec<[Frame; STACK_INLINE]>;
type Frame = AuthEvents; type Frame = AuthEvents;
const PATH_INLINE: usize = 48; const PATH_INLINE: usize = 16;
const STACK_INLINE: usize = 48; const STACK_INLINE: usize = 16;
#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)] #[tracing::instrument(
name = "subgraph_dfs",
level = "debug",
skip_all,
fields(
starting_events = %conflicted_set.len(),
)
)]
pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>( pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
conflicted_event_ids: &Set<&OwnedEventId>, conflicted_set: &Vec<&OwnedEventId>,
fetch: &Fetch, fetch: &Fetch,
) -> impl Stream<Item = OwnedEventId> + Send ) -> impl Stream<Item = OwnedEventId> + Send
where where
@@ -43,104 +50,135 @@ where
Fut: Future<Output = Result<Pdu>> + Send, Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event, Pdu: Event,
{ {
let state = Arc::new(Global::default()); let state = Global {
let state_ = state.clone(); subgraph: Default::default(),
conflicted_event_ids seen: Default::default(),
.iter() todo: conflicted_set
.stream() .iter()
.enumerate() .map(Deref::deref)
.map(move |(i, event_id)| (state_.clone(), event_id, i)) .cloned()
.for_each_concurrent(automatic_width(), async |(state, event_id, i)| { .map(Local::new)
subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i) .filter_map(Local::pop)
.await .map(|(local, event_id)| local.push(fetch, Some(event_id)))
.expect("only mutex errors expected"); .collect(),
}) };
.map(move |()| {
let seen = state.seen.lock().expect("locked");
let mut state = state.subgraph.lock().expect("locked");
debug!(
input_events = conflicted_event_ids.len(),
seen_events = seen.len(),
output_events = state.len(),
"conflicted subgraph state"
);
take(&mut *state).into_iter().stream() unfold(state, |mut state| async {
}) let outputs = state
.flatten_stream() .todo
.next()
.await?
.pop()
.map(|(local, event_id)| local.eval(&mut state, conflicted_set, event_id))
.map(|(local, next_id, outputs)| {
if !local.stack.is_empty() {
state.todo.push(local.push(fetch, next_id));
}
outputs
})
.into_iter()
.flatten()
.stream();
Some((outputs, state))
})
.flatten()
} }
#[implement(Local)]
#[tracing::instrument( #[tracing::instrument(
name = "descent", name = "descent",
level = "trace", level = "trace",
skip_all, skip_all,
fields( fields(
event_ids = conflicted_event_ids.len(), ?event_id,
event_id = %conflicted_event_id, path = self.path.len(),
%i, stack = self.stack.len(),
subgraph = state.subgraph.len(),
seen = state.seen.len(),
) )
)] )]
async fn subgraph_descent<Fetch, Fut, Pdu>( fn eval<Fut: Future + Send>(
conflicted_event_ids: &Set<&OwnedEventId>, mut self,
fetch: &Fetch, state: &mut Global<Fut>,
state: &Arc<Global>, conflicted_event_ids: &Vec<&OwnedEventId>,
conflicted_event_id: &OwnedEventId, event_id: OwnedEventId,
i: usize, ) -> (Self, Option<OwnedEventId>, Path) {
) -> Result let Global { subgraph, seen, .. } = state;
let insert_path = |global: &mut Global<Fut>, local: &Local| {
local
.path
.iter()
.filter(|&event_id| global.subgraph.insert(event_id.clone()))
.cloned()
.collect()
};
if subgraph.contains(&event_id) {
if self.path.len() <= 1 {
self.path.pop();
return (self, None, Path::new());
}
let path = insert_path(state, &self);
self.path.pop();
return (self, None, path);
}
if !seen.insert(event_id.clone()) {
return (self, None, Path::new());
}
if self.path.len() > 1
&& conflicted_event_ids
.binary_search(&&event_id)
.is_ok()
{
let path = insert_path(state, &self);
return (self, Some(event_id), path);
}
(self, Some(event_id), Path::new())
}
#[implement(Local)]
async fn push<Fetch, Fut, Pdu>(mut self, fetch: &Fetch, event_id: Option<OwnedEventId>) -> Self
where where
Fetch: Fn(OwnedEventId) -> Fut + Sync, Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send, Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event, Pdu: Event,
{ {
let Global { subgraph, seen } = &**state; if let Some(event_id) = event_id
&& let Ok(event) = fetch(event_id).await
let mut local = Local { {
path: once(conflicted_event_id.clone()).collect(), self.stack
stack: once(once(conflicted_event_id.clone()).collect()).collect(), .push(event.auth_events_into().into_iter().collect());
};
while let Some(event_id) = pop(&mut local) {
if subgraph.lock()?.contains(&event_id) {
if local.path.len() > 1 {
subgraph
.lock()?
.extend(local.path.iter().cloned());
}
local.path.pop();
continue;
}
if !seen.lock()?.insert(event_id.clone()) {
continue;
}
if local.path.len() > 1 && conflicted_event_ids.contains(&event_id) {
subgraph
.lock()?
.extend(local.path.iter().cloned());
}
if let Ok(event) = fetch(event_id).await {
local
.stack
.push(event.auth_events_into().into_iter().collect());
}
} }
Ok(()) self
} }
fn pop(local: &mut Local) -> Option<OwnedEventId> { #[implement(Local)]
let Local { path, stack } = local; fn pop(mut self) -> Option<(Self, OwnedEventId)> {
while self.stack.last().is_some_and(Frame::is_empty) {
while stack.last().is_some_and(Frame::is_empty) { self.stack.pop();
stack.pop(); self.path.pop();
path.pop();
} }
stack self.stack
.last_mut() .last_mut()
.and_then(Frame::pop) .and_then(Frame::pop)
.inspect(|event_id| path.push(event_id.clone())) .inspect(|event_id| self.path.push(event_id.clone()))
.map(move |event_id| (self, event_id))
}
#[implement(Local)]
#[allow(clippy::redundant_clone)] // buggy, nursery
fn new(conflicted_event_id: OwnedEventId) -> Self {
Self {
path: once(conflicted_event_id.clone()).collect(),
stack: once(once(conflicted_event_id).collect()).collect(),
}
} }

View File

@@ -195,7 +195,10 @@ where
.await .await
.inspect_err(|e| { .inspect_err(|e| {
debug_warn!( debug_warn!(
event_type = ?event.event_type(), ?state_key, %event_id, %event_id,
sender = %event.sender(),
event_type = ?event.event_type(),
?state_key,
"event failed auth check: {e}" "event failed auth check: {e}"
); );
}) })