Optimize get_auth_chain_inner concurrency.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,21 +1,26 @@
|
||||
mod data;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeSet, HashSet, VecDeque},
|
||||
collections::{BTreeSet, HashSet},
|
||||
fmt::Debug,
|
||||
iter::once,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut};
|
||||
use futures::{
|
||||
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
||||
stream::{FuturesUnordered, unfold},
|
||||
};
|
||||
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, debug, debug_error, implement,
|
||||
matrix::Event,
|
||||
matrix::{Event, PduEvent},
|
||||
pdu::AuthEvents,
|
||||
trace,
|
||||
utils::{
|
||||
IterStream,
|
||||
stream::{ReadyExt, TryBroadbandExt},
|
||||
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
|
||||
},
|
||||
validated, warn,
|
||||
};
|
||||
@@ -124,9 +129,9 @@ where
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn get_auth_chain_outer(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
async fn get_auth_chain_outer<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
started: Instant,
|
||||
chunk: Bucket<'_>,
|
||||
room_rules: &RoomVersionRules,
|
||||
@@ -144,20 +149,21 @@ async fn get_auth_chain_outer(
|
||||
return Ok(cached.to_vec());
|
||||
}
|
||||
|
||||
let chunk_cache: Vec<_> = chunk
|
||||
let chunk_cache = chunk
|
||||
.into_iter()
|
||||
.try_stream()
|
||||
.broad_and_then(async |(shortid, event_id)| {
|
||||
.stream()
|
||||
.broad_then(async |(shortid, event_id)| {
|
||||
if let Ok(cached) = self
|
||||
.get_cached_eventid_authchain(&[shortid])
|
||||
.await
|
||||
{
|
||||
return Ok(cached.to_vec());
|
||||
return cached.to_vec();
|
||||
}
|
||||
|
||||
let auth_chain = self
|
||||
let auth_chain: Vec<_> = self
|
||||
.get_auth_chain_inner(room_id, event_id, room_rules)
|
||||
.await?;
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
|
||||
debug!(
|
||||
@@ -166,16 +172,16 @@ async fn get_auth_chain_outer(
|
||||
"Cache missed event"
|
||||
);
|
||||
|
||||
Ok(auth_chain)
|
||||
auth_chain
|
||||
})
|
||||
.try_collect()
|
||||
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
|
||||
.map_ok(|mut chunk_cache: Vec<_>| {
|
||||
.collect()
|
||||
.map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
|
||||
.map(|mut chunk_cache: Vec<_>| {
|
||||
chunk_cache.sort_unstable();
|
||||
chunk_cache.dedup();
|
||||
chunk_cache
|
||||
})
|
||||
.await?;
|
||||
.await;
|
||||
|
||||
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
|
||||
debug!(
|
||||
@@ -191,66 +197,93 @@ async fn get_auth_chain_outer(
|
||||
#[tracing::instrument(
|
||||
name = "inner",
|
||||
level = "trace",
|
||||
skip(self, room_id, room_rules)
|
||||
skip_all,
|
||||
fields(%event_id),
|
||||
)]
|
||||
async fn get_auth_chain_inner(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
room_rules: &RoomVersionRules,
|
||||
) -> Result<Vec<ShortEventId>> {
|
||||
let mut found = HashSet::new();
|
||||
let mut todo: VecDeque<_> = [event_id.to_owned()].into();
|
||||
|
||||
if room_rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id
|
||||
{
|
||||
let create_id = room_id.as_event_id()?;
|
||||
let sauthevent = self
|
||||
.services
|
||||
.short
|
||||
.get_or_create_shorteventid(&create_id)
|
||||
.await;
|
||||
|
||||
found.insert(sauthevent);
|
||||
fn get_auth_chain_inner<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
event_id: &'a EventId,
|
||||
room_rules: &'a RoomVersionRules,
|
||||
) -> impl Stream<Item = ShortEventId> + Send + 'a {
|
||||
struct State<Fut> {
|
||||
todo: FuturesUnordered<Fut>,
|
||||
seen: HashSet<OwnedEventId>,
|
||||
}
|
||||
|
||||
while let Some(event_id) = todo.pop_front() {
|
||||
trace!(?event_id, "processing auth event");
|
||||
let state = State {
|
||||
todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
|
||||
seen: room_rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id
|
||||
.then_some(room_id.as_event_id().ok())
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect(),
|
||||
};
|
||||
|
||||
match self.services.timeline.get_pdu(&event_id).await {
|
||||
| Err(e) => {
|
||||
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
|
||||
},
|
||||
| Ok(pdu) => {
|
||||
if pdu.room_id != room_id {
|
||||
return Err!(Request(Forbidden(error!(
|
||||
?event_id,
|
||||
?room_id,
|
||||
wrong_room_id = ?pdu.room_id,
|
||||
"auth event for incorrect room"
|
||||
))));
|
||||
}
|
||||
unfold(state, move |mut state| async move {
|
||||
match state.todo.next().await {
|
||||
| None => None,
|
||||
| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
|
||||
| Some(Ok(pdu)) => {
|
||||
let push = |auth_event: &OwnedEventId| {
|
||||
trace!(?event_id, ?auth_event, "push");
|
||||
state
|
||||
.todo
|
||||
.push(self.get_pdu(room_id, auth_event.clone()));
|
||||
};
|
||||
|
||||
for auth_event in pdu.auth_events() {
|
||||
let sauthevent = self
|
||||
.services
|
||||
.short
|
||||
.get_or_create_shorteventid(auth_event)
|
||||
.await;
|
||||
let seen = |auth_event: OwnedEventId| {
|
||||
state
|
||||
.seen
|
||||
.insert(auth_event.clone())
|
||||
.then_some(auth_event)
|
||||
};
|
||||
|
||||
if found.insert(sauthevent) {
|
||||
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
|
||||
let out = pdu
|
||||
.auth_events()
|
||||
.map(ToOwned::to_owned)
|
||||
.filter_map(seen)
|
||||
.inspect(push)
|
||||
.collect::<AuthEvents>()
|
||||
.into_iter()
|
||||
.stream();
|
||||
|
||||
todo.push_back(auth_event.to_owned());
|
||||
}
|
||||
}
|
||||
Some((out, state))
|
||||
},
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.broad_then(async move |auth_event| {
|
||||
self.services
|
||||
.short
|
||||
.get_or_create_shorteventid(&auth_event)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
|
||||
let pdu = self
|
||||
.services
|
||||
.timeline
|
||||
.get_pdu(&event_id)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
|
||||
})?;
|
||||
|
||||
if pdu.room_id() != room_id {
|
||||
return Err!(Request(Forbidden(error!(
|
||||
?event_id,
|
||||
?room_id,
|
||||
wrong_room_id = ?pdu.room_id(),
|
||||
"auth event for incorrect room",
|
||||
))));
|
||||
}
|
||||
|
||||
Ok(found.into_iter().collect())
|
||||
Ok(pdu)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
|
||||
Reference in New Issue
Block a user