From 2b81e189cb9972d3ce38614fec70e67090bb1a98 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 1 Mar 2026 05:15:09 +0800 Subject: [PATCH] add MSC2246 support --- src/api/client/media.rs | 90 ++++++++++++++- src/api/client/media_legacy.rs | 18 ++- src/api/router.rs | 2 + src/core/config/mod.rs | 31 ++++++ src/database/maps.rs | 4 + src/service/media/data.rs | 110 ++++++++++++++++++ src/service/media/mod.rs | 198 ++++++++++++++++++++++++++++++++- tuwunel-example.toml | 16 +++ 8 files changed, 461 insertions(+), 8 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 9b1cb956..d63df14f 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -4,13 +4,13 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use reqwest::Url; use ruma::{ - Mxc, UserId, + MilliSecondsSinceUnixEpoch, Mxc, UserId, api::client::{ authenticated_media::{ get_content, get_content_as_filename, get_content_thumbnail, get_media_config, get_media_preview, }, - media::create_content, + media::{create_content, create_content_async, create_mxc_uri}, }, }; use tuwunel_core::{ @@ -80,6 +80,80 @@ pub(crate) async fn create_content_route( }) } +/// # `POST /_matrix/media/v1/create` +/// +/// Create a new MXC URI without content. +#[tracing::instrument( + name = "media_create_mxc", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_mxc_uri_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: services.globals.server_name(), + media_id: &utils::random_string(MXC_LENGTH), + }; + + let unused_expires_at = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as u64) + .saturating_add( + services + .server + .config + .media_create_unused_expiration_time + * 1000, + ); + services + .media + .create_pending(&mxc, user, unused_expires_at) + .await?; + + Ok(create_mxc_uri::v1::Response { + content_uri: mxc.to_string().into(), + unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch), + }) +} + +/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}` +/// +/// Upload content to a MXC URI that was created earlier. +#[tracing::instrument( + name = "media_upload_async", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_content_async_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: &body.server_name, + media_id: &body.media_id, + }; + + let filename = body.filename.as_deref(); + let content_type = body.content_type.as_deref(); + let content_disposition = make_content_disposition(None, content_type, filename); + + services + .media + .upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file) + .await?; + + Ok(create_content_async::v3::Response {}) +} + /// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}` /// /// Load media thumbnail from our server or over federation. @@ -296,7 +370,11 @@ async fn fetch_thumbnail_meta( timeout_ms: Duration, dim: &Dim, ) -> Result { - if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? { + if let Some(filemeta) = services + .media + .get_thumbnail_with_timeout(mxc, dim, timeout_ms) + .await? + { return Ok(filemeta); } @@ -316,7 +394,11 @@ async fn fetch_file_meta( user: &UserId, timeout_ms: Duration, ) -> Result { - if let Some(filemeta) = services.media.get(mxc).await? { + if let Some(filemeta) = services + .media + .get_with_timeout(mxc, timeout_ms) + .await? + { return Ok(filemeta); } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index 96e8d1d9..4ffe0be7 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route( }; let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?; - match services.media.get_thumbnail(&mxc, &dim).await? { + match services + .media + .get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, diff --git a/src/api/router.rs b/src/api/router.rs index 8b0e4ace..7263f4ea 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -158,6 +158,8 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::turn_server_route) .ruma_route(&client::send_event_to_device_route) .ruma_route(&client::create_content_route) + .ruma_route(&client::create_mxc_uri_route) + .ruma_route(&client::create_content_async_route) .ruma_route(&client::get_content_thumbnail_route) .ruma_route(&client::get_content_route) .ruma_route(&client::get_content_as_filename_route) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index effca79a..2ff20575 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -408,6 +408,33 @@ pub struct Config { #[serde(default = "default_max_request_size")] pub max_request_size: usize, + /// Maximum number of concurrently pending (asynchronous) media uploads a + /// user can have. + /// + /// default: 5 + #[serde(default = "default_max_pending_media_uploads")] + pub max_pending_media_uploads: usize, + + /// The time in seconds before an unused pending MXC URI expires and is + /// removed. + /// + /// default: 86400 (24 hours) + #[serde(default = "default_media_create_unused_expiration_time")] + pub media_create_unused_expiration_time: u64, + + /// The maximum number of media create requests per second allowed from a + /// single user. + /// + /// default: 10 + #[serde(default = "default_rc_media_create_per_second")] + pub rc_media_create_per_second: u64, + + /// The maximum burst count for media create requests from a single user. + /// + /// default: 50 + #[serde(default = "default_rc_media_create_burst_count")] + pub rc_media_create_burst_count: u64, + /// default: 192 #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -3245,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 } fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } +fn default_max_pending_media_uploads() -> usize { 5 } +fn default_media_create_unused_expiration_time() -> u64 { 86400 } +fn default_rc_media_create_per_second() -> u64 { 10 } +fn default_rc_media_create_burst_count() -> u64 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/database/maps.rs b/src/database/maps.rs index 51b7d097..3176818c 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -128,6 +128,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "mediaid_file", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "mediaid_pending", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "mediaid_user", ..descriptor::RANDOM_SMALL diff --git a/src/service/media/data.rs b/src/service/media/data.rs index aa2e8ba4..a57cfcc2 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -12,6 +12,7 @@ use super::{preview::UrlPreviewData, thumbnail::Dim}; pub(crate) struct Data { mediaid_file: Arc, + mediaid_pending: Arc, mediaid_user: Arc, url_previews: Arc, } @@ -27,6 +28,7 @@ impl Data { pub(super) fn new(db: &Arc) -> Self { Self { mediaid_file: db["mediaid_file"].clone(), + mediaid_pending: db["mediaid_pending"].clone(), mediaid_user: db["mediaid_user"].clone(), url_previews: db["url_previews"].clone(), } @@ -52,6 +54,114 @@ impl Data { Ok(key.to_vec()) } + /// Insert a pending MXC URI into the database + pub(super) fn insert_pending_mxc( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) { + let key = mxc.to_string(); + tracing::info!( + "Inserting pending MXC: key={}, user={}, expires_at={}", + key, + user, + unused_expires_at + ); + // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and + // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] + let mut value = Vec::with_capacity(user.as_bytes().len() + 9); + value.extend_from_slice(&unused_expires_at.to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(user.as_bytes()); + self.mediaid_pending + .insert(key.as_bytes(), &value); + } + + /// Count the number of pending MXC URIs for a specific user + pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) { + let mut count = 0; + let mut earliest_expiration = u64::MAX; + let user_bytes = user.as_bytes(); + + self.mediaid_pending + .raw_stream() + .ignore_err() + .ready_for_each(|(_key, value)| { + let mut parts = value.splitn(2, |&b| b == 0xFF); + if let Some(expires_at_bytes) = parts.next() { + if let Some(user_id_bytes) = parts.next() { + if user_id_bytes == user_bytes { + count += 1; + let expires_at = u64::from_be_bytes( + expires_at_bytes.try_into().unwrap_or([0u8; 8]), + ); + if expires_at < earliest_expiration { + earliest_expiration = expires_at; + } + } + } + } + }) + .await; + + (count, earliest_expiration) + } + + /// Search for a pending MXC URI in the database + pub(super) async fn search_pending_mxc( + &self, + mxc: &Mxc<'_>, + ) -> Option<(ruma::OwnedUserId, u64)> { + let key = mxc.to_string(); + tracing::info!("Searching for pending MXC: key={}", key); + let value = match self.mediaid_pending.get(key.as_bytes()).await { + | Ok(v) => v, + | Err(e) => { + tracing::info!("pending MXC not found or error for {}: {}", key, e); + return None; + }, + }; + let mut parts = value.splitn(2, |&b| b == 0xFF); + // value: [expires_at, 0xFF, user] + let expires_at_bytes = parts.next()?; + let expires_at = match expires_at_bytes.try_into() { + | Ok(v) => u64::from_be_bytes(v), + | Err(_) => { + tracing::error!("Failed to parse expires_at for {}", key); + return None; + }, + }; + let user_id_bytes = parts.next()?; + let user_str = match str_from_bytes(user_id_bytes) { + | Ok(v) => v, + | Err(_) => { + tracing::error!("Failed to parse user_str for {}", key); + return None; + }, + }; + let user_id = match user_str.try_into() { + | Ok(v) => v, + | Err(e) => { + tracing::error!("Failed to parse user_id for {}: {}", key, e); + return None; + }, + }; + + tracing::info!( + "Found pending MXC for {}: user={}, expires_at={}", + key, + user_id, + expires_at + ); + Some((user_id, expires_at)) + } + + pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { + self.mediaid_pending + .remove(mxc.to_string().as_bytes()); + } + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 9cd20600..0ce48a3f 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -5,7 +5,11 @@ mod preview; mod remote; mod tests; mod thumbnail; -use std::{path::PathBuf, sync::Arc, time::SystemTime}; +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose}; @@ -13,6 +17,7 @@ use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, + sync::Notify, }; use tuwunel_core::{ Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, @@ -29,9 +34,18 @@ pub struct FileMeta { pub content_type: Option, pub content_disposition: Option, } +/// For MSC2246 +pub(super) struct MXCState { + /// Save the notifier for each pending media upload + pub(super) notifiers: std::sync::Mutex>>, + /// Save the ratelimiter for each user + pub(super) ratelimiter: + std::sync::Mutex>, +} pub struct Service { url_preview_mutex: MutexMap, + mxc_state: MXCState, pub(super) db: Data, services: Arc, } @@ -50,6 +64,10 @@ impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { url_preview_mutex: MutexMap::new(), + mxc_state: MXCState { + notifiers: std::sync::Mutex::new(std::collections::HashMap::new()), + ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()), + }, db: Data::new(args.db), services: args.services.clone(), })) @@ -65,6 +83,119 @@ impl crate::Service for Service { } impl Service { + /// Create a pending media upload ID. + pub async fn create_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) -> Result<()> { + let config = &self.services.server.config; + + // Rate limiting (rc_media_create) + let rate = config.rc_media_create_per_second as f64; + let burst = config.rc_media_create_burst_count as f64; + + // Check rate limiting + if rate > 0.0 && burst > 0.0 { + let now = Instant::now(); + let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap(); + + let (last_time, tokens) = ratelimiter + .entry(user.to_owned()) + .or_insert_with(|| (now, burst)); + + let elapsed = now.duration_since(*last_time).as_secs_f64(); + let new_tokens = (*tokens + elapsed * rate).min(burst); + + if new_tokens >= 1.0 { + *last_time = now; + *tokens = new_tokens - 1.0; + } else { + return Err(tuwunel_core::Error::Request( + ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None }, + "Too many pending media creation requests.".into(), + http::StatusCode::TOO_MANY_REQUESTS, + )); + } + } + + let max_uploads = config.max_pending_media_uploads; + let (current_uploads, earliest_expiration) = + self.db.count_pending_mxc_for_user(user).await; + + // Check if the user has reached the maximum number of pending media uploads + if current_uploads >= max_uploads { + let retry_after = earliest_expiration.saturating_sub( + SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() as u64, + ); + return Err(tuwunel_core::Error::Request( + ruma::api::client::error::ErrorKind::LimitExceeded { + retry_after: Some(ruma::api::client::error::RetryAfter::Delay( + Duration::from_millis(retry_after), + )), + }, + "Maximum number of pending media uploads reached.".into(), + http::StatusCode::TOO_MANY_REQUESTS, + )); + } + + self.db + .insert_pending_mxc(mxc, user, unused_expires_at); + Ok(()) + } + + /// Uploads content to a pending media ID. + pub async fn upload_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + content_disposition: Option<&ContentDisposition>, + content_type: Option<&str>, + file: &[u8], + ) -> Result<()> { + let pending = self.db.search_pending_mxc(mxc).await; + let Some((owner_id, expires_at)) = pending else { + return Err!(Request(NotFound("Media not found"))); + }; + + if owner_id != user { + return Err!(Request(Forbidden("You did not create this media ID"))); + } + + let current_time = SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as u64; + + if expires_at < current_time { + return Err!(Request(NotFound("Pending media ID expired"))); + } + + self.create(mxc, Some(user), content_disposition, content_type, file) + .await?; + + self.db.remove_pending_mxc(mxc); + + if let Some(notifier) = self + .mxc_state + .notifiers + .lock() + .unwrap() + .get(&mxc.to_string()) + .cloned() + { + notifier.notify_waiters(); + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.remove(&mxc.to_string()); + } + + Ok(()) + } + /// Uploads a file. pub async fn create( &self, @@ -168,6 +299,71 @@ impl Service { } } + /// Download a file and wait up to a timeout_ms if it is pending. + pub async fn get_with_timeout( + &self, + mxc: &Mxc<'_>, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get(mxc).await? { + return Ok(Some(meta)); + } + + let pending = self.db.search_pending_mxc(mxc).await; + if pending.is_none() { + return Ok(None); + } + + let notifier = { + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.entry(mxc.to_string()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone() + }; + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get(mxc).await + } + + /// Download a thumbnail and wait up to a timeout_ms if it is pending. + pub async fn get_thumbnail_with_timeout( + &self, + mxc: &Mxc<'_>, + dim: &Dim, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get_thumbnail(mxc, dim).await? { + return Ok(Some(meta)); + } + + let pending = self.db.search_pending_mxc(mxc).await; + if pending.is_none() { + return Ok(None); + } + + let notifier = { + let mut map = self.mxc_state.notifiers.lock().unwrap(); + map.entry(mxc.to_string()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone() + }; + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get_thumbnail(mxc, dim).await + } + /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { let all_keys = self.db.get_all_media_keys().await; diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 7a2ae8c9..f0b6c7bd 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -320,6 +320,22 @@ # #max_request_size = 24 MiB +# Maximum number of concurrently pending (asynchronous) media uploads a user can have. +# +#max_pending_media_uploads = 5 + +# The time in seconds before an unused pending MXC URI expires and is removed. +# +#media_create_unused_expiration_time = 86400 (24 hours) + +# The maximum number of media create requests per second allowed from a single user. +# +#rc_media_create_per_second = 10 + +# The maximum burst count for media create requests from a single user. +# +#rc_media_create_burst_count = 50 + # This item is undocumented. Please contribute documentation for it. # #max_fetch_prev_events = 192