Simplify database queries.

This commit is contained in:
Jason Volk
2026-03-09 07:14:05 +00:00
parent dfedef4d19
commit d15b30de64
5 changed files with 71 additions and 121 deletions

View File

@@ -426,14 +426,14 @@ pub struct Config {
/// single user. /// single user.
/// ///
/// default: 10 /// default: 10
#[serde(default = "default_rc_media_create_per_second")] #[serde(default = "default_media_rc_create_per_second")]
pub rc_media_create_per_second: u32, pub media_rc_create_per_second: u32,
/// The maximum burst count for media create requests from a single user. /// The maximum burst count for media create requests from a single user.
/// ///
/// default: 50 /// default: 50
#[serde(default = "default_rc_media_create_burst_count")] #[serde(default = "default_media_rc_create_burst_count")]
pub rc_media_create_burst_count: u32, pub media_rc_create_burst_count: u32,
/// default: 192 /// default: 192
#[serde(default = "default_max_fetch_prev_events")] #[serde(default = "default_max_fetch_prev_events")]
@@ -3274,8 +3274,8 @@ fn default_ip_lookup_strategy() -> u8 { 5 }
fn default_max_request_size() -> usize { 24 * 1024 * 1024 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 }
fn default_max_pending_media_uploads() -> usize { 5 } fn default_max_pending_media_uploads() -> usize { 5 }
fn default_media_create_unused_expiration_time() -> u64 { 86400 } fn default_media_create_unused_expiration_time() -> u64 { 86400 }
fn default_rc_media_create_per_second() -> u32 { 10 } fn default_media_rc_create_per_second() -> u32 { 10 }
fn default_rc_media_create_burst_count() -> u32 { 50 } fn default_media_rc_create_burst_count() -> u32 { 50 }
fn default_request_conn_timeout() -> u64 { 10 } fn default_request_conn_timeout() -> u64 { 10 }

View File

@@ -130,7 +130,8 @@ pub(super) static MAPS: &[Descriptor] = &[
}, },
Descriptor { Descriptor {
name: "mediaid_pending", name: "mediaid_pending",
..descriptor::RANDOM_SMALL ttl: 60 * 60 * 24 * 7,
..descriptor::RANDOM_SMALL_CACHE
}, },
Descriptor { Descriptor {
name: "mediaid_user", name: "mediaid_user",

View File

@@ -1,12 +1,16 @@
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use futures::{StreamExt, pin_mut}; use futures::{StreamExt, pin_mut};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition};
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug, debug_error, debug_info, err, error, Err, Result, debug, debug_info, err,
utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, utils::{
ReadyExt, str_from_bytes,
stream::{TryExpect, TryIgnore},
string_from_bytes,
},
}; };
use tuwunel_database::{Database, Interfix, Map, serialize_key}; use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key};
use super::{preview::UrlPreviewData, thumbnail::Dim}; use super::{preview::UrlPreviewData, thumbnail::Dim};
@@ -61,102 +65,43 @@ impl Data {
user: &UserId, user: &UserId,
unused_expires_at: u64, unused_expires_at: u64,
) { ) {
debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending MXC"); let value = (unused_expires_at, user);
debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending");
let key = mxc.to_string(); self.mediaid_pending.put(mxc, value);
// 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and
// user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user]
let mut value = Vec::with_capacity(
user.as_bytes()
.len()
.checked_add(9)
.expect("User len too large,capacity overflow!"),
);
value.extend_from_slice(&unused_expires_at.to_be_bytes());
value.push(0xFF);
value.extend_from_slice(user.as_bytes());
self.mediaid_pending
.insert(key.as_bytes(), &value);
} }
/// Remove a pending MXC URI from the database
pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending.del(mxc); }
/// Count the number of pending MXC URIs for a specific user /// Count the number of pending MXC URIs for a specific user
pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) {
let mut count: usize = 0; type KeyVal<'a> = (Ignore, (u64, &'a UserId));
let mut earliest_expiration = u64::MAX;
let user_bytes = user.as_bytes();
self.mediaid_pending self.mediaid_pending
.raw_stream() .stream()
.ignore_err() .expect_ok()
.ready_for_each(|(_key, value)| { .ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id)
let mut parts = value.splitn(2, |&b| b == 0xFF); .ready_fold(
match (parts.next(), parts.next()) { (0_usize, u64::MAX),
| (Some(expires_at_bytes), Some(user_id_bytes)) |(count, earliest_expiration), (_, (expires_at, _))| {
if user_id_bytes == user_bytes => (count.saturating_add(1), earliest_expiration.min(expires_at))
{ },
// safe to add 1 even if count = usize::MAX it also > max_uploads )
count = count.saturating_add(1_usize); .await
let expires_at =
u64::from_be_bytes(expires_at_bytes.try_into().unwrap_or([0_u8; 8]));
if expires_at < earliest_expiration {
earliest_expiration = expires_at;
}
},
| _ => {},
}
})
.await;
(count, earliest_expiration)
} }
/// Search for a pending MXC URI in the database /// Search for a pending MXC URI in the database
pub(super) async fn search_pending_mxc( pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> {
&self, type Value<'a> = (u64, OwnedUserId);
mxc: &Mxc<'_>,
) -> Option<(ruma::OwnedUserId, u64)> {
debug!(?mxc, "Searching for pending MXC");
let key = mxc.to_string();
let value = match self.mediaid_pending.get(key.as_bytes()).await {
| Ok(v) => v,
| Err(e) => {
debug_error!(?key, "pending MXC not found or error: {e}");
return None;
},
};
let mut parts = value.splitn(2, |&b| b == 0xFF);
// value: [expires_at, 0xFF, user]
let expires_at_bytes = parts.next()?;
let expires_at = match expires_at_bytes.try_into() {
| Ok(v) => u64::from_be_bytes(v),
| Err(_) => {
debug_error!(?key, "Failed to parse expires_at for key");
return None;
},
};
let user_id_bytes = parts.next()?;
let Ok(user_str) = str_from_bytes(user_id_bytes) else {
error!(?key, "Failed to parse user_str for key");
return None;
};
let user_id = match user_str.try_into() {
| Ok(v) => v,
| Err(e) => {
error!(?key, "Failed to parse user_id for key: {e}");
return None;
},
};
debug!(?key, ?user_id, ?expires_at, "Found pending MXC");
Some((user_id, expires_at))
}
pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) {
self.mediaid_pending self.mediaid_pending
.remove(mxc.to_string().as_bytes()); .qry(mxc)
.await
.deserialized()
.map(|(expires_at, user_id): Value<'_>| (user_id, expires_at))
.inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending"))
.map_err(|e| err!(Request(NotFound("Pending not found or error: {e}"))))
} }
pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) {

View File

@@ -44,7 +44,7 @@ pub struct FileMeta {
/// For MSC2246 /// For MSC2246
struct MXCState { struct MXCState {
/// Save the notifier for each pending media upload /// Save the notifier for each pending media upload
notifiers: Mutex<HashMap<String, Arc<Notify>>>, notifiers: Mutex<HashMap<OwnedMxcUri, Arc<Notify>>>,
/// Save the ratelimiter for each user /// Save the ratelimiter for each user
ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>, ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
} }
@@ -90,22 +90,23 @@ impl crate::Service for Service {
impl Service { impl Service {
/// Create a pending media upload ID. /// Create a pending media upload ID.
#[tracing::instrument(level = "debug", skip(self))]
pub async fn create_pending( pub async fn create_pending(
&self, &self,
mxc: &Mxc<'_>, mxc: &Mxc<'_>,
user: &UserId, user: &UserId,
unused_expires_at: u64, unused_expires_at: u64,
) -> Result<()> { ) -> Result {
let config = &self.services.server.config; let config = &self.services.server.config;
// Rate limiting (rc_media_create) // Rate limiting (rc_media_create)
let rate = f64::from(config.rc_media_create_per_second); let rate = f64::from(config.media_rc_create_per_second);
let burst = f64::from(config.rc_media_create_burst_count); let burst = f64::from(config.media_rc_create_burst_count);
// Check rate limiting // Check rate limiting
if rate > 0.0 && burst > 0.0 { if rate > 0.0 && burst > 0.0 {
let now = Instant::now(); let now = Instant::now();
let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap(); let mut ratelimiter = self.mxc_state.ratelimiter.lock()?;
let (last_time, tokens) = ratelimiter let (last_time, tokens) = ratelimiter
.entry(user.to_owned()) .entry(user.to_owned())
@@ -149,6 +150,7 @@ impl Service {
} }
/// Uploads content to a pending media ID. /// Uploads content to a pending media ID.
#[tracing::instrument(level = "debug", skip(self))]
pub async fn upload_pending( pub async fn upload_pending(
&self, &self,
mxc: &Mxc<'_>, mxc: &Mxc<'_>,
@@ -157,7 +159,7 @@ impl Service {
content_type: Option<&str>, content_type: Option<&str>,
file: &[u8], file: &[u8],
) -> Result { ) -> Result {
let Some((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
if self.get_metadata(mxc).await.is_some() { if self.get_metadata(mxc).await.is_some() {
return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
} }
@@ -179,16 +181,16 @@ impl Service {
self.db.remove_pending_mxc(mxc); self.db.remove_pending_mxc(mxc);
let mxc_uri: OwnedMxcUri = mxc.to_string().into();
if let Some(notifier) = self if let Some(notifier) = self
.mxc_state .mxc_state
.notifiers .notifiers
.lock()? .lock()?
.get(&mxc.to_string()) .get(&mxc_uri)
.cloned() .cloned()
{ {
notifier.notify_waiters(); notifier.notify_waiters();
let mut map = self.mxc_state.notifiers.lock().unwrap(); self.mxc_state.notifiers.lock()?.remove(&mxc_uri);
map.remove(&mxc.to_string());
} }
Ok(()) Ok(())
@@ -307,16 +309,17 @@ impl Service {
return Ok(Some(meta)); return Ok(Some(meta));
} }
let Some(_pending) = self.db.search_pending_mxc(mxc).await else { let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
return Ok(None); return Ok(None);
}; };
let notifier = { let notifier = self
let mut map = self.mxc_state.notifiers.lock().unwrap(); .mxc_state
map.entry(mxc.to_string()) .notifiers
.or_insert_with(|| Arc::new(Notify::new())) .lock()?
.clone() .entry(mxc.to_string().into())
}; .or_insert_with(|| Arc::new(Notify::new()))
.clone();
if tokio::time::timeout(timeout_duration, notifier.notified()) if tokio::time::timeout(timeout_duration, notifier.notified())
.await .await
@@ -339,16 +342,17 @@ impl Service {
return Ok(Some(meta)); return Ok(Some(meta));
} }
let Some(_pending) = self.db.search_pending_mxc(mxc).await else { let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
return Ok(None); return Ok(None);
}; };
let notifier = { let notifier = self
let mut map = self.mxc_state.notifiers.lock().unwrap(); .mxc_state
map.entry(mxc.to_string()) .notifiers
.or_insert_with(|| Arc::new(Notify::new())) .lock()?
.clone() .entry(mxc.to_string().into())
}; .or_insert_with(|| Arc::new(Notify::new()))
.clone();
if tokio::time::timeout(timeout_duration, notifier.notified()) if tokio::time::timeout(timeout_duration, notifier.notified())
.await .await

View File

@@ -333,11 +333,11 @@
# The maximum number of media create requests per second allowed from a # The maximum number of media create requests per second allowed from a
# single user. # single user.
# #
#rc_media_create_per_second = 10 #media_rc_create_per_second = 10
# The maximum burst count for media create requests from a single user. # The maximum burst count for media create requests from a single user.
# #
#rc_media_create_burst_count = 50 #media_rc_create_burst_count = 50
# This item is undocumented. Please contribute documentation for it. # This item is undocumented. Please contribute documentation for it.
# #