Optimize get_auth_chain_inner concurrency.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,21 +1,26 @@
|
|||||||
mod data;
|
mod data;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashSet, VecDeque},
|
collections::{BTreeSet, HashSet},
|
||||||
fmt::Debug,
|
fmt::Debug,
|
||||||
|
iter::once,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
time::Instant,
|
time::Instant,
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut};
|
use futures::{
|
||||||
|
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
||||||
|
stream::{FuturesUnordered, unfold},
|
||||||
|
};
|
||||||
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
|
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
|
||||||
use tuwunel_core::{
|
use tuwunel_core::{
|
||||||
Err, Result, at, debug, debug_error, implement,
|
Err, Result, at, debug, debug_error, implement,
|
||||||
matrix::Event,
|
matrix::{Event, PduEvent},
|
||||||
|
pdu::AuthEvents,
|
||||||
trace,
|
trace,
|
||||||
utils::{
|
utils::{
|
||||||
IterStream,
|
IterStream,
|
||||||
stream::{ReadyExt, TryBroadbandExt},
|
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
|
||||||
},
|
},
|
||||||
validated, warn,
|
validated, warn,
|
||||||
};
|
};
|
||||||
@@ -124,9 +129,9 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
async fn get_auth_chain_outer(
|
async fn get_auth_chain_outer<'a>(
|
||||||
&self,
|
&'a self,
|
||||||
room_id: &RoomId,
|
room_id: &'a RoomId,
|
||||||
started: Instant,
|
started: Instant,
|
||||||
chunk: Bucket<'_>,
|
chunk: Bucket<'_>,
|
||||||
room_rules: &RoomVersionRules,
|
room_rules: &RoomVersionRules,
|
||||||
@@ -144,20 +149,21 @@ async fn get_auth_chain_outer(
|
|||||||
return Ok(cached.to_vec());
|
return Ok(cached.to_vec());
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk_cache: Vec<_> = chunk
|
let chunk_cache = chunk
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.try_stream()
|
.stream()
|
||||||
.broad_and_then(async |(shortid, event_id)| {
|
.broad_then(async |(shortid, event_id)| {
|
||||||
if let Ok(cached) = self
|
if let Ok(cached) = self
|
||||||
.get_cached_eventid_authchain(&[shortid])
|
.get_cached_eventid_authchain(&[shortid])
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
return Ok(cached.to_vec());
|
return cached.to_vec();
|
||||||
}
|
}
|
||||||
|
|
||||||
let auth_chain = self
|
let auth_chain: Vec<_> = self
|
||||||
.get_auth_chain_inner(room_id, event_id, room_rules)
|
.get_auth_chain_inner(room_id, event_id, room_rules)
|
||||||
.await?;
|
.collect()
|
||||||
|
.await;
|
||||||
|
|
||||||
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
|
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
|
||||||
debug!(
|
debug!(
|
||||||
@@ -166,16 +172,16 @@ async fn get_auth_chain_outer(
|
|||||||
"Cache missed event"
|
"Cache missed event"
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(auth_chain)
|
auth_chain
|
||||||
})
|
})
|
||||||
.try_collect()
|
.collect()
|
||||||
.map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
|
.map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
|
||||||
.map_ok(|mut chunk_cache: Vec<_>| {
|
.map(|mut chunk_cache: Vec<_>| {
|
||||||
chunk_cache.sort_unstable();
|
chunk_cache.sort_unstable();
|
||||||
chunk_cache.dedup();
|
chunk_cache.dedup();
|
||||||
chunk_cache
|
chunk_cache
|
||||||
})
|
})
|
||||||
.await?;
|
.await;
|
||||||
|
|
||||||
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
|
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
|
||||||
debug!(
|
debug!(
|
||||||
@@ -191,66 +197,93 @@ async fn get_auth_chain_outer(
|
|||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
name = "inner",
|
name = "inner",
|
||||||
level = "trace",
|
level = "trace",
|
||||||
skip(self, room_id, room_rules)
|
skip_all,
|
||||||
|
fields(%event_id),
|
||||||
)]
|
)]
|
||||||
async fn get_auth_chain_inner(
|
fn get_auth_chain_inner<'a>(
|
||||||
&self,
|
&'a self,
|
||||||
room_id: &RoomId,
|
room_id: &'a RoomId,
|
||||||
event_id: &EventId,
|
event_id: &'a EventId,
|
||||||
room_rules: &RoomVersionRules,
|
room_rules: &'a RoomVersionRules,
|
||||||
) -> Result<Vec<ShortEventId>> {
|
) -> impl Stream<Item = ShortEventId> + Send + 'a {
|
||||||
let mut found = HashSet::new();
|
struct State<Fut> {
|
||||||
let mut todo: VecDeque<_> = [event_id.to_owned()].into();
|
todo: FuturesUnordered<Fut>,
|
||||||
|
seen: HashSet<OwnedEventId>,
|
||||||
if room_rules
|
|
||||||
.authorization
|
|
||||||
.room_create_event_id_as_room_id
|
|
||||||
{
|
|
||||||
let create_id = room_id.as_event_id()?;
|
|
||||||
let sauthevent = self
|
|
||||||
.services
|
|
||||||
.short
|
|
||||||
.get_or_create_shorteventid(&create_id)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
found.insert(sauthevent);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
while let Some(event_id) = todo.pop_front() {
|
let state = State {
|
||||||
trace!(?event_id, "processing auth event");
|
todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
|
||||||
|
seen: room_rules
|
||||||
|
.authorization
|
||||||
|
.room_create_event_id_as_room_id
|
||||||
|
.then_some(room_id.as_event_id().ok())
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
|
||||||
match self.services.timeline.get_pdu(&event_id).await {
|
unfold(state, move |mut state| async move {
|
||||||
| Err(e) => {
|
match state.todo.next().await {
|
||||||
debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events");
|
| None => None,
|
||||||
},
|
| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
|
||||||
| Ok(pdu) => {
|
| Some(Ok(pdu)) => {
|
||||||
if pdu.room_id != room_id {
|
let push = |auth_event: &OwnedEventId| {
|
||||||
return Err!(Request(Forbidden(error!(
|
trace!(?event_id, ?auth_event, "push");
|
||||||
?event_id,
|
state
|
||||||
?room_id,
|
.todo
|
||||||
wrong_room_id = ?pdu.room_id,
|
.push(self.get_pdu(room_id, auth_event.clone()));
|
||||||
"auth event for incorrect room"
|
};
|
||||||
))));
|
|
||||||
}
|
|
||||||
|
|
||||||
for auth_event in pdu.auth_events() {
|
let seen = |auth_event: OwnedEventId| {
|
||||||
let sauthevent = self
|
state
|
||||||
.services
|
.seen
|
||||||
.short
|
.insert(auth_event.clone())
|
||||||
.get_or_create_shorteventid(auth_event)
|
.then_some(auth_event)
|
||||||
.await;
|
};
|
||||||
|
|
||||||
if found.insert(sauthevent) {
|
let out = pdu
|
||||||
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
|
.auth_events()
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.filter_map(seen)
|
||||||
|
.inspect(push)
|
||||||
|
.collect::<AuthEvents>()
|
||||||
|
.into_iter()
|
||||||
|
.stream();
|
||||||
|
|
||||||
todo.push_back(auth_event.to_owned());
|
Some((out, state))
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
.broad_then(async move |auth_event| {
|
||||||
|
self.services
|
||||||
|
.short
|
||||||
|
.get_or_create_shorteventid(&auth_event)
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
|
||||||
|
let pdu = self
|
||||||
|
.services
|
||||||
|
.timeline
|
||||||
|
.get_pdu(&event_id)
|
||||||
|
.await
|
||||||
|
.inspect_err(|e| {
|
||||||
|
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if pdu.room_id() != room_id {
|
||||||
|
return Err!(Request(Forbidden(error!(
|
||||||
|
?event_id,
|
||||||
|
?room_id,
|
||||||
|
wrong_room_id = ?pdu.room_id(),
|
||||||
|
"auth event for incorrect room",
|
||||||
|
))));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(found.into_iter().collect())
|
Ok(pdu)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
|
|||||||
Reference in New Issue
Block a user