From 944f165202344b18281b72ea6a56a47be39e8ab1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 20 Jan 2026 13:01:58 +0000 Subject: [PATCH] Optimize get_auth_chain_inner concurrency. Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 167 +++++++++++++++++----------- 1 file changed, 100 insertions(+), 67 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 738bddd4..f7911d97 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,21 +1,26 @@ mod data; use std::{ - collections::{BTreeSet, HashSet, VecDeque}, + collections::{BTreeSet, HashSet}, fmt::Debug, + iter::once, sync::Arc, time::Instant, }; -use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut}; +use futures::{ + FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, + stream::{FuturesUnordered, unfold}, +}; use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules}; use tuwunel_core::{ Err, Result, at, debug, debug_error, implement, - matrix::Event, + matrix::{Event, PduEvent}, + pdu::AuthEvents, trace, utils::{ IterStream, - stream::{ReadyExt, TryBroadbandExt}, + stream::{BroadbandExt, ReadyExt, TryBroadbandExt}, }, validated, warn, }; @@ -124,9 +129,9 @@ where } #[implement(Service)] -async fn get_auth_chain_outer( - &self, - room_id: &RoomId, +async fn get_auth_chain_outer<'a>( + &'a self, + room_id: &'a RoomId, started: Instant, chunk: Bucket<'_>, room_rules: &RoomVersionRules, @@ -144,20 +149,21 @@ async fn get_auth_chain_outer( return Ok(cached.to_vec()); } - let chunk_cache: Vec<_> = chunk + let chunk_cache = chunk .into_iter() - .try_stream() - .broad_and_then(async |(shortid, event_id)| { + .stream() + .broad_then(async |(shortid, event_id)| { if let Ok(cached) = self .get_cached_eventid_authchain(&[shortid]) .await { - return Ok(cached.to_vec()); + return cached.to_vec(); } - let auth_chain = self + let auth_chain: Vec<_> = self .get_auth_chain_inner(room_id, event_id, room_rules) - .await?; + .collect() + .await; self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice()); debug!( @@ -166,16 +172,16 @@ async fn get_auth_chain_outer( "Cache missed event" ); - Ok(auth_chain) + auth_chain }) - .try_collect() - .map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) - .map_ok(|mut chunk_cache: Vec<_>| { + .collect() + .map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) + .map(|mut chunk_cache: Vec<_>| { chunk_cache.sort_unstable(); chunk_cache.dedup(); chunk_cache }) - .await?; + .await; self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice()); debug!( @@ -191,66 +197,93 @@ async fn get_auth_chain_outer( #[tracing::instrument( name = "inner", level = "trace", - skip(self, room_id, room_rules) + skip_all, + fields(%event_id), )] -async fn get_auth_chain_inner( - &self, - room_id: &RoomId, - event_id: &EventId, - room_rules: &RoomVersionRules, -) -> Result> { - let mut found = HashSet::new(); - let mut todo: VecDeque<_> = [event_id.to_owned()].into(); - - if room_rules - .authorization - .room_create_event_id_as_room_id - { - let create_id = room_id.as_event_id()?; - let sauthevent = self - .services - .short - .get_or_create_shorteventid(&create_id) - .await; - - found.insert(sauthevent); +fn get_auth_chain_inner<'a>( + &'a self, + room_id: &'a RoomId, + event_id: &'a EventId, + room_rules: &'a RoomVersionRules, +) -> impl Stream + Send + 'a { + struct State { + todo: FuturesUnordered, + seen: HashSet, } - while let Some(event_id) = todo.pop_front() { - trace!(?event_id, "processing auth event"); + let state = State { + todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(), + seen: room_rules + .authorization + .room_create_event_id_as_room_id + .then_some(room_id.as_event_id().ok()) + .into_iter() + .flatten() + .collect(), + }; - match self.services.timeline.get_pdu(&event_id).await { - | Err(e) => { - debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"); - }, - | Ok(pdu) => { - if pdu.room_id != room_id { - return Err!(Request(Forbidden(error!( - ?event_id, - ?room_id, - wrong_room_id = ?pdu.room_id, - "auth event for incorrect room" - )))); - } + unfold(state, move |mut state| async move { + match state.todo.next().await { + | None => None, + | Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)), + | Some(Ok(pdu)) => { + let push = |auth_event: &OwnedEventId| { + trace!(?event_id, ?auth_event, "push"); + state + .todo + .push(self.get_pdu(room_id, auth_event.clone())); + }; - for auth_event in pdu.auth_events() { - let sauthevent = self - .services - .short - .get_or_create_shorteventid(auth_event) - .await; + let seen = |auth_event: OwnedEventId| { + state + .seen + .insert(auth_event.clone()) + .then_some(auth_event) + }; - if found.insert(sauthevent) { - trace!(?event_id, ?auth_event, "adding auth event to processing queue"); + let out = pdu + .auth_events() + .map(ToOwned::to_owned) + .filter_map(seen) + .inspect(push) + .collect::() + .into_iter() + .stream(); - todo.push_back(auth_event.to_owned()); - } - } + Some((out, state)) }, } + }) + .flatten() + .broad_then(async move |auth_event| { + self.services + .short + .get_or_create_shorteventid(&auth_event) + .await + }) +} + +#[implement(Service)] +async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result { + let pdu = self + .services + .timeline + .get_pdu(&event_id) + .await + .inspect_err(|e| { + debug_error!(?event_id, ?room_id, "auth chain event: {e}"); + })?; + + if pdu.room_id() != room_id { + return Err!(Request(Forbidden(error!( + ?event_id, + ?room_id, + wrong_room_id = ?pdu.room_id(), + "auth event for incorrect room", + )))); } - Ok(found.into_iter().collect()) + Ok(pdu) } #[implement(Service)]