diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 9b1cb956..60b37bf1 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -4,18 +4,21 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use reqwest::Url; use ruma::{ - Mxc, UserId, + MilliSecondsSinceUnixEpoch, Mxc, UserId, api::client::{ authenticated_media::{ get_content, get_content_as_filename, get_content_thumbnail, get_media_config, get_media_preview, }, - media::create_content, + media::{create_content, create_content_async, create_mxc_uri}, }, }; use tuwunel_core::{ Err, Result, err, - utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize}, + utils::{ + self, content_disposition::make_content_disposition, math::ruma_from_usize, + time::now_millis, + }, }; use tuwunel_service::{ Services, @@ -80,6 +83,78 @@ pub(crate) async fn create_content_route( }) } +/// # `POST /_matrix/media/v1/create` +/// +/// Create a new MXC URI without content. +#[tracing::instrument( + name = "media_create_mxc", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_mxc_uri_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: services.globals.server_name(), + media_id: &utils::random_string(MXC_LENGTH), + }; + + // safe because even if it overflows, it will be greater than the current time + // and the unused media will be deleted anyway + let unused_expires_at = now_millis().saturating_add( + services + .server + .config + .media_create_unused_expiration_time + .saturating_mul(1000), + ); + services + .media + .create_pending(&mxc, user, unused_expires_at) + .await?; + + Ok(create_mxc_uri::v1::Response { + content_uri: mxc.to_string().into(), + unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch), + }) +} + +/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}` +/// +/// Upload content to a MXC URI that was created earlier. +#[tracing::instrument( + name = "media_upload_async", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn create_content_async_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let user = body.sender_user(); + let mxc = Mxc { + server_name: &body.server_name, + media_id: &body.media_id, + }; + + let filename = body.filename.as_deref(); + let content_type = body.content_type.as_deref(); + let content_disposition = make_content_disposition(None, content_type, filename); + + services + .media + .upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file) + .await?; + + Ok(create_content_async::v3::Response {}) +} + /// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}` /// /// Load media thumbnail from our server or over federation. @@ -296,7 +371,11 @@ async fn fetch_thumbnail_meta( timeout_ms: Duration, dim: &Dim, ) -> Result { - if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? { + if let Some(filemeta) = services + .media + .get_thumbnail_with_timeout(mxc, dim, timeout_ms) + .await? + { return Ok(filemeta); } @@ -316,7 +395,11 @@ async fn fetch_file_meta( user: &UserId, timeout_ms: Duration, ) -> Result { - if let Some(filemeta) = services.media.get(mxc).await? { + if let Some(filemeta) = services + .media + .get_with_timeout(mxc, timeout_ms) + .await? + { return Ok(filemeta); } diff --git a/src/api/client/media_legacy.rs b/src/api/client/media_legacy.rs index 96e8d1d9..4ffe0be7 100644 --- a/src/api/client/media_legacy.rs +++ b/src/api/client/media_legacy.rs @@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route( media_id: &body.media_id, }; - match services.media.get(&mxc).await? { + match services + .media + .get_with_timeout(&mxc, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, @@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route( }; let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?; - match services.media.get_thumbnail(&mxc, &dim).await? { + match services + .media + .get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms) + .await? + { | Some(FileMeta { content, content_type, diff --git a/src/api/router.rs b/src/api/router.rs index 8b0e4ace..7263f4ea 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -158,6 +158,8 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::turn_server_route) .ruma_route(&client::send_event_to_device_route) .ruma_route(&client::create_content_route) + .ruma_route(&client::create_mxc_uri_route) + .ruma_route(&client::create_content_async_route) .ruma_route(&client::get_content_thumbnail_route) .ruma_route(&client::get_content_route) .ruma_route(&client::get_content_as_filename_route) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 052bf544..67514f96 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -408,6 +408,33 @@ pub struct Config { #[serde(default = "default_max_request_size")] pub max_request_size: usize, + /// Maximum number of concurrently pending (asynchronous) media uploads a + /// user can have. + /// + /// default: 5 + #[serde(default = "default_max_pending_media_uploads")] + pub max_pending_media_uploads: usize, + + /// The time in seconds before an unused pending MXC URI expires and is + /// removed. + /// + /// default: 86400 (24 hours) + #[serde(default = "default_media_create_unused_expiration_time")] + pub media_create_unused_expiration_time: u64, + + /// The maximum number of media create requests per second allowed from a + /// single user. + /// + /// default: 10 + #[serde(default = "default_media_rc_create_per_second")] + pub media_rc_create_per_second: u32, + + /// The maximum burst count for media create requests from a single user. + /// + /// default: 50 + #[serde(default = "default_media_rc_create_burst_count")] + pub media_rc_create_burst_count: u32, + /// default: 192 #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -3245,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 } fn default_ip_lookup_strategy() -> u8 { 5 } fn default_max_request_size() -> usize { 24 * 1024 * 1024 } +fn default_max_pending_media_uploads() -> usize { 5 } +fn default_media_create_unused_expiration_time() -> u64 { 86400 } +fn default_media_rc_create_per_second() -> u32 { 10 } +fn default_media_rc_create_burst_count() -> u32 { 50 } fn default_request_conn_timeout() -> u64 { 10 } diff --git a/src/core/error/response.rs b/src/core/error/response.rs index c41d542c..cc15038e 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -57,12 +57,18 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { use ErrorKind::*; match kind { + // 504 + | NotYetUploaded => StatusCode::GATEWAY_TIMEOUT, + // 429 | LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, // 413 | TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + // 409 + | CannotOverwriteMedia => StatusCode::CONFLICT, + // 405 | Unrecognized => StatusCode::METHOD_NOT_ALLOWED, diff --git a/src/database/maps.rs b/src/database/maps.rs index 51b7d097..5ec270bb 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -128,6 +128,11 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "mediaid_file", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "mediaid_pending", + ttl: 60 * 60 * 24 * 7, + ..descriptor::RANDOM_SMALL_CACHE + }, Descriptor { name: "mediaid_user", ..descriptor::RANDOM_SMALL diff --git a/src/service/media/data.rs b/src/service/media/data.rs index aa2e8ba4..0e26fb69 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,17 +1,22 @@ use std::{sync::Arc, time::Duration}; use futures::{StreamExt, pin_mut}; -use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; +use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition}; use tuwunel_core::{ Err, Result, debug, debug_info, err, - utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, + utils::{ + ReadyExt, str_from_bytes, + stream::{TryExpect, TryIgnore}, + string_from_bytes, + }, }; -use tuwunel_database::{Database, Interfix, Map, serialize_key}; +use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key}; use super::{preview::UrlPreviewData, thumbnail::Dim}; pub(crate) struct Data { mediaid_file: Arc, + mediaid_pending: Arc, mediaid_user: Arc, url_previews: Arc, } @@ -27,6 +32,7 @@ impl Data { pub(super) fn new(db: &Arc) -> Self { Self { mediaid_file: db["mediaid_file"].clone(), + mediaid_pending: db["mediaid_pending"].clone(), mediaid_user: db["mediaid_user"].clone(), url_previews: db["url_previews"].clone(), } @@ -52,6 +58,52 @@ impl Data { Ok(key.to_vec()) } + /// Insert a pending MXC URI into the database + pub(super) fn insert_pending_mxc( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) { + let value = (unused_expires_at, user); + debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending"); + + self.mediaid_pending.put(mxc, value); + } + + /// Remove a pending MXC URI from the database + pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending.del(mxc); } + + /// Count the number of pending MXC URIs for a specific user + pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) { + type KeyVal<'a> = (Ignore, (u64, &'a UserId)); + + self.mediaid_pending + .stream() + .expect_ok() + .ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id) + .ready_fold( + (0_usize, u64::MAX), + |(count, earliest_expiration), (_, (expires_at, _))| { + (count.saturating_add(1), earliest_expiration.min(expires_at)) + }, + ) + .await + } + + /// Search for a pending MXC URI in the database + pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> { + type Value<'a> = (u64, OwnedUserId); + + self.mediaid_pending + .qry(mxc) + .await + .deserialized() + .map(|(expires_at, user_id): Value<'_>| (user_id, expires_at)) + .inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending")) + .map_err(|e| err!(Request(NotFound("Pending not found or error: {e}")))) + } + pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) { debug!("MXC URI: {mxc}"); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 9cd20600..d0698999 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -5,18 +5,29 @@ mod preview; mod remote; mod tests; mod thumbnail; -use std::{path::PathBuf, sync::Arc, time::SystemTime}; +use std::{ + collections::HashMap, + path::PathBuf, + sync::{Arc, Mutex}, + time::{Duration, Instant, SystemTime}, +}; use async_trait::async_trait; use base64::{Engine as _, engine::general_purpose}; -use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; +use http::StatusCode; +use ruma::{ + Mxc, OwnedMxcUri, OwnedUserId, UserId, + api::client::error::{ErrorKind, RetryAfter}, + http_headers::ContentDisposition, +}; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, + sync::Notify, }; use tuwunel_core::{ - Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, - utils::{self, MutexMap}, + Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, + utils::{self, MutexMap, time::now_millis}, warn, }; @@ -30,10 +41,19 @@ pub struct FileMeta { pub content_disposition: Option, } +/// For MSC2246 +struct MXCState { + /// Save the notifier for each pending media upload + notifiers: Mutex>>, + /// Save the ratelimiter for each user + ratelimiter: Mutex>, +} + pub struct Service { - url_preview_mutex: MutexMap, pub(super) db: Data, services: Arc, + url_preview_mutex: MutexMap, + mxc_state: MXCState, } /// generated MXC ID (`media-id`) length @@ -49,9 +69,13 @@ pub const CORP_CROSS_ORIGIN: &str = "cross-origin"; impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { - url_preview_mutex: MutexMap::new(), db: Data::new(args.db), services: args.services.clone(), + url_preview_mutex: MutexMap::new(), + mxc_state: MXCState { + notifiers: Mutex::new(HashMap::new()), + ratelimiter: Mutex::new(HashMap::new()), + }, })) } @@ -65,6 +89,113 @@ impl crate::Service for Service { } impl Service { + /// Create a pending media upload ID. + #[tracing::instrument(level = "debug", skip(self))] + pub async fn create_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + unused_expires_at: u64, + ) -> Result { + let config = &self.services.server.config; + + // Rate limiting (rc_media_create) + let rate = f64::from(config.media_rc_create_per_second); + let burst = f64::from(config.media_rc_create_burst_count); + + // Check rate limiting + if rate > 0.0 && burst > 0.0 { + let now = Instant::now(); + let mut ratelimiter = self.mxc_state.ratelimiter.lock()?; + + let (last_time, tokens) = ratelimiter + .entry(user.to_owned()) + .or_insert_with(|| (now, burst)); + + let elapsed = now.duration_since(*last_time).as_secs_f64(); + let new_tokens = elapsed.mul_add(rate, *tokens).min(burst); + + if new_tokens >= 1.0 { + *last_time = now; + *tokens = new_tokens - 1.0; + } else { + return Err(Error::Request( + ErrorKind::LimitExceeded { retry_after: None }, + "Too many pending media creation requests.".into(), + StatusCode::TOO_MANY_REQUESTS, + )); + } + } + + let max_uploads = config.max_pending_media_uploads; + let (current_uploads, earliest_expiration) = + self.db.count_pending_mxc_for_user(user).await; + + // Check if the user has reached the maximum number of pending media uploads + if current_uploads >= max_uploads { + let retry_after = earliest_expiration.saturating_sub(now_millis()); + return Err(Error::Request( + ErrorKind::LimitExceeded { + retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))), + }, + "Maximum number of pending media uploads reached.".into(), + StatusCode::TOO_MANY_REQUESTS, + )); + } + + self.db + .insert_pending_mxc(mxc, user, unused_expires_at); + + Ok(()) + } + + /// Uploads content to a pending media ID. + #[tracing::instrument(level = "debug", skip(self))] + pub async fn upload_pending( + &self, + mxc: &Mxc<'_>, + user: &UserId, + content_disposition: Option<&ContentDisposition>, + content_type: Option<&str>, + file: &[u8], + ) -> Result { + let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else { + if self.get_metadata(mxc).await.is_some() { + return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); + } + + return Err!(Request(NotFound("Media not found"))); + }; + + if owner_id != user { + return Err!(Request(Forbidden("You did not create this media ID"))); + } + + let current_time = now_millis(); + if expires_at < current_time { + return Err!(Request(NotFound("Pending media ID expired"))); + } + + self.create(mxc, Some(user), content_disposition, content_type, file) + .await?; + + self.db.remove_pending_mxc(mxc); + + let mxc_uri: OwnedMxcUri = mxc.to_string().into(); + if let Some(notifier) = self + .mxc_state + .notifiers + .lock()? + .get(&mxc_uri) + .cloned() + { + notifier.notify_waiters(); + self.mxc_state.notifiers.lock()?.remove(&mxc_uri); + } + + Ok(()) + } + /// Uploads a file. pub async fn create( &self, @@ -168,6 +299,71 @@ impl Service { } } + /// Download a file and wait up to a timeout_ms if it is pending. + pub async fn get_with_timeout( + &self, + mxc: &Mxc<'_>, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get(mxc).await? { + return Ok(Some(meta)); + } + + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { + return Ok(None); + }; + + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get(mxc).await + } + + /// Download a thumbnail and wait up to a timeout_ms if it is pending. + pub async fn get_thumbnail_with_timeout( + &self, + mxc: &Mxc<'_>, + dim: &Dim, + timeout_duration: Duration, + ) -> Result> { + if let Some(meta) = self.get_thumbnail(mxc, dim).await? { + return Ok(Some(meta)); + } + + let Ok(_pending) = self.db.search_pending_mxc(mxc).await else { + return Ok(None); + }; + + let notifier = self + .mxc_state + .notifiers + .lock()? + .entry(mxc.to_string().into()) + .or_insert_with(|| Arc::new(Notify::new())) + .clone(); + + if tokio::time::timeout(timeout_duration, notifier.notified()) + .await + .is_err() + { + return Err!(Request(NotYetUploaded("Media has not been uploaded yet"))); + } + + self.get_thumbnail(mxc, dim).await + } + /// Gets all the MXC URIs in our media database pub async fn get_all_mxcs(&self) -> Result> { let all_keys = self.db.get_all_media_keys().await; diff --git a/tests/complement/results.jsonl b/tests/complement/results.jsonl index 810bb021..8b6bb70b 100644 --- a/tests/complement/results.jsonl +++ b/tests/complement/results.jsonl @@ -9,13 +9,13 @@ {"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty"} {"Action":"skip","Test":"TestArchivedRoomsHistory/timeline_is_empty/incremental_sync"} {"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty/initial_sync"} -{"Action":"fail","Test":"TestAsyncUpload"} -{"Action":"fail","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"} -{"Action":"fail","Test":"TestAsyncUpload/Create_media"} -{"Action":"fail","Test":"TestAsyncUpload/Download_media"} -{"Action":"fail","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"} -{"Action":"fail","Test":"TestAsyncUpload/Not_yet_uploaded"} -{"Action":"fail","Test":"TestAsyncUpload/Upload_media"} +{"Action":"pass","Test":"TestAsyncUpload"} +{"Action":"pass","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"} +{"Action":"pass","Test":"TestAsyncUpload/Create_media"} +{"Action":"pass","Test":"TestAsyncUpload/Download_media"} +{"Action":"pass","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"} +{"Action":"pass","Test":"TestAsyncUpload/Not_yet_uploaded"} +{"Action":"pass","Test":"TestAsyncUpload/Upload_media"} {"Action":"pass","Test":"TestAvatarUrlUpdate"} {"Action":"pass","Test":"TestBannedUserCannotSendJoin"} {"Action":"skip","Test":"TestCanRegisterAdmin"} diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 7a2ae8c9..dc97b8d8 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -320,6 +320,25 @@ # #max_request_size = 24 MiB +# Maximum number of concurrently pending (asynchronous) media uploads a +# user can have. +# +#max_pending_media_uploads = 5 + +# The time in seconds before an unused pending MXC URI expires and is +# removed. +# +#media_create_unused_expiration_time = 86400 (24 hours) + +# The maximum number of media create requests per second allowed from a +# single user. +# +#media_rc_create_per_second = 10 + +# The maximum burst count for media create requests from a single user. +# +#media_rc_create_burst_count = 50 + # This item is undocumented. Please contribute documentation for it. # #max_fetch_prev_events = 192