From 254b53adf4e20a91dbff94da2cd3b7aa79b46cce Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 27 Feb 2026 11:01:12 +0000 Subject: [PATCH] Split auth_chain shortid and eventid gathering callstacks. Optimize event parse for auth_chain auth_events fetch. Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 319 ++++++++++++++++++---------- src/service/rooms/short/mod.rs | 4 +- 2 files changed, 206 insertions(+), 117 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index f8982d23..7851e087 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,6 +1,5 @@ use std::{ collections::{BTreeSet, HashSet}, - fmt::Debug, iter::once, sync::Arc, time::Instant, @@ -8,18 +7,24 @@ use std::{ use async_trait::async_trait; use futures::{ - FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, + FutureExt, Stream, StreamExt, TryFutureExt, pin_mut, stream::{FuturesUnordered, unfold}, }; -use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules}; +use ruma::{ + EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, + room_version_rules::RoomVersionRules, +}; +use serde::Deserialize; use tuwunel_core::{ Err, Result, at, debug, debug_error, err, implement, - matrix::{Event, PduEvent, room_version}, + itertools::Itertools, + matrix::room_version, pdu::AuthEvents, + smallvec::SmallVec, trace, utils, utils::{ IterStream, - stream::{BroadbandExt, ReadyExt, TryBroadbandExt}, + stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width}, }, validated, warn, }; @@ -37,7 +42,8 @@ struct Data { shorteventid_authchain: Arc, } -type Bucket<'a> = BTreeSet<(u64, &'a EventId)>; +type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>; +type CacheKey = SmallVec<[ShortEventId; 1]>; #[async_trait] impl crate::Service for Service { @@ -64,7 +70,7 @@ pub fn event_ids_iter<'a, I>( starting_events: I, ) -> impl Stream> + Send + 'a where - I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + I: Iterator + Clone + ExactSizeIterator + Send + 'a, { self.get_auth_chain(room_id, room_version, starting_events) .map_ok(|chain| { @@ -81,7 +87,10 @@ where name = "auth_chain", level = "debug", skip_all, - fields(%room_id), + fields( + %room_id, + starting_events = %starting_events.clone().count(), + ) )] pub async fn get_auth_chain<'a, I>( &'a self, @@ -90,18 +99,19 @@ pub async fn get_auth_chain<'a, I>( starting_events: I, ) -> Result> where - I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + I: Iterator + Clone + ExactSizeIterator + Send + 'a, { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: Bucket<'_> = BTreeSet::new(); let started = Instant::now(); let room_rules = room_version::rules(room_version)?; + let starting_events_count = starting_events.clone().count(); let starting_ids = self .services .short .multi_get_or_create_shorteventid(starting_events.clone()) - .zip(starting_events.clone().stream()); + .zip(starting_events.stream()); pin_mut!(starting_ids); let mut buckets = [BUCKET; NUM_BUCKETS]; @@ -112,24 +122,30 @@ where } debug!( - starting_events = ?starting_events.count(), + starting_events = starting_events_count, elapsed = ?started.elapsed(), "start", ); let full_auth_chain: Vec = buckets - .into_iter() - .try_stream() - .broad_and_then(|chunk| self.get_auth_chain_outer(room_id, started, chunk, &room_rules)) - .try_collect() - .map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect()) - .map_ok(|mut full_auth_chain: Vec<_>| { - full_auth_chain.sort_unstable(); - full_auth_chain.dedup(); - full_auth_chain + .iter() + .stream() + .flat_map_unordered(automatic_width(), |starting_events| { + self.get_chunk_auth_chain( + room_id, + &started, + starting_events.iter().copied(), + &room_rules, + ) + .boxed() }) + .collect::>() + .map(IntoIterator::into_iter) + .map(Itertools::sorted_unstable) + .map(Itertools::dedup) + .map(Iterator::collect) .boxed() - .await?; + .await; debug!( chain_length = ?full_auth_chain.len(), @@ -141,80 +157,118 @@ where } #[implement(Service)] -async fn get_auth_chain_outer<'a>( +#[tracing::instrument( + name = "outer", + level = "trace", + skip_all, + fields( + starting_events = %starting_events.clone().count(), + ) +)] +pub fn get_chunk_auth_chain<'a, I>( &'a self, room_id: &'a RoomId, - started: Instant, - chunk: Bucket<'_>, - room_rules: &RoomVersionRules, -) -> Result> { - let chunk_key: Vec = chunk.iter().map(at!(0)).collect(); + started: &'a Instant, + starting_events: I, + room_rules: &'a RoomVersionRules, +) -> impl Stream + Send + 'a +where + I: Iterator + Clone + Send + Sync + 'a, +{ + let starting_shortids = starting_events.clone().map(at!(0)); - if let Ok(cached) = self.get_cached_auth_chain(&chunk_key).await { - return Ok(cached); - } + let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| { + if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await { + return cached; + } - let chunk_cache = chunk - .into_iter() - .stream() - .broad_then(async |(shortid, event_id)| { - if let Ok(cached) = self.get_cached_auth_chain(&[shortid]).await { - return cached; - } + let auth_chain: Vec<_> = self + .get_event_auth_chain(room_id, event_id, room_rules) + .collect() + .await; - let auth_chain: Vec<_> = self - .get_auth_chain_inner(room_id, event_id, room_rules) - .collect() - .await; + self.put_cached_auth_chain(once(shortid), auth_chain.as_slice()); + debug!( + ?event_id, + elapsed = ?started.elapsed(), + "Cache missed event" + ); - self.put_cached_auth_chain(&[shortid], auth_chain.as_slice()); - debug!( - ?event_id, - elapsed = ?started.elapsed(), - "Cache missed event" - ); + auth_chain + }; - auth_chain + let cache_chain = move |chunk_cache: &Vec<_>| { + self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice()); + debug!( + chunk_cache_length = ?chunk_cache.len(), + elapsed = ?started.elapsed(), + "Cache missed chunk", + ); + }; + + self.get_cached_auth_chain(starting_events.clone().map(at!(0))) + .map_ok(IntoIterator::into_iter) + .map_ok(IterStream::try_stream) + .or_else(move |_| async move { + starting_events + .clone() + .stream() + .broad_then(build_chain) + .collect::>() + .map(IntoIterator::into_iter) + .map(Iterator::flatten) + .map(Itertools::sorted_unstable) + .map(Itertools::dedup) + .map(Iterator::collect) + .inspect(cache_chain) + .map(IntoIterator::into_iter) + .map(IterStream::try_stream) + .map(Ok) + .await }) - .collect() - .map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) - .map(|mut chunk_cache: Vec<_>| { - chunk_cache.sort_unstable(); - chunk_cache.dedup(); - chunk_cache - }) - .await; - - self.put_cached_auth_chain(&chunk_key, chunk_cache.as_slice()); - debug!( - chunk_cache_length = ?chunk_cache.len(), - elapsed = ?started.elapsed(), - "Cache missed chunk", - ); - - Ok(chunk_cache) + .try_flatten_stream() + .map_expect("either cache hit or cache miss yields a chain") } #[implement(Service)] -#[tracing::instrument( - name = "inner", - level = "trace", - skip_all, - fields(%event_id), -)] -fn get_auth_chain_inner<'a>( +#[tracing::instrument(name = "inner", level = "trace", skip_all)] +pub fn get_event_auth_chain<'a>( &'a self, room_id: &'a RoomId, event_id: &'a EventId, room_rules: &'a RoomVersionRules, ) -> impl Stream + Send + 'a { + self.get_event_auth_chain_ids(room_id, event_id, room_rules) + .broad_then(async move |auth_event| { + self.services + .short + .get_or_create_shorteventid(&auth_event) + .await + }) +} + +#[implement(Service)] +#[tracing::instrument( + name = "inner_ids", + level = "trace", + skip_all, + fields(%event_id) +)] +pub fn get_event_auth_chain_ids<'a>( + &'a self, + room_id: &'a RoomId, + event_id: &'a EventId, + room_rules: &'a RoomVersionRules, +) -> impl Stream + Send + 'a { struct State { todo: FuturesUnordered, seen: HashSet, } + let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned()); + let state = State { - todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(), + todo: once(starting_events).collect(), seen: room_rules .authorization .room_create_event_id_as_room_id @@ -228,12 +282,12 @@ fn get_auth_chain_inner<'a>( match state.todo.next().await { | None => None, | Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)), - | Some(Ok(pdu)) => { + | Some(Ok(auth_events)) => { let push = |auth_event: &OwnedEventId| { trace!(?event_id, ?auth_event, "push"); state .todo - .push(self.get_pdu(room_id, auth_event.clone())); + .push(self.get_event_auth_event_ids(room_id, auth_event.clone())); }; let seen = |auth_event: OwnedEventId| { @@ -243,9 +297,8 @@ fn get_auth_chain_inner<'a>( .then_some(auth_event) }; - let out = pdu - .auth_events() - .map(ToOwned::to_owned) + let out = auth_events + .into_iter() .filter_map(seen) .inspect(push) .collect::() @@ -257,45 +310,29 @@ fn get_auth_chain_inner<'a>( } }) .flatten() - .broad_then(async move |auth_event| { - self.services - .short - .get_or_create_shorteventid(&auth_event) - .await - }) } #[implement(Service)] -async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result { - let pdu = self - .services - .timeline - .get_pdu(&event_id) - .await - .inspect_err(|e| { - debug_error!(?event_id, ?room_id, "auth chain event: {e}"); - })?; +#[tracing::instrument( + name = "cache_put", + level = "debug", + skip_all, + fields( + key_len = key.clone().count(), + chain_len = auth_chain.len(), + ) +)] +fn put_cached_auth_chain(&self, key: I, auth_chain: &[ShortEventId]) +where + I: Iterator + Clone + Send, +{ + let key = key.collect::(); - if pdu.room_id() != room_id { - return Err!(Request(Forbidden(error!( - ?event_id, - ?room_id, - wrong_room_id = ?pdu.room_id(), - "auth event for incorrect room", - )))); - } - - Ok(pdu) -} - -#[implement(Service)] -#[tracing::instrument(skip_all, level = "debug")] -fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId]) { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); self.db .authchainkey_authchain - .put(key, auth_chain); + .put(key.as_slice(), auth_chain); if key.len() == 1 { self.db @@ -305,8 +342,21 @@ fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId } #[implement(Service)] -#[tracing::instrument(skip_all, level = "trace")] -async fn get_cached_auth_chain(&self, key: &[u64]) -> Result> { +#[tracing::instrument( + name = "cache_get", + level = "trace", + err(level = "trace"), + skip_all, + fields( + key_len = %key.clone().count() + ), +)] +async fn get_cached_auth_chain(&self, key: I) -> Result> +where + I: Iterator + Clone + Send, +{ + let key = key.collect::(); + if key.is_empty() { return Ok(Vec::new()); } @@ -315,7 +365,7 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result> let chain = self .db .authchainkey_authchain - .qry(key) + .qry(key.as_slice()) .map_err(|_| err!(Request(NotFound("auth_chain not cached")))) .or_else(async |e| { if key.len() > 1 { @@ -328,12 +378,51 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result> .map_err(|_| err!(Request(NotFound("auth_chain not found")))) .await }) - .await?; - - let chain = chain + .await? .chunks_exact(size_of::()) .map(utils::u64_from_u8) .collect(); Ok(chain) } + +#[implement(Service)] +#[tracing::instrument( + name = "auth_events", + level = "trace", + ret(level = "trace"), + err(level = "trace"), + skip_all, + fields(%event_id) +)] +async fn get_event_auth_event_ids<'a>( + &'a self, + room_id: &'a RoomId, + event_id: OwnedEventId, +) -> Result { + #[derive(Deserialize)] + struct Pdu { + auth_events: AuthEvents, + room_id: OwnedRoomId, + } + + let pdu: Pdu = self + .services + .timeline + .get(&event_id) + .inspect_err(|e| { + debug_error!(?event_id, ?room_id, "auth chain event: {e}"); + }) + .await?; + + if pdu.room_id != room_id { + return Err!(Request(Forbidden(error!( + ?event_id, + ?room_id, + wrong_room_id = ?pdu.room_id, + "auth event for incorrect room", + )))); + } + + Ok(pdu.auth_events) +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index f5562868..59d2f5d1 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, fmt::Debug, mem::size_of_val, sync::Arc}; +use std::{borrow::Borrow, mem::size_of_val, sync::Arc}; use futures::{FutureExt, Stream, StreamExt, pin_mut}; use ruma::{EventId, OwnedRoomId, RoomId, events::StateEventType}; @@ -61,7 +61,7 @@ pub fn multi_get_or_create_shorteventid<'a, I>( event_ids: I, ) -> impl Stream + Send + '_ where - I: Iterator + Clone + Debug + Send + 'a, + I: Iterator + Clone + Send + 'a, { event_ids .clone()