Split auth_chain shortid and eventid gathering callstacks.
Optimize event parse for auth_chain auth_events fetch. Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
collections::{BTreeSet, HashSet},
|
||||
fmt::Debug,
|
||||
iter::once,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
@@ -8,18 +7,24 @@ use std::{
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{
|
||||
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
||||
FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
|
||||
stream::{FuturesUnordered, unfold},
|
||||
};
|
||||
use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
|
||||
use ruma::{
|
||||
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
|
||||
room_version_rules::RoomVersionRules,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, debug, debug_error, err, implement,
|
||||
matrix::{Event, PduEvent, room_version},
|
||||
itertools::Itertools,
|
||||
matrix::room_version,
|
||||
pdu::AuthEvents,
|
||||
smallvec::SmallVec,
|
||||
trace, utils,
|
||||
utils::{
|
||||
IterStream,
|
||||
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
|
||||
stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
|
||||
},
|
||||
validated, warn,
|
||||
};
|
||||
@@ -37,7 +42,8 @@ struct Data {
|
||||
shorteventid_authchain: Arc<Map>,
|
||||
}
|
||||
|
||||
type Bucket<'a> = BTreeSet<(u64, &'a EventId)>;
|
||||
type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
|
||||
type CacheKey = SmallVec<[ShortEventId; 1]>;
|
||||
|
||||
#[async_trait]
|
||||
impl crate::Service for Service {
|
||||
@@ -64,7 +70,7 @@ pub fn event_ids_iter<'a, I>(
|
||||
starting_events: I,
|
||||
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
|
||||
{
|
||||
self.get_auth_chain(room_id, room_version, starting_events)
|
||||
.map_ok(|chain| {
|
||||
@@ -81,7 +87,10 @@ where
|
||||
name = "auth_chain",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%room_id),
|
||||
fields(
|
||||
%room_id,
|
||||
starting_events = %starting_events.clone().count(),
|
||||
)
|
||||
)]
|
||||
pub async fn get_auth_chain<'a, I>(
|
||||
&'a self,
|
||||
@@ -90,18 +99,19 @@ pub async fn get_auth_chain<'a, I>(
|
||||
starting_events: I,
|
||||
) -> Result<Vec<ShortEventId>>
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
|
||||
{
|
||||
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
|
||||
const BUCKET: Bucket<'_> = BTreeSet::new();
|
||||
|
||||
let started = Instant::now();
|
||||
let room_rules = room_version::rules(room_version)?;
|
||||
let starting_events_count = starting_events.clone().count();
|
||||
let starting_ids = self
|
||||
.services
|
||||
.short
|
||||
.multi_get_or_create_shorteventid(starting_events.clone())
|
||||
.zip(starting_events.clone().stream());
|
||||
.zip(starting_events.stream());
|
||||
|
||||
pin_mut!(starting_ids);
|
||||
let mut buckets = [BUCKET; NUM_BUCKETS];
|
||||
@@ -112,24 +122,30 @@ where
|
||||
}
|
||||
|
||||
debug!(
|
||||
starting_events = ?starting_events.count(),
|
||||
starting_events = starting_events_count,
|
||||
elapsed = ?started.elapsed(),
|
||||
"start",
|
||||
);
|
||||
|
||||
let full_auth_chain: Vec<ShortEventId> = buckets
|
||||
.into_iter()
|
||||
.try_stream()
|
||||
.broad_and_then(|chunk| self.get_auth_chain_outer(room_id, started, chunk, &room_rules))
|
||||
.try_collect()
|
||||
.map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect())
|
||||
.map_ok(|mut full_auth_chain: Vec<_>| {
|
||||
full_auth_chain.sort_unstable();
|
||||
full_auth_chain.dedup();
|
||||
full_auth_chain
|
||||
.iter()
|
||||
.stream()
|
||||
.flat_map_unordered(automatic_width(), |starting_events| {
|
||||
self.get_chunk_auth_chain(
|
||||
room_id,
|
||||
&started,
|
||||
starting_events.iter().copied(),
|
||||
&room_rules,
|
||||
)
|
||||
.boxed()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.map(IntoIterator::into_iter)
|
||||
.map(Itertools::sorted_unstable)
|
||||
.map(Itertools::dedup)
|
||||
.map(Iterator::collect)
|
||||
.boxed()
|
||||
.await?;
|
||||
.await;
|
||||
|
||||
debug!(
|
||||
chain_length = ?full_auth_chain.len(),
|
||||
@@ -141,80 +157,118 @@ where
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn get_auth_chain_outer<'a>(
|
||||
#[tracing::instrument(
|
||||
name = "outer",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
starting_events = %starting_events.clone().count(),
|
||||
)
|
||||
)]
|
||||
pub fn get_chunk_auth_chain<'a, I>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
started: Instant,
|
||||
chunk: Bucket<'_>,
|
||||
room_rules: &RoomVersionRules,
|
||||
) -> Result<Vec<ShortEventId>> {
|
||||
let chunk_key: Vec<ShortEventId> = chunk.iter().map(at!(0)).collect();
|
||||
started: &'a Instant,
|
||||
starting_events: I,
|
||||
room_rules: &'a RoomVersionRules,
|
||||
) -> impl Stream<Item = ShortEventId> + Send + 'a
|
||||
where
|
||||
I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
|
||||
{
|
||||
let starting_shortids = starting_events.clone().map(at!(0));
|
||||
|
||||
if let Ok(cached) = self.get_cached_auth_chain(&chunk_key).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
|
||||
if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
|
||||
return cached;
|
||||
}
|
||||
|
||||
let chunk_cache = chunk
|
||||
.into_iter()
|
||||
.stream()
|
||||
.broad_then(async |(shortid, event_id)| {
|
||||
if let Ok(cached) = self.get_cached_auth_chain(&[shortid]).await {
|
||||
return cached;
|
||||
}
|
||||
let auth_chain: Vec<_> = self
|
||||
.get_event_auth_chain(room_id, event_id, room_rules)
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let auth_chain: Vec<_> = self
|
||||
.get_auth_chain_inner(room_id, event_id, room_rules)
|
||||
.collect()
|
||||
.await;
|
||||
self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
|
||||
debug!(
|
||||
?event_id,
|
||||
elapsed = ?started.elapsed(),
|
||||
"Cache missed event"
|
||||
);
|
||||
|
||||
self.put_cached_auth_chain(&[shortid], auth_chain.as_slice());
|
||||
debug!(
|
||||
?event_id,
|
||||
elapsed = ?started.elapsed(),
|
||||
"Cache missed event"
|
||||
);
|
||||
auth_chain
|
||||
};
|
||||
|
||||
auth_chain
|
||||
let cache_chain = move |chunk_cache: &Vec<_>| {
|
||||
self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
|
||||
debug!(
|
||||
chunk_cache_length = ?chunk_cache.len(),
|
||||
elapsed = ?started.elapsed(),
|
||||
"Cache missed chunk",
|
||||
);
|
||||
};
|
||||
|
||||
self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
|
||||
.map_ok(IntoIterator::into_iter)
|
||||
.map_ok(IterStream::try_stream)
|
||||
.or_else(move |_| async move {
|
||||
starting_events
|
||||
.clone()
|
||||
.stream()
|
||||
.broad_then(build_chain)
|
||||
.collect::<Vec<_>>()
|
||||
.map(IntoIterator::into_iter)
|
||||
.map(Iterator::flatten)
|
||||
.map(Itertools::sorted_unstable)
|
||||
.map(Itertools::dedup)
|
||||
.map(Iterator::collect)
|
||||
.inspect(cache_chain)
|
||||
.map(IntoIterator::into_iter)
|
||||
.map(IterStream::try_stream)
|
||||
.map(Ok)
|
||||
.await
|
||||
})
|
||||
.collect()
|
||||
.map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
|
||||
.map(|mut chunk_cache: Vec<_>| {
|
||||
chunk_cache.sort_unstable();
|
||||
chunk_cache.dedup();
|
||||
chunk_cache
|
||||
})
|
||||
.await;
|
||||
|
||||
self.put_cached_auth_chain(&chunk_key, chunk_cache.as_slice());
|
||||
debug!(
|
||||
chunk_cache_length = ?chunk_cache.len(),
|
||||
elapsed = ?started.elapsed(),
|
||||
"Cache missed chunk",
|
||||
);
|
||||
|
||||
Ok(chunk_cache)
|
||||
.try_flatten_stream()
|
||||
.map_expect("either cache hit or cache miss yields a chain")
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "inner",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(%event_id),
|
||||
)]
|
||||
fn get_auth_chain_inner<'a>(
|
||||
#[tracing::instrument(name = "inner", level = "trace", skip_all)]
|
||||
pub fn get_event_auth_chain<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
event_id: &'a EventId,
|
||||
room_rules: &'a RoomVersionRules,
|
||||
) -> impl Stream<Item = ShortEventId> + Send + 'a {
|
||||
self.get_event_auth_chain_ids(room_id, event_id, room_rules)
|
||||
.broad_then(async move |auth_event| {
|
||||
self.services
|
||||
.short
|
||||
.get_or_create_shorteventid(&auth_event)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "inner_ids",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(%event_id)
|
||||
)]
|
||||
pub fn get_event_auth_chain_ids<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
event_id: &'a EventId,
|
||||
room_rules: &'a RoomVersionRules,
|
||||
) -> impl Stream<Item = OwnedEventId> + Send + 'a {
|
||||
struct State<Fut> {
|
||||
todo: FuturesUnordered<Fut>,
|
||||
seen: HashSet<OwnedEventId>,
|
||||
}
|
||||
|
||||
let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
|
||||
|
||||
let state = State {
|
||||
todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
|
||||
todo: once(starting_events).collect(),
|
||||
seen: room_rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id
|
||||
@@ -228,12 +282,12 @@ fn get_auth_chain_inner<'a>(
|
||||
match state.todo.next().await {
|
||||
| None => None,
|
||||
| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
|
||||
| Some(Ok(pdu)) => {
|
||||
| Some(Ok(auth_events)) => {
|
||||
let push = |auth_event: &OwnedEventId| {
|
||||
trace!(?event_id, ?auth_event, "push");
|
||||
state
|
||||
.todo
|
||||
.push(self.get_pdu(room_id, auth_event.clone()));
|
||||
.push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
|
||||
};
|
||||
|
||||
let seen = |auth_event: OwnedEventId| {
|
||||
@@ -243,9 +297,8 @@ fn get_auth_chain_inner<'a>(
|
||||
.then_some(auth_event)
|
||||
};
|
||||
|
||||
let out = pdu
|
||||
.auth_events()
|
||||
.map(ToOwned::to_owned)
|
||||
let out = auth_events
|
||||
.into_iter()
|
||||
.filter_map(seen)
|
||||
.inspect(push)
|
||||
.collect::<AuthEvents>()
|
||||
@@ -257,45 +310,29 @@ fn get_auth_chain_inner<'a>(
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.broad_then(async move |auth_event| {
|
||||
self.services
|
||||
.short
|
||||
.get_or_create_shorteventid(&auth_event)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
|
||||
let pdu = self
|
||||
.services
|
||||
.timeline
|
||||
.get_pdu(&event_id)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
|
||||
})?;
|
||||
#[tracing::instrument(
|
||||
name = "cache_put",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
key_len = key.clone().count(),
|
||||
chain_len = auth_chain.len(),
|
||||
)
|
||||
)]
|
||||
fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
|
||||
where
|
||||
I: Iterator<Item = ShortEventId> + Clone + Send,
|
||||
{
|
||||
let key = key.collect::<CacheKey>();
|
||||
|
||||
if pdu.room_id() != room_id {
|
||||
return Err!(Request(Forbidden(error!(
|
||||
?event_id,
|
||||
?room_id,
|
||||
wrong_room_id = ?pdu.room_id(),
|
||||
"auth event for incorrect room",
|
||||
))));
|
||||
}
|
||||
|
||||
Ok(pdu)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(skip_all, level = "debug")]
|
||||
fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId]) {
|
||||
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
|
||||
|
||||
self.db
|
||||
.authchainkey_authchain
|
||||
.put(key, auth_chain);
|
||||
.put(key.as_slice(), auth_chain);
|
||||
|
||||
if key.len() == 1 {
|
||||
self.db
|
||||
@@ -305,8 +342,21 @@ fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(skip_all, level = "trace")]
|
||||
async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>> {
|
||||
#[tracing::instrument(
|
||||
name = "cache_get",
|
||||
level = "trace",
|
||||
err(level = "trace"),
|
||||
skip_all,
|
||||
fields(
|
||||
key_len = %key.clone().count()
|
||||
),
|
||||
)]
|
||||
async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
|
||||
where
|
||||
I: Iterator<Item = ShortEventId> + Clone + Send,
|
||||
{
|
||||
let key = key.collect::<CacheKey>();
|
||||
|
||||
if key.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
@@ -315,7 +365,7 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>>
|
||||
let chain = self
|
||||
.db
|
||||
.authchainkey_authchain
|
||||
.qry(key)
|
||||
.qry(key.as_slice())
|
||||
.map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
|
||||
.or_else(async |e| {
|
||||
if key.len() > 1 {
|
||||
@@ -328,12 +378,51 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>>
|
||||
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))
|
||||
.await
|
||||
})
|
||||
.await?;
|
||||
|
||||
let chain = chain
|
||||
.await?
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(utils::u64_from_u8)
|
||||
.collect();
|
||||
|
||||
Ok(chain)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "auth_events",
|
||||
level = "trace",
|
||||
ret(level = "trace"),
|
||||
err(level = "trace"),
|
||||
skip_all,
|
||||
fields(%event_id)
|
||||
)]
|
||||
async fn get_event_auth_event_ids<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
event_id: OwnedEventId,
|
||||
) -> Result<AuthEvents> {
|
||||
#[derive(Deserialize)]
|
||||
struct Pdu {
|
||||
auth_events: AuthEvents,
|
||||
room_id: OwnedRoomId,
|
||||
}
|
||||
|
||||
let pdu: Pdu = self
|
||||
.services
|
||||
.timeline
|
||||
.get(&event_id)
|
||||
.inspect_err(|e| {
|
||||
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
|
||||
})
|
||||
.await?;
|
||||
|
||||
if pdu.room_id != room_id {
|
||||
return Err!(Request(Forbidden(error!(
|
||||
?event_id,
|
||||
?room_id,
|
||||
wrong_room_id = ?pdu.room_id,
|
||||
"auth event for incorrect room",
|
||||
))));
|
||||
}
|
||||
|
||||
Ok(pdu.auth_events)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{borrow::Borrow, fmt::Debug, mem::size_of_val, sync::Arc};
|
||||
use std::{borrow::Borrow, mem::size_of_val, sync::Arc};
|
||||
|
||||
use futures::{FutureExt, Stream, StreamExt, pin_mut};
|
||||
use ruma::{EventId, OwnedRoomId, RoomId, events::StateEventType};
|
||||
@@ -61,7 +61,7 @@ pub fn multi_get_or_create_shorteventid<'a, I>(
|
||||
event_ids: I,
|
||||
) -> impl Stream<Item = ShortEventId> + Send + '_
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + Send + 'a,
|
||||
I: Iterator<Item = &'a EventId> + Clone + Send + 'a,
|
||||
{
|
||||
event_ids
|
||||
.clone()
|
||||
|
||||
Reference in New Issue
Block a user