From 39cf12481320f75f1868e336ad66dfdd52a1bd70 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 25 Feb 2026 07:38:35 +0000 Subject: [PATCH] Refactor conflicted_subgraph into stream::unfold() pattern. Signed-off-by: Jason Volk --- src/api/mod.rs | 1 - src/service/rooms/state_res/resolve.rs | 14 +- .../state_res/resolve/conflicted_subgraph.rs | 224 ++++++++++-------- .../state_res/resolve/iterative_auth_check.rs | 5 +- 4 files changed, 146 insertions(+), 98 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 46db297e..c4bef8dc 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,3 @@ -#![type_length_limit = "262144"] //TODO: REDUCE ME #![expect(clippy::toplevel_ref_arg)] #![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91 diff --git a/src/service/rooms/state_res/resolve.rs b/src/service/rooms/state_res/resolve.rs index 4f38ec5c..62c56666 100644 --- a/src/service/rooms/state_res/resolve.rs +++ b/src/service/rooms/state_res/resolve.rs @@ -131,7 +131,10 @@ where .boxed() .await?; - let power_set_event_ids: Vec<_> = sorted_power_set.iter().sorted().collect(); + let power_set_event_ids: Vec<_> = sorted_power_set + .iter() + .sorted_unstable() + .collect(); let sorted_power_set = sorted_power_set .iter() @@ -236,7 +239,7 @@ where ExistsFut: Future + Send, FetchEvent: Fn(OwnedEventId) -> EventFut + Sync, EventFut: Future> + Send, - Pdu: Event + Clone, + Pdu: Event, { let consider_conflicted_subgraph = rules .state_res @@ -244,7 +247,12 @@ where .is_some_and(|rules| rules.consider_conflicted_state_subgraph) || backport_css; - let conflicted_state_set: HashSet<_> = conflicted_states.values().flatten().collect(); + let conflicted_state_set: Vec<_> = conflicted_states + .values() + .flatten() + .sorted_unstable() + .dedup() + .collect(); // Since `org.matrix.hydra.11`, fetch the conflicted state subgraph. let conflicted_subgraph = consider_conflicted_subgraph diff --git a/src/service/rooms/state_res/resolve/conflicted_subgraph.rs b/src/service/rooms/state_res/resolve/conflicted_subgraph.rs index 8fc8cdd3..1d39b7a3 100644 --- a/src/service/rooms/state_res/resolve/conflicted_subgraph.rs +++ b/src/service/rooms/state_res/resolve/conflicted_subgraph.rs @@ -1,23 +1,22 @@ -use std::{ - collections::HashSet as Set, - iter::once, - mem::take, - sync::{Arc, Mutex}, -}; +use std::{collections::HashSet as Set, iter::once, ops::Deref}; -use futures::{Future, FutureExt, Stream, StreamExt}; +use futures::{ + Future, Stream, StreamExt, + stream::{FuturesUnordered, unfold}, +}; use ruma::OwnedEventId; use tuwunel_core::{ - Result, debug, + Result, implement, matrix::{Event, pdu::AuthEvents}, smallvec::SmallVec, - utils::stream::{IterStream, automatic_width}, + utils::stream::IterStream, }; #[derive(Default)] -struct Global { - subgraph: Mutex>, - seen: Mutex>, +struct Global { + subgraph: Set, + seen: Set, + todo: Todo, } #[derive(Default, Debug)] @@ -26,16 +25,24 @@ struct Local { stack: Stack, } +type Todo = FuturesUnordered; type Path = SmallVec<[OwnedEventId; PATH_INLINE]>; type Stack = SmallVec<[Frame; STACK_INLINE]>; type Frame = AuthEvents; -const PATH_INLINE: usize = 48; -const STACK_INLINE: usize = 48; +const PATH_INLINE: usize = 16; +const STACK_INLINE: usize = 16; -#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)] +#[tracing::instrument( + name = "subgraph_dfs", + level = "debug", + skip_all, + fields( + starting_events = %conflicted_set.len(), + ) +)] pub(super) fn conflicted_subgraph_dfs( - conflicted_event_ids: &Set<&OwnedEventId>, + conflicted_set: &Vec<&OwnedEventId>, fetch: &Fetch, ) -> impl Stream + Send where @@ -43,104 +50,135 @@ where Fut: Future> + Send, Pdu: Event, { - let state = Arc::new(Global::default()); - let state_ = state.clone(); - conflicted_event_ids - .iter() - .stream() - .enumerate() - .map(move |(i, event_id)| (state_.clone(), event_id, i)) - .for_each_concurrent(automatic_width(), async |(state, event_id, i)| { - subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i) - .await - .expect("only mutex errors expected"); - }) - .map(move |()| { - let seen = state.seen.lock().expect("locked"); - let mut state = state.subgraph.lock().expect("locked"); - debug!( - input_events = conflicted_event_ids.len(), - seen_events = seen.len(), - output_events = state.len(), - "conflicted subgraph state" - ); + let state = Global { + subgraph: Default::default(), + seen: Default::default(), + todo: conflicted_set + .iter() + .map(Deref::deref) + .cloned() + .map(Local::new) + .filter_map(Local::pop) + .map(|(local, event_id)| local.push(fetch, Some(event_id))) + .collect(), + }; - take(&mut *state).into_iter().stream() - }) - .flatten_stream() + unfold(state, |mut state| async { + let outputs = state + .todo + .next() + .await? + .pop() + .map(|(local, event_id)| local.eval(&mut state, conflicted_set, event_id)) + .map(|(local, next_id, outputs)| { + if !local.stack.is_empty() { + state.todo.push(local.push(fetch, next_id)); + } + + outputs + }) + .into_iter() + .flatten() + .stream(); + + Some((outputs, state)) + }) + .flatten() } +#[implement(Local)] #[tracing::instrument( name = "descent", level = "trace", skip_all, fields( - event_ids = conflicted_event_ids.len(), - event_id = %conflicted_event_id, - %i, + ?event_id, + path = self.path.len(), + stack = self.stack.len(), + subgraph = state.subgraph.len(), + seen = state.seen.len(), ) )] -async fn subgraph_descent( - conflicted_event_ids: &Set<&OwnedEventId>, - fetch: &Fetch, - state: &Arc, - conflicted_event_id: &OwnedEventId, - i: usize, -) -> Result +fn eval( + mut self, + state: &mut Global, + conflicted_event_ids: &Vec<&OwnedEventId>, + event_id: OwnedEventId, +) -> (Self, Option, Path) { + let Global { subgraph, seen, .. } = state; + + let insert_path = |global: &mut Global, local: &Local| { + local + .path + .iter() + .filter(|&event_id| global.subgraph.insert(event_id.clone())) + .cloned() + .collect() + }; + + if subgraph.contains(&event_id) { + if self.path.len() <= 1 { + self.path.pop(); + return (self, None, Path::new()); + } + + let path = insert_path(state, &self); + self.path.pop(); + return (self, None, path); + } + + if !seen.insert(event_id.clone()) { + return (self, None, Path::new()); + } + + if self.path.len() > 1 + && conflicted_event_ids + .binary_search(&&event_id) + .is_ok() + { + let path = insert_path(state, &self); + return (self, Some(event_id), path); + } + + (self, Some(event_id), Path::new()) +} + +#[implement(Local)] +async fn push(mut self, fetch: &Fetch, event_id: Option) -> Self where Fetch: Fn(OwnedEventId) -> Fut + Sync, Fut: Future> + Send, Pdu: Event, { - let Global { subgraph, seen } = &**state; - - let mut local = Local { - path: once(conflicted_event_id.clone()).collect(), - stack: once(once(conflicted_event_id.clone()).collect()).collect(), - }; - - while let Some(event_id) = pop(&mut local) { - if subgraph.lock()?.contains(&event_id) { - if local.path.len() > 1 { - subgraph - .lock()? - .extend(local.path.iter().cloned()); - } - - local.path.pop(); - continue; - } - - if !seen.lock()?.insert(event_id.clone()) { - continue; - } - - if local.path.len() > 1 && conflicted_event_ids.contains(&event_id) { - subgraph - .lock()? - .extend(local.path.iter().cloned()); - } - - if let Ok(event) = fetch(event_id).await { - local - .stack - .push(event.auth_events_into().into_iter().collect()); - } + if let Some(event_id) = event_id + && let Ok(event) = fetch(event_id).await + { + self.stack + .push(event.auth_events_into().into_iter().collect()); } - Ok(()) + self } -fn pop(local: &mut Local) -> Option { - let Local { path, stack } = local; - - while stack.last().is_some_and(Frame::is_empty) { - stack.pop(); - path.pop(); +#[implement(Local)] +fn pop(mut self) -> Option<(Self, OwnedEventId)> { + while self.stack.last().is_some_and(Frame::is_empty) { + self.stack.pop(); + self.path.pop(); } - stack + self.stack .last_mut() .and_then(Frame::pop) - .inspect(|event_id| path.push(event_id.clone())) + .inspect(|event_id| self.path.push(event_id.clone())) + .map(move |event_id| (self, event_id)) +} + +#[implement(Local)] +#[allow(clippy::redundant_clone)] // buggy, nursery +fn new(conflicted_event_id: OwnedEventId) -> Self { + Self { + path: once(conflicted_event_id.clone()).collect(), + stack: once(once(conflicted_event_id).collect()).collect(), + } } diff --git a/src/service/rooms/state_res/resolve/iterative_auth_check.rs b/src/service/rooms/state_res/resolve/iterative_auth_check.rs index 6870dc52..50abe023 100644 --- a/src/service/rooms/state_res/resolve/iterative_auth_check.rs +++ b/src/service/rooms/state_res/resolve/iterative_auth_check.rs @@ -195,7 +195,10 @@ where .await .inspect_err(|e| { debug_warn!( - event_type = ?event.event_type(), ?state_key, %event_id, + %event_id, + sender = %event.sender(), + event_type = ?event.event_type(), + ?state_key, "event failed auth check: {e}" ); })