diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 06d162ee..cc61b5d0 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -426,14 +426,14 @@ pub struct Config { /// single user. /// /// default: 10 - #[serde(default = "default_rc_media_create_per_second")] - pub rc_media_create_per_second: u32, + #[serde(default = "default_media_rc_create_per_second")] + pub media_rc_create_per_second: u32, /// The maximum burst count for media create requests from a single user. /// /// default: 50 - #[serde(default = "default_rc_media_create_burst_count")] - pub rc_media_create_burst_count: u32, + #[serde(default = "default_media_rc_create_burst_count")] + pub media_rc_create_burst_count: u32, /// default: 192 #[serde(default = "default_max_fetch_prev_events")] @@ -3274,8 +3274,8 @@ fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } fn default_max_pending_media_uploads() -> usize { 5 } fn default_media_create_unused_expiration_time() -> u64 { 86400 } -fn default_rc_media_create_per_second() -> u32 { 10 } -fn default_rc_media_create_burst_count() -> u32 { 50 } +fn default_media_rc_create_per_second() -> u32 { 10 } +fn default_media_rc_create_burst_count() -> u32 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/database/maps.rs b/src/database/maps.rs index 3176818c..5ec270bb 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -130,7 +130,8 @@ pub(super) static MAPS: &[Descriptor] = &[ }, Descriptor { name: "mediaid_pending", - ..descriptor::RANDOM_SMALL + ttl: 60 * 60 * 24 * 7, + ..descriptor::RANDOM_SMALL_CACHE }, Descriptor { name: "mediaid_user", diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 0bd4751e..0e26fb69 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,12 +1,16 @@ use std::{sync::Arc, time::Duration}; use futures::{StreamExt, pin_mut}; -use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; +use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition}; use tuwunel_core::{ - Err, Result, debug, debug_error, debug_info, err, error, - utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, + Err, Result, debug, debug_info, err, + utils::{ + ReadyExt, str_from_bytes, + stream::{TryExpect, TryIgnore}, + string_from_bytes, + }, }; -use tuwunel_database::{Database, Interfix, Map, serialize_key}; +use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -61,102 +65,43 @@ impl Data { user: &UserId, unused_expires_at: u64, ) { - debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending MXC"); + let value = (unused_expires_at, user); + debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending"); - let key = mxc.to_string(); - // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and - // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] - let mut value = Vec::with_capacity( - user.as_bytes() - .len() - .checked_add(9) - .expect("User len too large,capacity overflow!"), - ); - value.extend_from_slice(&unused_expires_at.to_be_bytes()); - value.push(0xFF); - value.extend_from_slice(user.as_bytes()); - self.mediaid_pending - .insert(key.as_bytes(), &value); + self.mediaid_pending.put(mxc, value); } + /// Remove a pending MXC URI from the database + pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending.del(mxc); } + /// Count the number of pending MXC URIs for a specific user - pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { - let mut count: usize = 0; - let mut earliest_expiration = u64::MAX; - let user_bytes = user.as_bytes(); + pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) { + type KeyVal<'a> = (Ignore, (u64, &'a UserId)); self.mediaid_pending - .raw_stream() - .ignore_err() - .ready_for_each(|(_key, value)| { - let mut parts = value.splitn(2, |&b| b == 0xFF); - match (parts.next(), parts.next()) { - | (Some(expires_at_bytes), Some(user_id_bytes)) - if user_id_bytes == user_bytes => - { - // safe to add 1 even if count = usize::MAX it also > max_uploads - count = count.saturating_add(1_usize); - let expires_at = - u64::from_be_bytes(expires_at_bytes.try_into().unwrap_or([0_u8; 8])); - if expires_at < earliest_expiration { - earliest_expiration = expires_at; - } - }, - | _ => {}, - } - }) - .await; - - (count, earliest_expiration) + .stream() + .expect_ok() + .ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id) + .ready_fold( + (0_usize, u64::MAX), + |(count, earliest_expiration), (_, (expires_at, _))| { + (count.saturating_add(1), earliest_expiration.min(expires_at)) + }, + ) + .await } /// Search for a pending MXC URI in the database - pub(super) async fn search_pending_mxc( - &self, - mxc: &Mxc<'_>, - ) -> Option<(ruma::OwnedUserId, u64)> { - debug!(?mxc, "Searching for pending MXC"); + pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> { + type Value<'a> = (u64, OwnedUserId); - let key = mxc.to_string(); - let value = match self.mediaid_pending.get(key.as_bytes()).await { - | Ok(v) => v, - | Err(e) => { - debug_error!(?key, "pending MXC not found or error: {e}"); - return None; - }, - }; - let mut parts = value.splitn(2, |&b| b == 0xFF); - // value: [expires_at, 0xFF, user] - let expires_at_bytes = parts.next()?; - let expires_at = match expires_at_bytes.try_into() { - | Ok(v) => u64::from_be_bytes(v), - | Err(_) => { - debug_error!(?key, "Failed to parse expires_at for key"); - return None; - }, - }; - let user_id_bytes = parts.next()?; - - let Ok(user_str) = str_from_bytes(user_id_bytes) else { - error!(?key, "Failed to parse user_str for key"); - return None; - }; - let user_id = match user_str.try_into() { - | Ok(v) => v, - | Err(e) => { - error!(?key, "Failed to parse user_id for key: {e}"); - return None; - }, - }; - - debug!(?key, ?user_id, ?expires_at, "Found pending MXC"); - - Some((user_id, expires_at)) - } - - pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending - .remove(mxc.to_string().as_bytes()); + .qry(mxc) + .await + .deserialized() + .map(|(expires_at, user_id): Value<'_>| (user_id, expires_at)) + .inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending")) + .map_err(|e| err!(Request(NotFound("Pending not found or error: {e}")))) } pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 56e32ffe..d0698999 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -44,7 +44,7 @@ pub struct FileMeta { /// For MSC2246 struct MXCState { /// Save the notifier for each pending media upload - notifiers: Mutex>>, + notifiers: Mutex>>, /// Save the ratelimiter for each user ratelimiter: Mutex>, } @@ -90,22 +90,23 @@ impl crate::Service for Service { impl Service { /// Create a pending media upload ID. + #[tracing::instrument(level = "debug", skip(self))] pub async fn create_pending( &self, mxc: &Mxc<'_>, user: &UserId, unused_expires_at: u64, - ) -> Result<()> { + ) -> Result { let config = &self.services.server.config; // Rate limiting (rc_media_create) - let rate = f64::from(config.rc_media_create_per_second); - let burst = f64::from(config.rc_media_create_burst_count); + let rate = f64::from(config.media_rc_create_per_second); + let burst = f64::from(config.media_rc_create_burst_count); // Check rate limiting if rate > 0.0 && burst > 0.0 { let now = Instant::now(); - let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap(); + let mut ratelimiter = self.mxc_state.ratelimiter.lock()?; let (last_time, tokens) = ratelimiter .entry(user.to_owned()) @@ -149,6 +150,7 @@ impl Service { } /// Uploads content to a pending media ID. + #[tracing::instrument(level = "debug", skip(self))] pub async fn upload_pending( &self, mxc: &Mxc<'_>, @@ -157,7 +159,7 @@ impl Service { content_type: Option<&str>, file: &[u8], ) -> Result { - let Some((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { + let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { if self.get_metadata(mxc).await.is_some() { return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); } @@ -179,16 +181,16 @@ impl Service { self.db.remove_pending_mxc(mxc); + let mxc_uri: OwnedMxcUri = mxc.to_string().into(); if let Some(notifier) = self .mxc_state .notifiers .lock()? - .get(&mxc.to_string()) + .get(&mxc_uri) .cloned() { notifier.notify_waiters(); - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.remove(&mxc.to_string()); + self.mxc_state.notifiers.lock()?.remove(&mxc_uri); } Ok(()) @@ -307,16 +309,17 @@ impl Service { return Ok(Some(meta)); } - let Some(_pending) = self.db.search_pending_mxc(mxc).await else { + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); }; - let notifier = { - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.entry(mxc.to_string()) - .or_insert_with(|| Arc::new(Notify::new())) - .clone() - }; + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); if tokio::time::timeout(timeout_duration, notifier.notified()) .await @@ -339,16 +342,17 @@ impl Service { return Ok(Some(meta)); } - let Some(_pending) = self.db.search_pending_mxc(mxc).await else { + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); }; - let notifier = { - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.entry(mxc.to_string()) - .or_insert_with(|| Arc::new(Notify::new())) - .clone() - }; + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); if tokio::time::timeout(timeout_duration, notifier.notified()) .await diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 36c7635e..dc97b8d8 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -333,11 +333,11 @@ # The maximum number of media create requests per second allowed from a # single user. # -#rc_media_create_per_second = 10 +#media_rc_create_per_second = 10 # The maximum burst count for media create requests from a single user. # -#rc_media_create_burst_count = 50 +#media_rc_create_burst_count = 50 # This item is undocumented. Please contribute documentation for it. #