Split auth_chain shortid and eventid gathering callstacks.

Optimize event parse for auth_chain auth_events fetch.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-27 11:01:12 +00:00
parent 42570a5a7c
commit 254b53adf4
2 changed files with 206 additions and 117 deletions

View File

@@ -1,6 +1,5 @@
use std::{
collections::{BTreeSet, HashSet},
fmt::Debug,
iter::once,
sync::Arc,
time::Instant,
@@ -8,18 +7,24 @@ use std::{
use async_trait::async_trait;
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
FutureExt, Stream, StreamExt, TryFutureExt, pin_mut,
stream::{FuturesUnordered, unfold},
};
use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use ruma::{
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
room_version_rules::RoomVersionRules,
};
use serde::Deserialize;
use tuwunel_core::{
Err, Result, at, debug, debug_error, err, implement,
matrix::{Event, PduEvent, room_version},
itertools::Itertools,
matrix::room_version,
pdu::AuthEvents,
smallvec::SmallVec,
trace, utils,
utils::{
IterStream,
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
stream::{BroadbandExt, ReadyExt, TryExpect, automatic_width},
},
validated, warn,
};
@@ -37,7 +42,8 @@ struct Data {
shorteventid_authchain: Arc<Map>,
}
type Bucket<'a> = BTreeSet<(u64, &'a EventId)>;
type Bucket<'a> = BTreeSet<(ShortEventId, &'a EventId)>;
type CacheKey = SmallVec<[ShortEventId; 1]>;
#[async_trait]
impl crate::Service for Service {
@@ -64,7 +70,7 @@ pub fn event_ids_iter<'a, I>(
starting_events: I,
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
{
self.get_auth_chain(room_id, room_version, starting_events)
.map_ok(|chain| {
@@ -81,7 +87,10 @@ where
name = "auth_chain",
level = "debug",
skip_all,
fields(%room_id),
fields(
%room_id,
starting_events = %starting_events.clone().count(),
)
)]
pub async fn get_auth_chain<'a, I>(
&'a self,
@@ -90,18 +99,19 @@ pub async fn get_auth_chain<'a, I>(
starting_events: I,
) -> Result<Vec<ShortEventId>>
where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
I: Iterator<Item = &'a EventId> + Clone + ExactSizeIterator + Send + 'a,
{
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: Bucket<'_> = BTreeSet::new();
let started = Instant::now();
let room_rules = room_version::rules(room_version)?;
let starting_events_count = starting_events.clone().count();
let starting_ids = self
.services
.short
.multi_get_or_create_shorteventid(starting_events.clone())
.zip(starting_events.clone().stream());
.zip(starting_events.stream());
pin_mut!(starting_ids);
let mut buckets = [BUCKET; NUM_BUCKETS];
@@ -112,24 +122,30 @@ where
}
debug!(
starting_events = ?starting_events.count(),
starting_events = starting_events_count,
elapsed = ?started.elapsed(),
"start",
);
let full_auth_chain: Vec<ShortEventId> = buckets
.into_iter()
.try_stream()
.broad_and_then(|chunk| self.get_auth_chain_outer(room_id, started, chunk, &room_rules))
.try_collect()
.map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect())
.map_ok(|mut full_auth_chain: Vec<_>| {
full_auth_chain.sort_unstable();
full_auth_chain.dedup();
full_auth_chain
.iter()
.stream()
.flat_map_unordered(automatic_width(), |starting_events| {
self.get_chunk_auth_chain(
room_id,
&started,
starting_events.iter().copied(),
&room_rules,
)
.boxed()
})
.collect::<Vec<_>>()
.map(IntoIterator::into_iter)
.map(Itertools::sorted_unstable)
.map(Itertools::dedup)
.map(Iterator::collect)
.boxed()
.await?;
.await;
debug!(
chain_length = ?full_auth_chain.len(),
@@ -141,80 +157,118 @@ where
}
#[implement(Service)]
async fn get_auth_chain_outer<'a>(
#[tracing::instrument(
name = "outer",
level = "trace",
skip_all,
fields(
starting_events = %starting_events.clone().count(),
)
)]
pub fn get_chunk_auth_chain<'a, I>(
&'a self,
room_id: &'a RoomId,
started: Instant,
chunk: Bucket<'_>,
room_rules: &RoomVersionRules,
) -> Result<Vec<ShortEventId>> {
let chunk_key: Vec<ShortEventId> = chunk.iter().map(at!(0)).collect();
started: &'a Instant,
starting_events: I,
room_rules: &'a RoomVersionRules,
) -> impl Stream<Item = ShortEventId> + Send + 'a
where
I: Iterator<Item = (ShortEventId, &'a EventId)> + Clone + Send + Sync + 'a,
{
let starting_shortids = starting_events.clone().map(at!(0));
if let Ok(cached) = self.get_cached_auth_chain(&chunk_key).await {
return Ok(cached);
}
let build_chain = async |(shortid, event_id): (ShortEventId, &'a EventId)| {
if let Ok(cached) = self.get_cached_auth_chain(once(shortid)).await {
return cached;
}
let chunk_cache = chunk
.into_iter()
.stream()
.broad_then(async |(shortid, event_id)| {
if let Ok(cached) = self.get_cached_auth_chain(&[shortid]).await {
return cached;
}
let auth_chain: Vec<_> = self
.get_event_auth_chain(room_id, event_id, room_rules)
.collect()
.await;
let auth_chain: Vec<_> = self
.get_auth_chain_inner(room_id, event_id, room_rules)
.collect()
.await;
self.put_cached_auth_chain(once(shortid), auth_chain.as_slice());
debug!(
?event_id,
elapsed = ?started.elapsed(),
"Cache missed event"
);
self.put_cached_auth_chain(&[shortid], auth_chain.as_slice());
debug!(
?event_id,
elapsed = ?started.elapsed(),
"Cache missed event"
);
auth_chain
};
auth_chain
let cache_chain = move |chunk_cache: &Vec<_>| {
self.put_cached_auth_chain(starting_shortids, chunk_cache.as_slice());
debug!(
chunk_cache_length = ?chunk_cache.len(),
elapsed = ?started.elapsed(),
"Cache missed chunk",
);
};
self.get_cached_auth_chain(starting_events.clone().map(at!(0)))
.map_ok(IntoIterator::into_iter)
.map_ok(IterStream::try_stream)
.or_else(move |_| async move {
starting_events
.clone()
.stream()
.broad_then(build_chain)
.collect::<Vec<_>>()
.map(IntoIterator::into_iter)
.map(Iterator::flatten)
.map(Itertools::sorted_unstable)
.map(Itertools::dedup)
.map(Iterator::collect)
.inspect(cache_chain)
.map(IntoIterator::into_iter)
.map(IterStream::try_stream)
.map(Ok)
.await
})
.collect()
.map(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map(|mut chunk_cache: Vec<_>| {
chunk_cache.sort_unstable();
chunk_cache.dedup();
chunk_cache
})
.await;
self.put_cached_auth_chain(&chunk_key, chunk_cache.as_slice());
debug!(
chunk_cache_length = ?chunk_cache.len(),
elapsed = ?started.elapsed(),
"Cache missed chunk",
);
Ok(chunk_cache)
.try_flatten_stream()
.map_expect("either cache hit or cache miss yields a chain")
}
#[implement(Service)]
#[tracing::instrument(
name = "inner",
level = "trace",
skip_all,
fields(%event_id),
)]
fn get_auth_chain_inner<'a>(
#[tracing::instrument(name = "inner", level = "trace", skip_all)]
pub fn get_event_auth_chain<'a>(
&'a self,
room_id: &'a RoomId,
event_id: &'a EventId,
room_rules: &'a RoomVersionRules,
) -> impl Stream<Item = ShortEventId> + Send + 'a {
self.get_event_auth_chain_ids(room_id, event_id, room_rules)
.broad_then(async move |auth_event| {
self.services
.short
.get_or_create_shorteventid(&auth_event)
.await
})
}
#[implement(Service)]
#[tracing::instrument(
name = "inner_ids",
level = "trace",
skip_all,
fields(%event_id)
)]
pub fn get_event_auth_chain_ids<'a>(
&'a self,
room_id: &'a RoomId,
event_id: &'a EventId,
room_rules: &'a RoomVersionRules,
) -> impl Stream<Item = OwnedEventId> + Send + 'a {
struct State<Fut> {
todo: FuturesUnordered<Fut>,
seen: HashSet<OwnedEventId>,
}
let starting_events = self.get_event_auth_event_ids(room_id, event_id.to_owned());
let state = State {
todo: once(self.get_pdu(room_id, event_id.to_owned())).collect(),
todo: once(starting_events).collect(),
seen: room_rules
.authorization
.room_create_event_id_as_room_id
@@ -228,12 +282,12 @@ fn get_auth_chain_inner<'a>(
match state.todo.next().await {
| None => None,
| Some(Err(_)) => Some((AuthEvents::new().into_iter().stream(), state)),
| Some(Ok(pdu)) => {
| Some(Ok(auth_events)) => {
let push = |auth_event: &OwnedEventId| {
trace!(?event_id, ?auth_event, "push");
state
.todo
.push(self.get_pdu(room_id, auth_event.clone()));
.push(self.get_event_auth_event_ids(room_id, auth_event.clone()));
};
let seen = |auth_event: OwnedEventId| {
@@ -243,9 +297,8 @@ fn get_auth_chain_inner<'a>(
.then_some(auth_event)
};
let out = pdu
.auth_events()
.map(ToOwned::to_owned)
let out = auth_events
.into_iter()
.filter_map(seen)
.inspect(push)
.collect::<AuthEvents>()
@@ -257,45 +310,29 @@ fn get_auth_chain_inner<'a>(
}
})
.flatten()
.broad_then(async move |auth_event| {
self.services
.short
.get_or_create_shorteventid(&auth_event)
.await
})
}
#[implement(Service)]
async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> Result<PduEvent> {
let pdu = self
.services
.timeline
.get_pdu(&event_id)
.await
.inspect_err(|e| {
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
})?;
#[tracing::instrument(
name = "cache_put",
level = "debug",
skip_all,
fields(
key_len = key.clone().count(),
chain_len = auth_chain.len(),
)
)]
fn put_cached_auth_chain<I>(&self, key: I, auth_chain: &[ShortEventId])
where
I: Iterator<Item = ShortEventId> + Clone + Send,
{
let key = key.collect::<CacheKey>();
if pdu.room_id() != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
?room_id,
wrong_room_id = ?pdu.room_id(),
"auth event for incorrect room",
))));
}
Ok(pdu)
}
#[implement(Service)]
#[tracing::instrument(skip_all, level = "debug")]
fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId]) {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
self.db
.authchainkey_authchain
.put(key, auth_chain);
.put(key.as_slice(), auth_chain);
if key.len() == 1 {
self.db
@@ -305,8 +342,21 @@ fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId
}
#[implement(Service)]
#[tracing::instrument(skip_all, level = "trace")]
async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>> {
#[tracing::instrument(
name = "cache_get",
level = "trace",
err(level = "trace"),
skip_all,
fields(
key_len = %key.clone().count()
),
)]
async fn get_cached_auth_chain<I>(&self, key: I) -> Result<Vec<ShortEventId>>
where
I: Iterator<Item = ShortEventId> + Clone + Send,
{
let key = key.collect::<CacheKey>();
if key.is_empty() {
return Ok(Vec::new());
}
@@ -315,7 +365,7 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>>
let chain = self
.db
.authchainkey_authchain
.qry(key)
.qry(key.as_slice())
.map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
.or_else(async |e| {
if key.len() > 1 {
@@ -328,12 +378,51 @@ async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>>
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))
.await
})
.await?;
let chain = chain
.await?
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect();
Ok(chain)
}
#[implement(Service)]
#[tracing::instrument(
name = "auth_events",
level = "trace",
ret(level = "trace"),
err(level = "trace"),
skip_all,
fields(%event_id)
)]
async fn get_event_auth_event_ids<'a>(
&'a self,
room_id: &'a RoomId,
event_id: OwnedEventId,
) -> Result<AuthEvents> {
#[derive(Deserialize)]
struct Pdu {
auth_events: AuthEvents,
room_id: OwnedRoomId,
}
let pdu: Pdu = self
.services
.timeline
.get(&event_id)
.inspect_err(|e| {
debug_error!(?event_id, ?room_id, "auth chain event: {e}");
})
.await?;
if pdu.room_id != room_id {
return Err!(Request(Forbidden(error!(
?event_id,
?room_id,
wrong_room_id = ?pdu.room_id,
"auth event for incorrect room",
))));
}
Ok(pdu.auth_events)
}

View File

@@ -1,4 +1,4 @@
use std::{borrow::Borrow, fmt::Debug, mem::size_of_val, sync::Arc};
use std::{borrow::Borrow, mem::size_of_val, sync::Arc};
use futures::{FutureExt, Stream, StreamExt, pin_mut};
use ruma::{EventId, OwnedRoomId, RoomId, events::StateEventType};
@@ -61,7 +61,7 @@ pub fn multi_get_or_create_shorteventid<'a, I>(
event_ids: I,
) -> impl Stream<Item = ShortEventId> + Send + '_
where
I: Iterator<Item = &'a EventId> + Clone + Debug + Send + 'a,
I: Iterator<Item = &'a EventId> + Clone + Send + 'a,
{
event_ids
.clone()