From 2b81e189cb9972d3ce38614fec70e67090bb1a98 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 1 Mar 2026 05:15:09 +0800 Subject: [PATCH 1/5] add MSC2246 support --- src/api/client/media.rs | 90 ++++++++++++++- src/api/client/media_legacy.rs | 18 ++- src/api/router.rs | 2 + src/core/config/mod.rs | 31 ++++++ src/database/maps.rs | 4 + src/service/media/data.rs | 110 ++++++++++++++++++ src/service/media/mod.rs | 198 ++++++++++++++++++++++++++++++++- tuwunel-example.toml | 16 +++ 8 files changed, 461 insertions(+), 8 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 9b1cb956..d63df14f 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -4,13 +4,13 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use reqwest::Url; use ruma::{ - Mxc, UserId, + MilliSecondsSinceUnixEpoch, Mxc, UserId, api::client::{ authenticated_media::{ get_content, get_content_as_filename, get_content_thumbnail, get_media_config, get_media_preview, }, - media::create_content, + media::{create_content, create_content_async, create_mxc_uri}, }, }; use tuwunel_core::{ @@ -80,6 +80,80 @@ pub(crate) async fn create_content_route( }) } +/// # `POST /_matrix/media/v1/create` +/// +/// Create a new MXC URI without content. +#[tracing::instrument( + name = "media_create_mxc", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_mxc_uri_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: services.globals.server_name(), + media_id: &utils::random_string(MXC_LENGTH), + }; + + let unused_expires_at = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as u64) + .saturating_add( + services + .server + .config + .media_create_unused_expiration_time + * 1000, + ); + services + .media + .create_pending(&mxc, user, unused_expires_at) + .await?; + + Ok(create_mxc_uri::v1::Response { + content_uri: mxc.to_string().into(), + unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch), + }) +} + +/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}` +/// +/// Upload content to a MXC URI that was created earlier. +#[tracing::instrument( + name = "media_upload_async", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_content_async_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: &body.server_name, + media_id: &body.media_id, + }; + + let filename = body.filename.as_deref(); + let content_type = body.content_type.as_deref(); + let content_disposition = make_content_disposition(None, content_type, filename); + + services + .media + .upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file) + .await?; + + Ok(create_content_async::v3::Response {}) +} + /// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}` /// /// Load media thumbnail from our server or over federation. @@ -296,7 +370,11 @@ async fn fetch_thumbnail_meta( timeout_ms: Duration, dim: &Dim, ) -> Result { - if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? { + if let Some(filemeta) = services + .media + .get_thumbnail_with_timeout(mxc, dim, timeout_ms) + .await? + { return Ok(filemeta); } @@ -316,7 +394,11 @@ async fn fetch_file_meta( user: &UserId, timeout_ms: Duration, ) -> Result { - if let Some(filemeta) = services.media.get(mxc).await? { + if let Some(filemeta) = services + .media + .get_with_timeout(mxc, timeout_ms) + .await? + { return Ok(filemeta); } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index 96e8d1d9..4ffe0be7 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route( }; let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?; - match services.media.get_thumbnail(&mxc, &dim).await? { + match services + .media + .get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, diff --git a/src/api/router.rs b/src/api/router.rs index 8b0e4ace..7263f4ea 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -158,6 +158,8 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::turn_server_route) .ruma_route(&client::send_event_to_device_route) .ruma_route(&client::create_content_route) + .ruma_route(&client::create_mxc_uri_route) + .ruma_route(&client::create_content_async_route) .ruma_route(&client::get_content_thumbnail_route) .ruma_route(&client::get_content_route) .ruma_route(&client::get_content_as_filename_route) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index effca79a..2ff20575 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -408,6 +408,33 @@ pub struct Config { #[serde(default = "default_max_request_size")] pub max_request_size: usize, + /// Maximum number of concurrently pending (asynchronous) media uploads a + /// user can have. + /// + /// default: 5 + #[serde(default = "default_max_pending_media_uploads")] + pub max_pending_media_uploads: usize, + + /// The time in seconds before an unused pending MXC URI expires and is + /// removed. + /// + /// default: 86400 (24 hours) + #[serde(default = "default_media_create_unused_expiration_time")] + pub media_create_unused_expiration_time: u64, + + /// The maximum number of media create requests per second allowed from a + /// single user. + /// + /// default: 10 + #[serde(default = "default_rc_media_create_per_second")] + pub rc_media_create_per_second: u64, + + /// The maximum burst count for media create requests from a single user. + /// + /// default: 50 + #[serde(default = "default_rc_media_create_burst_count")] + pub rc_media_create_burst_count: u64, + /// default: 192 #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -3245,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 } fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } +fn default_max_pending_media_uploads() -> usize { 5 } +fn default_media_create_unused_expiration_time() -> u64 { 86400 } +fn default_rc_media_create_per_second() -> u64 { 10 } +fn default_rc_media_create_burst_count() -> u64 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/database/maps.rs b/src/database/maps.rs index 51b7d097..3176818c 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -128,6 +128,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "mediaid_file", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "mediaid_pending", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "mediaid_user", ..descriptor::RANDOM_SMALL diff --git a/src/service/media/data.rs b/src/service/media/data.rs index aa2e8ba4..a57cfcc2 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -12,6 +12,7 @@ use super::{preview::UrlPreviewData, thumbnail::Dim}; pub(crate) struct Data { mediaid_file: Arc, + mediaid_pending: Arc, mediaid_user: Arc, url_previews: Arc, } @@ -27,6 +28,7 @@ impl Data { pub(super) fn new(db: &Arc) -> Self { Self { mediaid_file: db["mediaid_file"].clone(), + mediaid_pending: db["mediaid_pending"].clone(), mediaid_user: db["mediaid_user"].clone(), url_previews: db["url_previews"].clone(), } @@ -52,6 +54,114 @@ impl Data { Ok(key.to_vec()) } + /// Insert a pending MXC URI into the database + pub(super) fn insert_pending_mxc( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) { + let key = mxc.to_string(); + tracing::info!( + "Inserting pending MXC: key={}, user={}, expires_at={}", + key, + user, + unused_expires_at + ); + // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and + // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] + let mut value = Vec::with_capacity(user.as_bytes().len() + 9); + value.extend_from_slice(&unused_expires_at.to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(user.as_bytes()); + self.mediaid_pending + .insert(key.as_bytes(), &value); + } + + /// Count the number of pending MXC URIs for a specific user + pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { + let mut count = 0; + let mut earliest_expiration = u64::MAX; + let user_bytes = user.as_bytes(); + + self.mediaid_pending + .raw_stream() + .ignore_err() + .ready_for_each(|(_key, value)| { + let mut parts = value.splitn(2, |&b| b == 0xFF); + if let Some(expires_at_bytes) = parts.next() { + if let Some(user_id_bytes) = parts.next() { + if user_id_bytes == user_bytes { + count += 1; + let expires_at = u64::from_be_bytes( + expires_at_bytes.try_into().unwrap_or([0u8; 8]), + ); + if expires_at < earliest_expiration { + earliest_expiration = expires_at; + } + } + } + } + }) + .await; + + (count, earliest_expiration) + } + + /// Search for a pending MXC URI in the database + pub(super) async fn search_pending_mxc( + &self, + mxc: &Mxc<'_>, + ) -> Option<(ruma::OwnedUserId, u64)> { + let key = mxc.to_string(); + tracing::info!("Searching for pending MXC: key={}", key); + let value = match self.mediaid_pending.get(key.as_bytes()).await { + | Ok(v) => v, + | Err(e) => { + tracing::info!("pending MXC not found or error for {}: {}", key, e); + return None; + }, + }; + let mut parts = value.splitn(2, |&b| b == 0xFF); + // value: [expires_at, 0xFF, user] + let expires_at_bytes = parts.next()?; + let expires_at = match expires_at_bytes.try_into() { + | Ok(v) => u64::from_be_bytes(v), + | Err(_) => { + tracing::error!("Failed to parse expires_at for {}", key); + return None; + }, + }; + let user_id_bytes = parts.next()?; + let user_str = match str_from_bytes(user_id_bytes) { + | Ok(v) => v, + | Err(_) => { + tracing::error!("Failed to parse user_str for {}", key); + return None; + }, + }; + let user_id = match user_str.try_into() { + | Ok(v) => v, + | Err(e) => { + tracing::error!("Failed to parse user_id for {}: {}", key, e); + return None; + }, + }; + + tracing::info!( + "Found pending MXC for {}: user={}, expires_at={}", + key, + user_id, + expires_at + ); + Some((user_id, expires_at)) + } + + pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { + self.mediaid_pending + .remove(mxc.to_string().as_bytes()); + } + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 9cd20600..0ce48a3f 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -5,7 +5,11 @@ mod preview; mod remote; mod tests; mod thumbnail; -use std::{path::PathBuf, sync::Arc, time::SystemTime}; +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose}; @@ -13,6 +17,7 @@ use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, + sync::Notify, }; use tuwunel_core::{ Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, @@ -29,9 +34,18 @@ pub struct FileMeta { pub content_type: Option, pub content_disposition: Option, } +/// For MSC2246 +pub(super) struct MXCState { + /// Save the notifier for each pending media upload + pub(super) notifiers: std::sync::Mutex>>, + /// Save the ratelimiter for each user + pub(super) ratelimiter: + std::sync::Mutex>, +} pub struct Service { url_preview_mutex: MutexMap, + mxc_state: MXCState, pub(super) db: Data, services: Arc, } @@ -50,6 +64,10 @@ impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { url_preview_mutex: MutexMap::new(), + mxc_state: MXCState { + notifiers: std::sync::Mutex::new(std::collections::HashMap::new()), + ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()), + }, db: Data::new(args.db), services: args.services.clone(), })) @@ -65,6 +83,119 @@ impl crate::Service for Service { } impl Service { + /// Create a pending media upload ID. + pub async fn create_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) -> Result<()> { + let config = &self.services.server.config; + + // Rate limiting (rc_media_create) + let rate = config.rc_media_create_per_second as f64; + let burst = config.rc_media_create_burst_count as f64; + + // Check rate limiting + if rate > 0.0 && burst > 0.0 { + let now = Instant::now(); + let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap(); + + let (last_time, tokens) = ratelimiter + .entry(user.to_owned()) + .or_insert_with(|| (now, burst)); + + let elapsed = now.duration_since(*last_time).as_secs_f64(); + let new_tokens = (*tokens + elapsed * rate).min(burst); + + if new_tokens >= 1.0 { + *last_time = now; + *tokens = new_tokens - 1.0; + } else { + return Err(tuwunel_core::Error::Request( + ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None }, + "Too many pending media creation requests.".into(), + http::StatusCode::TOO_MANY_REQUESTS, + )); + } + } + + let max_uploads = config.max_pending_media_uploads; + let (current_uploads, earliest_expiration) = + self.db.count_pending_mxc_for_user(user).await; + + // Check if the user has reached the maximum number of pending media uploads + if current_uploads >= max_uploads { + let retry_after = earliest_expiration.saturating_sub( + SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64, + ); + return Err(tuwunel_core::Error::Request( + ruma::api::client::error::ErrorKind::LimitExceeded { + retry_after: Some(ruma::api::client::error::RetryAfter::Delay( + Duration::from_millis(retry_after), + )), + }, + "Maximum number of pending media uploads reached.".into(), + http::StatusCode::TOO_MANY_REQUESTS, + )); + } + + self.db + .insert_pending_mxc(mxc, user, unused_expires_at); + Ok(()) + } + + /// Uploads content to a pending media ID. + pub async fn upload_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + content_disposition: Option<&ContentDisposition>, + content_type: Option<&str>, + file: &[u8], + ) -> Result<()> { + let pending = self.db.search_pending_mxc(mxc).await; + let Some((owner_id, expires_at)) = pending else { + return Err!(Request(NotFound("Media not found"))); + }; + + if owner_id != user { + return Err!(Request(Forbidden("You did not create this media ID"))); + } + + let current_time = SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as u64; + + if expires_at < current_time { + return Err!(Request(NotFound("Pending media ID expired"))); + } + + self.create(mxc, Some(user), content_disposition, content_type, file) + .await?; + + self.db.remove_pending_mxc(mxc); + + if let Some(notifier) = self + .mxc_state + .notifiers + .lock() + .unwrap() + .get(&mxc.to_string()) + .cloned() + { + notifier.notify_waiters(); + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.remove(&mxc.to_string()); + } + + Ok(()) + } + /// Uploads a file. pub async fn create( &self, @@ -168,6 +299,71 @@ impl Service { } } + /// Download a file and wait up to a timeout_ms if it is pending. + pub async fn get_with_timeout( + &self, + mxc: &Mxc<'_>, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get(mxc).await? { + return Ok(Some(meta)); + } + + let pending = self.db.search_pending_mxc(mxc).await; + if pending.is_none() { + return Ok(None); + } + + let notifier = { + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.entry(mxc.to_string()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone() + }; + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get(mxc).await + } + + /// Download a thumbnail and wait up to a timeout_ms if it is pending. + pub async fn get_thumbnail_with_timeout( + &self, + mxc: &Mxc<'_>, + dim: &Dim, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get_thumbnail(mxc, dim).await? { + return Ok(Some(meta)); + } + + let pending = self.db.search_pending_mxc(mxc).await; + if pending.is_none() { + return Ok(None); + } + + let notifier = { + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.entry(mxc.to_string()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone() + }; + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get_thumbnail(mxc, dim).await + } + /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { let all_keys = self.db.get_all_media_keys().await; diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 7a2ae8c9..f0b6c7bd 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -320,6 +320,22 @@ # #max_request_size = 24 MiB +# Maximum number of concurrently pending (asynchronous) media uploads a user can have. +# +#max_pending_media_uploads = 5 + +# The time in seconds before an unused pending MXC URI expires and is removed. +# +#media_create_unused_expiration_time = 86400 (24 hours) + +# The maximum number of media create requests per second allowed from a single user. +# +#rc_media_create_per_second = 10 + +# The maximum burst count for media create requests from a single user. +# +#rc_media_create_burst_count = 50 + # This item is undocumented. Please contribute documentation for it. # #max_fetch_prev_events = 192 From ad896bb091c6eb976b0006f6dc685e36ebf4bfb0 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 8 Mar 2026 12:23:02 +0800 Subject: [PATCH 2/5] cllipy fix --- src/api/client/media.rs | 22 +++++++++++--------- src/core/config/mod.rs | 8 ++++---- src/service/media/data.rs | 42 +++++++++++++++++++++------------------ src/service/media/mod.rs | 22 ++++++++++---------- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index d63df14f..53898b00 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -100,17 +100,21 @@ pub(crate) async fn create_mxc_uri_route( media_id: &utils::random_string(MXC_LENGTH), }; - let unused_expires_at = (std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() as u64) - .saturating_add( - services + let unused_expires_at = u64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(), + )? + .saturating_add( + services .server .config - .media_create_unused_expiration_time - * 1000, - ); + .media_create_unused_expiration_time + // safe because even if it overflows, it will be greater than the current time + // and the unused media will be deleted anyway + .saturating_mul(1000), + ); services .media .create_pending(&mxc, user, unused_expires_at) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 2ff20575..06d162ee 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -427,13 +427,13 @@ pub struct Config { /// /// default: 10 #[serde(default = "default_rc_media_create_per_second")] - pub rc_media_create_per_second: u64, + pub rc_media_create_per_second: u32, /// The maximum burst count for media create requests from a single user. /// /// default: 50 #[serde(default = "default_rc_media_create_burst_count")] - pub rc_media_create_burst_count: u64, + pub rc_media_create_burst_count: u32, /// default: 192 #[serde(default = "default_max_fetch_prev_events")] @@ -3274,8 +3274,8 @@ fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } fn default_max_pending_media_uploads() -> usize { 5 } fn default_media_create_unused_expiration_time() -> u64 { 86400 } -fn default_rc_media_create_per_second() -> u64 { 10 } -fn default_rc_media_create_burst_count() -> u64 { 50 } +fn default_rc_media_create_per_second() -> u32 { 10 } +fn default_rc_media_create_burst_count() -> u32 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/service/media/data.rs b/src/service/media/data.rs index a57cfcc2..3de84426 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -70,7 +70,12 @@ impl Data { ); // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] - let mut value = Vec::with_capacity(user.as_bytes().len() + 9); + let mut value = Vec::with_capacity( + user.as_bytes() + .len() + .checked_add(9) + .expect("User len too large,capacity overflow!"), + ); value.extend_from_slice(&unused_expires_at.to_be_bytes()); value.push(0xFF); value.extend_from_slice(user.as_bytes()); @@ -80,7 +85,7 @@ impl Data { /// Count the number of pending MXC URIs for a specific user pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { - let mut count = 0; + let mut count: usize = 0; let mut earliest_expiration = u64::MAX; let user_bytes = user.as_bytes(); @@ -89,18 +94,19 @@ impl Data { .ignore_err() .ready_for_each(|(_key, value)| { let mut parts = value.splitn(2, |&b| b == 0xFF); - if let Some(expires_at_bytes) = parts.next() { - if let Some(user_id_bytes) = parts.next() { - if user_id_bytes == user_bytes { - count += 1; - let expires_at = u64::from_be_bytes( - expires_at_bytes.try_into().unwrap_or([0u8; 8]), - ); - if expires_at < earliest_expiration { - earliest_expiration = expires_at; - } + match (parts.next(), parts.next()) { + | (Some(expires_at_bytes), Some(user_id_bytes)) + if user_id_bytes == user_bytes => + { + // safe to add 1 even if count = usize::MAX it also > max_uploads + count = count.saturating_add(1_usize); + let expires_at = + u64::from_be_bytes(expires_at_bytes.try_into().unwrap_or([0_u8; 8])); + if expires_at < earliest_expiration { + earliest_expiration = expires_at; } - } + }, + | _ => {}, } }) .await; @@ -133,12 +139,10 @@ impl Data { }, }; let user_id_bytes = parts.next()?; - let user_str = match str_from_bytes(user_id_bytes) { - | Ok(v) => v, - | Err(_) => { - tracing::error!("Failed to parse user_str for {}", key); - return None; - }, + + let Ok(user_str) = str_from_bytes(user_id_bytes) else { + tracing::error!("Failed to parse user_str for {}", key); + return None; }; let user_id = match user_str.try_into() { | Ok(v) => v, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 0ce48a3f..81cd2987 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -93,8 +93,8 @@ impl Service { let config = &self.services.server.config; // Rate limiting (rc_media_create) - let rate = config.rc_media_create_per_second as f64; - let burst = config.rc_media_create_burst_count as f64; + let rate = f64::from(config.rc_media_create_per_second); + let burst = f64::from(config.rc_media_create_burst_count); // Check rate limiting if rate > 0.0 && burst > 0.0 { @@ -106,7 +106,7 @@ impl Service { .or_insert_with(|| (now, burst)); let elapsed = now.duration_since(*last_time).as_secs_f64(); - let new_tokens = (*tokens + elapsed * rate).min(burst); + let new_tokens = elapsed.mul_add(rate, *tokens).min(burst); if new_tokens >= 1.0 { *last_time = now; @@ -126,12 +126,12 @@ impl Service { // Check if the user has reached the maximum number of pending media uploads if current_uploads >= max_uploads { - let retry_after = earliest_expiration.saturating_sub( + let retry_after = earliest_expiration.saturating_sub(u64::try_from( SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() - .as_millis() as u64, - ); + .as_millis(), + )?); return Err(tuwunel_core::Error::Request( ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: Some(ruma::api::client::error::RetryAfter::Delay( @@ -166,10 +166,12 @@ impl Service { return Err!(Request(Forbidden("You did not create this media ID"))); } - let current_time = SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() as u64; + let current_time = u64::try_from( + SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(), + )?; if expires_at < current_time { return Err!(Request(NotFound("Pending media ID expired"))); From c960a9dbc385f0faa2b749bd641593f8e02f4feb Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 8 Mar 2026 19:52:11 +0800 Subject: [PATCH 3/5] M_NOT_YET_UPLOAD and can not override,and change the result asyncupload to pass --- src/core/error/response.rs | 6 ++++++ src/service/media/mod.rs | 4 ++++ tests/complement/results.jsonl | 14 +++++++------- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/core/error/response.rs b/src/core/error/response.rs index c41d542c..41cfa72d 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -60,6 +60,9 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { // 429 | LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, + // 409 + | CannotOverwriteMedia => StatusCode::CONFLICT, + // 413 | TooLarge => StatusCode::PAYLOAD_TOO_LARGE, @@ -69,6 +72,9 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { // 404 | NotFound | NotImplemented | FeatureDisabled => StatusCode::NOT_FOUND, + // 504 + | NotYetUploaded => StatusCode::GATEWAY_TIMEOUT, + // 403 | GuestAccessForbidden | ThreepidAuthFailed diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 81cd2987..9a84a29c 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -159,6 +159,10 @@ impl Service { ) -> Result<()> { let pending = self.db.search_pending_mxc(mxc).await; let Some((owner_id, expires_at)) = pending else { + if self.get_metadata(mxc).await.is_some() { + return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); + } + return Err!(Request(NotFound("Media not found"))); }; diff --git a/tests/complement/results.jsonl b/tests/complement/results.jsonl index 810bb021..8b6bb70b 100644 --- a/tests/complement/results.jsonl +++ b/tests/complement/results.jsonl @@ -9,13 +9,13 @@ {"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty"} {"Action":"skip","Test":"TestArchivedRoomsHistory/timeline_is_empty/incremental_sync"} {"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty/initial_sync"} -{"Action":"fail","Test":"TestAsyncUpload"} -{"Action":"fail","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"} -{"Action":"fail","Test":"TestAsyncUpload/Create_media"} -{"Action":"fail","Test":"TestAsyncUpload/Download_media"} -{"Action":"fail","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"} -{"Action":"fail","Test":"TestAsyncUpload/Not_yet_uploaded"} -{"Action":"fail","Test":"TestAsyncUpload/Upload_media"} +{"Action":"pass","Test":"TestAsyncUpload"} +{"Action":"pass","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"} +{"Action":"pass","Test":"TestAsyncUpload/Create_media"} +{"Action":"pass","Test":"TestAsyncUpload/Download_media"} +{"Action":"pass","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"} +{"Action":"pass","Test":"TestAsyncUpload/Not_yet_uploaded"} +{"Action":"pass","Test":"TestAsyncUpload/Upload_media"} {"Action":"pass","Test":"TestAvatarUrlUpdate"} {"Action":"pass","Test":"TestBannedUserCannotSendJoin"} {"Action":"skip","Test":"TestCanRegisterAdmin"} From dfedef4d193f8241640ff85333a1faed446b2354 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 8 Mar 2026 14:36:49 +0000 Subject: [PATCH 4/5] Cleanup --- src/api/client/media.rs | 23 +++++----- src/core/error/response.rs | 12 +++--- src/service/media/data.rs | 29 +++++-------- src/service/media/mod.rs | 86 +++++++++++++++++--------------------- tuwunel-example.toml | 9 ++-- 5 files changed, 71 insertions(+), 88 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 53898b00..60b37bf1 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -15,7 +15,10 @@ use ruma::{ }; use tuwunel_core::{ Err, Result, err, - utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize}, + utils::{ + self, content_disposition::make_content_disposition, math::ruma_from_usize, + time::now_millis, + }, }; use tuwunel_service::{ Services, @@ -100,20 +103,14 @@ pub(crate) async fn create_mxc_uri_route( media_id: &utils::random_string(MXC_LENGTH), }; - let unused_expires_at = u64::try_from( - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(), - )? - .saturating_add( + // safe because even if it overflows, it will be greater than the current time + // and the unused media will be deleted anyway + let unused_expires_at = now_millis().saturating_add( services - .server - .config + .server + .config .media_create_unused_expiration_time - // safe because even if it overflows, it will be greater than the current time - // and the unused media will be deleted anyway - .saturating_mul(1000), + .saturating_mul(1000), ); services .media diff --git a/src/core/error/response.rs b/src/core/error/response.rs index 41cfa72d..cc15038e 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -57,24 +57,24 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { use ErrorKind::*; match kind { + // 504 + | NotYetUploaded => StatusCode::GATEWAY_TIMEOUT, + // 429 | LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, - // 409 - | CannotOverwriteMedia => StatusCode::CONFLICT, - // 413 | TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + // 409 + | CannotOverwriteMedia => StatusCode::CONFLICT, + // 405 | Unrecognized => StatusCode::METHOD_NOT_ALLOWED, // 404 | NotFound | NotImplemented | FeatureDisabled => StatusCode::NOT_FOUND, - // 504 - | NotYetUploaded => StatusCode::GATEWAY_TIMEOUT, - // 403 | GuestAccessForbidden | ThreepidAuthFailed diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 3de84426..0bd4751e 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use futures::{StreamExt, pin_mut}; use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use tuwunel_core::{ - Err, Result, debug, debug_info, err, + Err, Result, debug, debug_error, debug_info, err, error, utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, }; use tuwunel_database::{Database, Interfix, Map, serialize_key}; @@ -61,13 +61,9 @@ impl Data { user: &UserId, unused_expires_at: u64, ) { + debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending MXC"); + let key = mxc.to_string(); - tracing::info!( - "Inserting pending MXC: key={}, user={}, expires_at={}", - key, - user, - unused_expires_at - ); // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] let mut value = Vec::with_capacity( @@ -119,12 +115,13 @@ impl Data { &self, mxc: &Mxc<'_>, ) -> Option<(ruma::OwnedUserId, u64)> { + debug!(?mxc, "Searching for pending MXC"); + let key = mxc.to_string(); - tracing::info!("Searching for pending MXC: key={}", key); let value = match self.mediaid_pending.get(key.as_bytes()).await { | Ok(v) => v, | Err(e) => { - tracing::info!("pending MXC not found or error for {}: {}", key, e); + debug_error!(?key, "pending MXC not found or error: {e}"); return None; }, }; @@ -134,30 +131,26 @@ impl Data { let expires_at = match expires_at_bytes.try_into() { | Ok(v) => u64::from_be_bytes(v), | Err(_) => { - tracing::error!("Failed to parse expires_at for {}", key); + debug_error!(?key, "Failed to parse expires_at for key"); return None; }, }; let user_id_bytes = parts.next()?; let Ok(user_str) = str_from_bytes(user_id_bytes) else { - tracing::error!("Failed to parse user_str for {}", key); + error!(?key, "Failed to parse user_str for key"); return None; }; let user_id = match user_str.try_into() { | Ok(v) => v, | Err(e) => { - tracing::error!("Failed to parse user_id for {}: {}", key, e); + error!(?key, "Failed to parse user_id for key: {e}"); return None; }, }; - tracing::info!( - "Found pending MXC for {}: user={}, expires_at={}", - key, - user_id, - expires_at - ); + debug!(?key, ?user_id, ?expires_at, "Found pending MXC"); + Some((user_id, expires_at)) } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 9a84a29c..56e32ffe 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -6,22 +6,28 @@ mod remote; mod tests; mod thumbnail; use std::{ + collections::HashMap, path::PathBuf, - sync::Arc, + sync::{Arc, Mutex}, time::{Duration, Instant, SystemTime}, }; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose}; -use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; +use http::StatusCode; +use ruma::{ + Mxc, OwnedMxcUri, OwnedUserId, UserId, + api::client::error::{ErrorKind, RetryAfter}, + http_headers::ContentDisposition, +}; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, sync::Notify, }; use tuwunel_core::{ - Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, - utils::{self, MutexMap}, + Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, + utils::{self, MutexMap, time::now_millis}, warn, }; @@ -34,20 +40,20 @@ pub struct FileMeta { pub content_type: Option, pub content_disposition: Option, } + /// For MSC2246 -pub(super) struct MXCState { +struct MXCState { /// Save the notifier for each pending media upload - pub(super) notifiers: std::sync::Mutex>>, + notifiers: Mutex>>, /// Save the ratelimiter for each user - pub(super) ratelimiter: - std::sync::Mutex>, + ratelimiter: Mutex>, } pub struct Service { - url_preview_mutex: MutexMap, - mxc_state: MXCState, pub(super) db: Data, services: Arc, + url_preview_mutex: MutexMap, + mxc_state: MXCState, } /// generated MXC ID (`media-id`) length @@ -63,13 +69,13 @@ pub const CORP_CROSS_ORIGIN: &str = "cross-origin"; impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - url_preview_mutex: MutexMap::new(), - mxc_state: MXCState { - notifiers: std::sync::Mutex::new(std::collections::HashMap::new()), - ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()), - }, db: Data::new(args.db), services: args.services.clone(), + url_preview_mutex: MutexMap::new(), + mxc_state: MXCState { + notifiers: Mutex::new(HashMap::new()), + ratelimiter: Mutex::new(HashMap::new()), + }, })) } @@ -112,10 +118,10 @@ impl Service { *last_time = now; *tokens = new_tokens - 1.0; } else { - return Err(tuwunel_core::Error::Request( - ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None }, + return Err(Error::Request( + ErrorKind::LimitExceeded { retry_after: None }, "Too many pending media creation requests.".into(), - http::StatusCode::TOO_MANY_REQUESTS, + StatusCode::TOO_MANY_REQUESTS, )); } } @@ -126,25 +132,19 @@ impl Service { // Check if the user has reached the maximum number of pending media uploads if current_uploads >= max_uploads { - let retry_after = earliest_expiration.saturating_sub(u64::try_from( - SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(), - )?); - return Err(tuwunel_core::Error::Request( - ruma::api::client::error::ErrorKind::LimitExceeded { - retry_after: Some(ruma::api::client::error::RetryAfter::Delay( - Duration::from_millis(retry_after), - )), + let retry_after = earliest_expiration.saturating_sub(now_millis()); + return Err(Error::Request( + ErrorKind::LimitExceeded { + retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))), }, "Maximum number of pending media uploads reached.".into(), - http::StatusCode::TOO_MANY_REQUESTS, + StatusCode::TOO_MANY_REQUESTS, )); } self.db .insert_pending_mxc(mxc, user, unused_expires_at); + Ok(()) } @@ -156,9 +156,8 @@ impl Service { content_disposition: Option<&ContentDisposition>, content_type: Option<&str>, file: &[u8], - ) -> Result<()> { - let pending = self.db.search_pending_mxc(mxc).await; - let Some((owner_id, expires_at)) = pending else { + ) -> Result { + let Some((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { if self.get_metadata(mxc).await.is_some() { return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); } @@ -170,13 +169,7 @@ impl Service { return Err!(Request(Forbidden("You did not create this media ID"))); } - let current_time = u64::try_from( - SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(), - )?; - + let current_time = now_millis(); if expires_at < current_time { return Err!(Request(NotFound("Pending media ID expired"))); } @@ -189,8 +182,7 @@ impl Service { if let Some(notifier) = self .mxc_state .notifiers - .lock() - .unwrap() + .lock()? .get(&mxc.to_string()) .cloned() { @@ -315,10 +307,9 @@ impl Service { return Ok(Some(meta)); } - let pending = self.db.search_pending_mxc(mxc).await; - if pending.is_none() { + let Some(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); - } + }; let notifier = { let mut map = self.mxc_state.notifiers.lock().unwrap(); @@ -348,10 +339,9 @@ impl Service { return Ok(Some(meta)); } - let pending = self.db.search_pending_mxc(mxc).await; - if pending.is_none() { + let Some(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); - } + }; let notifier = { let mut map = self.mxc_state.notifiers.lock().unwrap(); diff --git a/tuwunel-example.toml b/tuwunel-example.toml index f0b6c7bd..36c7635e 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -320,15 +320,18 @@ # #max_request_size = 24 MiB -# Maximum number of concurrently pending (asynchronous) media uploads a user can have. +# Maximum number of concurrently pending (asynchronous) media uploads a +# user can have. # #max_pending_media_uploads = 5 -# The time in seconds before an unused pending MXC URI expires and is removed. +# The time in seconds before an unused pending MXC URI expires and is +# removed. # #media_create_unused_expiration_time = 86400 (24 hours) -# The maximum number of media create requests per second allowed from a single user. +# The maximum number of media create requests per second allowed from a +# single user. # #rc_media_create_per_second = 10 From d15b30de64a61422b5937202353b31daaf61dc34 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 9 Mar 2026 07:14:05 +0000 Subject: [PATCH 5/5] Simplify database queries. --- src/core/config/mod.rs | 12 ++-- src/database/maps.rs | 3 +- src/service/media/data.rs | 123 +++++++++++--------------------------- src/service/media/mod.rs | 50 +++++++++------- tuwunel-example.toml | 4 +- 5 files changed, 71 insertions(+), 121 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 06d162ee..cc61b5d0 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -426,14 +426,14 @@ pub struct Config { /// single user. /// /// default: 10 - #[serde(default = "default_rc_media_create_per_second")] - pub rc_media_create_per_second: u32, + #[serde(default = "default_media_rc_create_per_second")] + pub media_rc_create_per_second: u32, /// The maximum burst count for media create requests from a single user. /// /// default: 50 - #[serde(default = "default_rc_media_create_burst_count")] - pub rc_media_create_burst_count: u32, + #[serde(default = "default_media_rc_create_burst_count")] + pub media_rc_create_burst_count: u32, /// default: 192 #[serde(default = "default_max_fetch_prev_events")] @@ -3274,8 +3274,8 @@ fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } fn default_max_pending_media_uploads() -> usize { 5 } fn default_media_create_unused_expiration_time() -> u64 { 86400 } -fn default_rc_media_create_per_second() -> u32 { 10 } -fn default_rc_media_create_burst_count() -> u32 { 50 } +fn default_media_rc_create_per_second() -> u32 { 10 } +fn default_media_rc_create_burst_count() -> u32 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/database/maps.rs b/src/database/maps.rs index 3176818c..5ec270bb 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -130,7 +130,8 @@ pub(super) static MAPS: &[Descriptor] = &[ }, Descriptor { name: "mediaid_pending", - ..descriptor::RANDOM_SMALL + ttl: 60 * 60 * 24 * 7, + ..descriptor::RANDOM_SMALL_CACHE }, Descriptor { name: "mediaid_user", diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 0bd4751e..0e26fb69 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,12 +1,16 @@ use std::{sync::Arc, time::Duration}; use futures::{StreamExt, pin_mut}; -use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; +use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition}; use tuwunel_core::{ - Err, Result, debug, debug_error, debug_info, err, error, - utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, + Err, Result, debug, debug_info, err, + utils::{ + ReadyExt, str_from_bytes, + stream::{TryExpect, TryIgnore}, + string_from_bytes, + }, }; -use tuwunel_database::{Database, Interfix, Map, serialize_key}; +use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key}; use super::{preview::UrlPreviewData, thumbnail::Dim}; @@ -61,102 +65,43 @@ impl Data { user: &UserId, unused_expires_at: u64, ) { - debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending MXC"); + let value = (unused_expires_at, user); + debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending"); - let key = mxc.to_string(); - // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and - // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] - let mut value = Vec::with_capacity( - user.as_bytes() - .len() - .checked_add(9) - .expect("User len too large,capacity overflow!"), - ); - value.extend_from_slice(&unused_expires_at.to_be_bytes()); - value.push(0xFF); - value.extend_from_slice(user.as_bytes()); - self.mediaid_pending - .insert(key.as_bytes(), &value); + self.mediaid_pending.put(mxc, value); } + /// Remove a pending MXC URI from the database + pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending.del(mxc); } + /// Count the number of pending MXC URIs for a specific user - pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { - let mut count: usize = 0; - let mut earliest_expiration = u64::MAX; - let user_bytes = user.as_bytes(); + pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) { + type KeyVal<'a> = (Ignore, (u64, &'a UserId)); self.mediaid_pending - .raw_stream() - .ignore_err() - .ready_for_each(|(_key, value)| { - let mut parts = value.splitn(2, |&b| b == 0xFF); - match (parts.next(), parts.next()) { - | (Some(expires_at_bytes), Some(user_id_bytes)) - if user_id_bytes == user_bytes => - { - // safe to add 1 even if count = usize::MAX it also > max_uploads - count = count.saturating_add(1_usize); - let expires_at = - u64::from_be_bytes(expires_at_bytes.try_into().unwrap_or([0_u8; 8])); - if expires_at < earliest_expiration { - earliest_expiration = expires_at; - } - }, - | _ => {}, - } - }) - .await; - - (count, earliest_expiration) + .stream() + .expect_ok() + .ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id) + .ready_fold( + (0_usize, u64::MAX), + |(count, earliest_expiration), (_, (expires_at, _))| { + (count.saturating_add(1), earliest_expiration.min(expires_at)) + }, + ) + .await } /// Search for a pending MXC URI in the database - pub(super) async fn search_pending_mxc( - &self, - mxc: &Mxc<'_>, - ) -> Option<(ruma::OwnedUserId, u64)> { - debug!(?mxc, "Searching for pending MXC"); + pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> { + type Value<'a> = (u64, OwnedUserId); - let key = mxc.to_string(); - let value = match self.mediaid_pending.get(key.as_bytes()).await { - | Ok(v) => v, - | Err(e) => { - debug_error!(?key, "pending MXC not found or error: {e}"); - return None; - }, - }; - let mut parts = value.splitn(2, |&b| b == 0xFF); - // value: [expires_at, 0xFF, user] - let expires_at_bytes = parts.next()?; - let expires_at = match expires_at_bytes.try_into() { - | Ok(v) => u64::from_be_bytes(v), - | Err(_) => { - debug_error!(?key, "Failed to parse expires_at for key"); - return None; - }, - }; - let user_id_bytes = parts.next()?; - - let Ok(user_str) = str_from_bytes(user_id_bytes) else { - error!(?key, "Failed to parse user_str for key"); - return None; - }; - let user_id = match user_str.try_into() { - | Ok(v) => v, - | Err(e) => { - error!(?key, "Failed to parse user_id for key: {e}"); - return None; - }, - }; - - debug!(?key, ?user_id, ?expires_at, "Found pending MXC"); - - Some((user_id, expires_at)) - } - - pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending - .remove(mxc.to_string().as_bytes()); + .qry(mxc) + .await + .deserialized() + .map(|(expires_at, user_id): Value<'_>| (user_id, expires_at)) + .inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending")) + .map_err(|e| err!(Request(NotFound("Pending not found or error: {e}")))) } pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 56e32ffe..d0698999 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -44,7 +44,7 @@ pub struct FileMeta { /// For MSC2246 struct MXCState { /// Save the notifier for each pending media upload - notifiers: Mutex>>, + notifiers: Mutex>>, /// Save the ratelimiter for each user ratelimiter: Mutex>, } @@ -90,22 +90,23 @@ impl crate::Service for Service { impl Service { /// Create a pending media upload ID. + #[tracing::instrument(level = "debug", skip(self))] pub async fn create_pending( &self, mxc: &Mxc<'_>, user: &UserId, unused_expires_at: u64, - ) -> Result<()> { + ) -> Result { let config = &self.services.server.config; // Rate limiting (rc_media_create) - let rate = f64::from(config.rc_media_create_per_second); - let burst = f64::from(config.rc_media_create_burst_count); + let rate = f64::from(config.media_rc_create_per_second); + let burst = f64::from(config.media_rc_create_burst_count); // Check rate limiting if rate > 0.0 && burst > 0.0 { let now = Instant::now(); - let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap(); + let mut ratelimiter = self.mxc_state.ratelimiter.lock()?; let (last_time, tokens) = ratelimiter .entry(user.to_owned()) @@ -149,6 +150,7 @@ impl Service { } /// Uploads content to a pending media ID. + #[tracing::instrument(level = "debug", skip(self))] pub async fn upload_pending( &self, mxc: &Mxc<'_>, @@ -157,7 +159,7 @@ impl Service { content_type: Option<&str>, file: &[u8], ) -> Result { - let Some((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { + let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { if self.get_metadata(mxc).await.is_some() { return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); } @@ -179,16 +181,16 @@ impl Service { self.db.remove_pending_mxc(mxc); + let mxc_uri: OwnedMxcUri = mxc.to_string().into(); if let Some(notifier) = self .mxc_state .notifiers .lock()? - .get(&mxc.to_string()) + .get(&mxc_uri) .cloned() { notifier.notify_waiters(); - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.remove(&mxc.to_string()); + self.mxc_state.notifiers.lock()?.remove(&mxc_uri); } Ok(()) @@ -307,16 +309,17 @@ impl Service { return Ok(Some(meta)); } - let Some(_pending) = self.db.search_pending_mxc(mxc).await else { + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); }; - let notifier = { - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.entry(mxc.to_string()) - .or_insert_with(|| Arc::new(Notify::new())) - .clone() - }; + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); if tokio::time::timeout(timeout_duration, notifier.notified()) .await @@ -339,16 +342,17 @@ impl Service { return Ok(Some(meta)); } - let Some(_pending) = self.db.search_pending_mxc(mxc).await else { + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { return Ok(None); }; - let notifier = { - let mut map = self.mxc_state.notifiers.lock().unwrap(); - map.entry(mxc.to_string()) - .or_insert_with(|| Arc::new(Notify::new())) - .clone() - }; + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); if tokio::time::timeout(timeout_duration, notifier.notified()) .await diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 36c7635e..dc97b8d8 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -333,11 +333,11 @@ # The maximum number of media create requests per second allowed from a # single user. # -#rc_media_create_per_second = 10 +#media_rc_create_per_second = 10 # The maximum burst count for media create requests from a single user. # -#rc_media_create_burst_count = 50 +#media_rc_create_burst_count = 50 # This item is undocumented. Please contribute documentation for it. #