Optimize get_auth_chain_inner concurrency.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-20 13:01:58 +00:00
parent a0b98fa575
commit 944f165202

View File

@@ -1,21 +1,26 @@
mod data; mod data;
use std::{ use std::{
collections::{BTreeSet, HashSet, VecDeque}, collections::{BTreeSet, HashSet},
fmt::Debug, fmt::Debug,
iter::once,
sync::Arc, sync::Arc,
time::Instant, time::Instant,
}; };
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut}; use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
stream::{FuturesUnordered, unfold},
};
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules}; use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
use tuwunel_core::{ use tuwunel_core::{
Err, Result, at, debug, debug_error, implement, Err, Result, at, debug, debug_error, implement,
matrix::Event, matrix::{Event, PduEvent},
pdu::AuthEvents,
trace, trace,
utils::{ utils::{
IterStream, IterStream,
stream::{ReadyExt, TryBroadbandExt}, stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
}, },
validated, warn, validated, warn,
}; };
@@ -124,9 +129,9 @@ where
} }
#[implement(Service)] #[implement(Service)]
async fn get_auth_chain_outer( async fn get_auth_chain_outer<'a>(
&self, &'a self,
room_id: &RoomId, room_id: &'a RoomId,
started: Instant, started: Instant,
chunk: Bucket<'_>, chunk: Bucket<'_>,
room_rules: &RoomVersionRules, room_rules: &RoomVersionRules,
@@ -144,20 +149,21 @@ async fn get_auth_chain_outer(
return Ok(cached.to_vec()); return Ok(cached.to_vec());
} }
let chunk_cache: Vec<_> = chunk let chunk_cache = chunk
.into_iter() .into_iter()
.try_stream() .stream()
.broad_and_then(async |(shortid, event_id)| { .broad_then(async |(shortid, event_id)| {
if let Ok(cached) = self if let Ok(cached) = self
.get_cached_eventid_authchain(&[shortid]) .get_cached_eventid_authchain(&[shortid])
.await .await
{ {
return Ok(cached.to_vec()); return cached.to_vec();
} }
let auth_chain = self let auth_chain: Vec<_> = self
.get_auth_chain_inner(room_id, event_id, room_rules) .get_auth_chain_inner(room_id, event_id, room_rules)
.await?; .collect()
.await;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice()); self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
debug!( debug!(
@@ -166,16 +172,16 @@ async fn get_auth_chain_outer(
"Cache missed event" "Cache missed event"
); );
Ok(auth_chain) auth_chain
}) })
.try_collect() .collect()
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) .map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map_ok(|mut chunk_cache: Vec<_>| { .map(|mut chunk_cache: Vec<_>| {
chunk_cache.sort_unstable(); chunk_cache.sort_unstable();
chunk_cache.dedup(); chunk_cache.dedup();
chunk_cache chunk_cache
}) })
.await?; .await;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice()); self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
debug!( debug!(
@@ -191,66 +197,93 @@ async fn get_auth_chain_outer(
#[tracing::instrument( #[tracing::instrument(
name = "inner", name = "inner",
level = "trace", level = "trace",
skip(self, room_id, room_rules) skip_all,
fields(%event_id),
)] )]
async fn get_auth_chain_inner( fn get_auth_chain_inner<'a>(
&self, &'a self,
room_id: &RoomId, room_id: &'a RoomId,
event_id: &EventId, event_id: &'a EventId,
room_rules: &RoomVersionRules, room_rules: &'a RoomVersionRules,
) -> Result<Vec<ShortEventId>> { ) -> impl Stream<Item = ShortEventId> + Send + 'a {
let mut found = HashSet::new(); struct State<Fut> {
let mut todo: VecDeque<_> = [event_id.to_owned()].into(); todo: FuturesUnordered<Fut>,
seen: HashSet<OwnedEventId>,
if room_rules
.authorization
.room_create_event_id_as_room_id
{
let create_id = room_id.as_event_id()?;
let sauthevent = self
.services
.short
.get_or_create_shorteventid(&create_id)
.await;
found.insert(sauthevent);
} }
while let Some(event_id) = todo.pop_front() { let state = State {
trace!(?event_id, "processing auth event"); todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
seen: room_rules
.authorization
.room_create_event_id_as_room_id
.then_some(room_id.as_event_id().ok())
.into_iter()
.flatten()
.collect(),
};
match self.services.timeline.get_pdu(&event_id).await { unfold(state, move |mut state| async move {
| Err(e) => { match state.todo.next().await {
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"); | None => None,
}, | Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
| Ok(pdu) => { | Some(Ok(pdu)) => {
if pdu.room_id != room_id { let push = |auth_event: &OwnedEventId| {
return Err!(Request(Forbidden(error!( trace!(?event_id, ?auth_event, "push");
?event_id, state
?room_id, .todo
wrong_room_id = ?pdu.room_id, .push(self.get_pdu(room_id, auth_event.clone()));
"auth event for incorrect room" };
))));
}
for auth_event in pdu.auth_events() { let seen = |auth_event: OwnedEventId| {
let sauthevent = self state
.services .seen
.short .insert(auth_event.clone())
.get_or_create_shorteventid(auth_event) .then_some(auth_event)
.await; };
if found.insert(sauthevent) { let out = pdu
trace!(?event_id, ?auth_event, "adding auth event to processing queue"); .auth_events()
.map(ToOwned::to_owned)
.filter_map(seen)
.inspect(push)
.collect::<AuthEvents>()
.into_iter()
.stream();
todo.push_back(auth_event.to_owned()); Some((out, state))
}
}
}, },
} }
})
.flatten()
.broad_then(async move |auth_event| {
self.services
.short
.get_or_create_shorteventid(&auth_event)
.await
})
}
#[implement(Service)]
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
let pdu = self
.services
.timeline
.get_pdu(&event_id)
.await
.inspect_err(|e| {
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
})?;
if pdu.room_id() != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
?room_id,
wrong_room_id = ?pdu.room_id(),
"auth event for incorrect room",
))));
} }
Ok(found.into_iter().collect()) Ok(pdu)
} }
#[implement(Service)] #[implement(Service)]