Optimize get_auth_chain_inner concurrency.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-20 13:01:58 +00:00
parent a0b98fa575
commit 944f165202

View File

@@ -1,21 +1,26 @@
mod data;
use std::{
collections::{BTreeSet, HashSet, VecDeque},
collections::{BTreeSet, HashSet},
fmt::Debug,
iter::once,
sync::Arc,
time::Instant,
};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut};
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
stream::{FuturesUnordered, unfold},
};
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
use tuwunel_core::{
Err, Result, at, debug, debug_error, implement,
matrix::Event,
matrix::{Event, PduEvent},
pdu::AuthEvents,
trace,
utils::{
IterStream,
stream::{ReadyExt, TryBroadbandExt},
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
},
validated, warn,
};
@@ -124,9 +129,9 @@ where
}
#[implement(Service)]
async fn get_auth_chain_outer(
&self,
room_id: &RoomId,
async fn get_auth_chain_outer<'a>(
&'a self,
room_id: &'a RoomId,
started: Instant,
chunk: Bucket<'_>,
room_rules: &RoomVersionRules,
@@ -144,20 +149,21 @@ async fn get_auth_chain_outer(
return Ok(cached.to_vec());
}
let chunk_cache: Vec<_> = chunk
let chunk_cache = chunk
.into_iter()
.try_stream()
.broad_and_then(async |(shortid, event_id)| {
.stream()
.broad_then(async |(shortid, event_id)| {
if let Ok(cached) = self
.get_cached_eventid_authchain(&[shortid])
.await
{
return Ok(cached.to_vec());
return cached.to_vec();
}
let auth_chain = self
let auth_chain: Vec<_> = self
.get_auth_chain_inner(room_id, event_id, room_rules)
.await?;
.collect()
.await;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
debug!(
@@ -166,16 +172,16 @@ async fn get_auth_chain_outer(
"Cache missed event"
);
Ok(auth_chain)
auth_chain
})
.try_collect()
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map_ok(|mut chunk_cache: Vec<_>| {
.collect()
.map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map(|mut chunk_cache: Vec<_>| {
chunk_cache.sort_unstable();
chunk_cache.dedup();
chunk_cache
})
.await?;
.await;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
debug!(
@@ -191,66 +197,93 @@ async fn get_auth_chain_outer(
#[tracing::instrument(
name = "inner",
level = "trace",
skip(self, room_id, room_rules)
skip_all,
fields(%event_id),
)]
async fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
room_rules: &RoomVersionRules,
) -> Result<Vec<ShortEventId>> {
let mut found = HashSet::new();
let mut todo: VecDeque<_> = [event_id.to_owned()].into();
if room_rules
.authorization
.room_create_event_id_as_room_id
{
let create_id = room_id.as_event_id()?;
let sauthevent = self
.services
.short
.get_or_create_shorteventid(&create_id)
.await;
found.insert(sauthevent);
fn get_auth_chain_inner<'a>(
&'a self,
room_id: &'a RoomId,
event_id: &'a EventId,
room_rules: &'a RoomVersionRules,
) -> impl Stream<Item = ShortEventId> + Send + 'a {
struct State<Fut> {
todo: FuturesUnordered<Fut>,
seen: HashSet<OwnedEventId>,
}
while let Some(event_id) = todo.pop_front() {
trace!(?event_id, "processing auth event");
let state = State {
todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
seen: room_rules
.authorization
.room_create_event_id_as_room_id
.then_some(room_id.as_event_id().ok())
.into_iter()
.flatten()
.collect(),
};
match self.services.timeline.get_pdu(&event_id).await {
| Err(e) => {
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
},
| Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
?room_id,
wrong_room_id = ?pdu.room_id,
"auth event for incorrect room"
))));
}
unfold(state, move |mut state| async move {
match state.todo.next().await {
| None => None,
| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
| Some(Ok(pdu)) => {
let push = |auth_event: &OwnedEventId| {
trace!(?event_id, ?auth_event, "push");
state
.todo
.push(self.get_pdu(room_id, auth_event.clone()));
};
for auth_event in pdu.auth_events() {
let sauthevent = self
.services
.short
.get_or_create_shorteventid(auth_event)
.await;
let seen = |auth_event: OwnedEventId| {
state
.seen
.insert(auth_event.clone())
.then_some(auth_event)
};
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
let out = pdu
.auth_events()
.map(ToOwned::to_owned)
.filter_map(seen)
.inspect(push)
.collect::<AuthEvents>()
.into_iter()
.stream();
todo.push_back(auth_event.to_owned());
}
}
Some((out, state))
},
}
})
.flatten()
.broad_then(async move |auth_event| {
self.services
.short
.get_or_create_shorteventid(&auth_event)
.await
})
}
#[implement(Service)]
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
let pdu = self
.services
.timeline
.get_pdu(&event_id)
.await
.inspect_err(|e| {
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
})?;
if pdu.room_id() != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
?room_id,
wrong_room_id = ?pdu.room_id(),
"auth event for incorrect room",
))));
}
Ok(found.into_iter().collect())
Ok(pdu)
}
#[implement(Service)]