Move auth_chain cache to db.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-21 16:24:53 +00:00
parent 7b22e7930a
commit afcb2315ee
4 changed files with 87 additions and 149 deletions

View File

@@ -202,11 +202,12 @@ fn get_cache(ctx: &Context, desc: &Descriptor) -> Option<Cache> {
| "eventid_pduid" => Some(config.eventid_pdu_cache_capacity),
| "eventid_shorteventid" => Some(config.eventidshort_cache_capacity),
| "shorteventid_eventid" => Some(config.shorteventid_cache_capacity),
| "shorteventid_authchain" => Some(config.auth_chain_cache_capacity),
| "shortstatekey_statekey" => Some(config.shortstatekey_cache_capacity),
| "statekey_shortstatekey" => Some(config.statekeyshort_cache_capacity),
| "servernameevent_data" => Some(config.servernameevent_data_cache_capacity),
| "pduid_pdu" | "eventid_outlierpdu" => Some(config.pdu_cache_capacity),
| "shorteventid_authchain" | "authchainkey_authchain" =>
Some(config.auth_chain_cache_capacity),
| _ => None,
}
.map(TryInto::try_into)

View File

@@ -34,6 +34,15 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "aliasid_alias",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "authchainkey_authchain",
cache_disp: CacheDisp::SharedWith("shorteventid_authchain"),
index_size: 512,
block_size: 4096,
key_size_hint: Some(8),
val_size_hint: Some(1024),
..descriptor::RANDOM_CACHE
},
Descriptor {
name: "backupid_algorithm",
..descriptor::RANDOM_SMALL
@@ -279,8 +288,11 @@ pub(super) static MAPS: &[Descriptor] = &[
},
Descriptor {
name: "shorteventid_authchain",
cache_disp: CacheDisp::Unique,
cache_disp: CacheDisp::SharedWith("authchainkey_authchain"),
key_size_hint: Some(8),
val_size_hint: Some(1024),
index_size: 512,
block_size: 4096,
..descriptor::SEQUENTIAL
},
Descriptor {

View File

@@ -1,92 +0,0 @@
use std::{
mem::size_of,
sync::{Arc, Mutex},
};
use lru_cache::LruCache;
use tuwunel_core::{Err, Result, err, utils, utils::math::usize_from_f64};
use tuwunel_database::Map;
use crate::rooms::short::ShortEventId;
pub(super) struct Data {
shorteventid_authchain: Arc<Map>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[ShortEventId]>>>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
let config = &args.server.config;
let cache_size = f64::from(config.auth_chain_cache_capacity);
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier)
.expect("valid cache size");
Self {
shorteventid_authchain: db["shorteventid_authchain"].clone(),
auth_chain_cache: Mutex::new(LruCache::new(cache_size)),
}
}
pub(super) async fn get_cached_eventid_authchain(
&self,
key: &[u64],
) -> Result<Arc<[ShortEventId]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache
if let Some(result) = self
.auth_chain_cache
.lock()
.expect("cache locked")
.get_mut(key)
{
return Ok(Arc::clone(result));
}
// We only save auth chains for single events in the db
if key.len() != 1 {
return Err!(Request(NotFound("auth_chain not cached")));
}
// Check database
let chain = self
.shorteventid_authchain
.qry(&key[0])
.await
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))?;
let chain = chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>();
// Cache in RAM
self.auth_chain_cache
.lock()
.expect("cache locked")
.insert(vec![key[0]], Arc::clone(&chain));
Ok(chain)
}
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[ShortEventId]>) {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Only persist single events in db
if key.len() == 1 {
let key = key[0].to_be_bytes();
let val = auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>();
self.shorteventid_authchain.insert(&key, &val);
}
// Cache in RAM
self.auth_chain_cache
.lock()
.expect("cache locked")
.insert(key, auth_chain);
}
}

View File

@@ -1,5 +1,3 @@
mod data;
use std::{
collections::{BTreeSet, HashSet},
fmt::Debug,
@@ -8,24 +6,25 @@ use std::{
time::Instant,
};
use async_trait::async_trait;
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
stream::{FuturesUnordered, unfold},
};
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
use tuwunel_core::{
Err, Result, at, debug, debug_error, implement,
Err, Result, at, debug, debug_error, err, implement,
matrix::{Event, PduEvent},
pdu::AuthEvents,
trace,
trace, utils,
utils::{
IterStream,
stream::{BroadbandExt, ReadyExt, TryBroadbandExt},
},
validated, warn,
};
use tuwunel_database::Map;
use self::data::Data;
use crate::rooms::short::ShortEventId;
pub struct Service {
@@ -33,16 +32,27 @@ pub struct Service {
db: Data,
}
struct Data {
authchainkey_authchain: Arc<Map>,
shorteventid_authchain: Arc<Map>,
}
type Bucket<'a> = BTreeSet<(u64, &'a EventId)>;
#[async_trait]
impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: args.services.clone(),
db: Data::new(args),
db: Data {
authchainkey_authchain: args.db["authchainkey_authchain"].clone(),
shorteventid_authchain: args.db["shorteventid_authchain"].clone(),
},
}))
}
async fn clear_cache(&self) { self.db.authchainkey_authchain.clear().await; }
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
@@ -66,7 +76,12 @@ where
}
#[implement(Service)]
#[tracing::instrument(name = "auth_chain", level = "debug", skip_all)]
#[tracing::instrument(
name = "auth_chain",
level = "debug",
skip_all,
fields(%room_id),
)]
pub async fn get_auth_chain<'a, I>(
&'a self,
room_id: &RoomId,
@@ -138,26 +153,16 @@ async fn get_auth_chain_outer<'a>(
) -> Result<Vec<ShortEventId>> {
let chunk_key: Vec<ShortEventId> = chunk.iter().map(at!(0)).collect();
if chunk_key.is_empty() {
return Ok(Vec::new());
}
if let Ok(cached) = self
.get_cached_eventid_authchain(&chunk_key)
.await
{
return Ok(cached.to_vec());
if let Ok(cached) = self.get_cached_auth_chain(&chunk_key).await {
return Ok(cached);
}
let chunk_cache = chunk
.into_iter()
.stream()
.broad_then(async |(shortid, event_id)| {
if let Ok(cached) = self
.get_cached_eventid_authchain(&[shortid])
.await
{
return cached.to_vec();
if let Ok(cached) = self.get_cached_auth_chain(&[shortid]).await {
return cached;
}
let auth_chain: Vec<_> = self
@@ -165,7 +170,7 @@ async fn get_auth_chain_outer<'a>(
.collect()
.await;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
self.put_cached_auth_chain(&[shortid], auth_chain.as_slice());
debug!(
?event_id,
elapsed = ?started.elapsed(),
@@ -183,7 +188,7 @@ async fn get_auth_chain_outer<'a>(
})
.await;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
self.put_cached_auth_chain(&chunk_key, chunk_cache.as_slice());
debug!(
chunk_cache_length = ?chunk_cache.len(),
elapsed = ?started.elapsed(),
@@ -286,40 +291,52 @@ async fn get_pdu<'a>(&'a self, room_id: &'a RoomId, event_id: OwnedEventId) -> R
Ok(pdu)
}
#[implement(Service)]
#[inline]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
self.db.get_cached_eventid_authchain(key).await
}
#[implement(Service)]
#[tracing::instrument(skip_all, level = "debug")]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<ShortEventId>) {
let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect();
fn put_cached_auth_chain(&self, key: &[ShortEventId], auth_chain: &[ShortEventId]) {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
self.db.cache_auth_chain(key, val);
}
#[implement(Service)]
#[tracing::instrument(skip_all, level = "debug")]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &[ShortEventId]) {
let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect();
self.db.cache_auth_chain(key, val);
}
#[implement(Service)]
pub fn get_cache_usage(&self) -> (usize, usize) {
let cache = self.db.auth_chain_cache.lock().expect("locked");
(cache.len(), cache.capacity())
}
#[implement(Service)]
pub fn clear_cache(&self) {
self.db
.auth_chain_cache
.lock()
.expect("locked")
.clear();
.authchainkey_authchain
.put(key, auth_chain);
if key.len() == 1 {
self.db
.shorteventid_authchain
.put(key, auth_chain);
}
}
#[implement(Service)]
#[tracing::instrument(skip_all, level = "trace")]
async fn get_cached_auth_chain(&self, key: &[u64]) -> Result<Vec<ShortEventId>> {
if key.is_empty() {
return Ok(Vec::new());
}
// Check cache. On miss, check first-order table for single-event keys.
let chain = self
.db
.authchainkey_authchain
.qry(key)
.map_err(|_| err!(Request(NotFound("auth_chain not cached"))))
.or_else(async |e| {
if key.len() > 1 {
return Err(e);
}
self.db
.shorteventid_authchain
.qry(&key[0])
.map_err(|_| err!(Request(NotFound("auth_chain not found"))))
.await
})
.await?;
let chain = chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect();
Ok(chain)
}