Merge branch 'donjuanplatinum/msc2246'

This commit is contained in:
Jason Volk
2026-03-10 02:37:12 +00:00
10 changed files with 430 additions and 24 deletions

View File

@@ -4,18 +4,21 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use reqwest::Url;
use ruma::{
Mxc, UserId,
MilliSecondsSinceUnixEpoch, Mxc, UserId,
api::client::{
authenticated_media::{
get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
get_media_preview,
},
media::create_content,
media::{create_content, create_content_async, create_mxc_uri},
},
};
use tuwunel_core::{
Err, Result, err,
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
utils::{
self, content_disposition::make_content_disposition, math::ruma_from_usize,
time::now_millis,
},
};
use tuwunel_service::{
Services,
@@ -80,6 +83,78 @@ pub(crate) async fn create_content_route(
})
}
/// # `POST /_matrix/media/v1/create`
///
/// Create a new MXC URI without content.
#[tracing::instrument(
name = "media_create_mxc",
level = "debug",
skip_all,
fields(%client),
)]
pub(crate) async fn create_mxc_uri_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_mxc_uri::v1::Request>,
) -> Result<create_mxc_uri::v1::Response> {
let user = body.sender_user();
let mxc = Mxc {
server_name: services.globals.server_name(),
media_id: &utils::random_string(MXC_LENGTH),
};
// safe because even if it overflows, it will be greater than the current time
// and the unused media will be deleted anyway
let unused_expires_at = now_millis().saturating_add(
services
.server
.config
.media_create_unused_expiration_time
.saturating_mul(1000),
);
services
.media
.create_pending(&mxc, user, unused_expires_at)
.await?;
Ok(create_mxc_uri::v1::Response {
content_uri: mxc.to_string().into(),
unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch),
})
}
/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}`
///
/// Upload content to a MXC URI that was created earlier.
#[tracing::instrument(
name = "media_upload_async",
level = "debug",
skip_all,
fields(%client),
)]
pub(crate) async fn create_content_async_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content_async::v3::Request>,
) -> Result<create_content_async::v3::Response> {
let user = body.sender_user();
let mxc = Mxc {
server_name: &body.server_name,
media_id: &body.media_id,
};
let filename = body.filename.as_deref();
let content_type = body.content_type.as_deref();
let content_disposition = make_content_disposition(None, content_type, filename);
services
.media
.upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file)
.await?;
Ok(create_content_async::v3::Response {})
}
/// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}`
///
/// Load media thumbnail from our server or over federation.
@@ -296,7 +371,11 @@ async fn fetch_thumbnail_meta(
timeout_ms: Duration,
dim: &Dim,
) -> Result<FileMeta> {
if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? {
if let Some(filemeta) = services
.media
.get_thumbnail_with_timeout(mxc, dim, timeout_ms)
.await?
{
return Ok(filemeta);
}
@@ -316,7 +395,11 @@ async fn fetch_file_meta(
user: &UserId,
timeout_ms: Duration,
) -> Result<FileMeta> {
if let Some(filemeta) = services.media.get(mxc).await? {
if let Some(filemeta) = services
.media
.get_with_timeout(mxc, timeout_ms)
.await?
{
return Ok(filemeta);
}

View File

@@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route(
media_id: &body.media_id,
};
match services.media.get(&mxc).await? {
match services
.media
.get_with_timeout(&mxc, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,
@@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route(
media_id: &body.media_id,
};
match services.media.get(&mxc).await? {
match services
.media
.get_with_timeout(&mxc, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,
@@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route(
};
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
match services.media.get_thumbnail(&mxc, &dim).await? {
match services
.media
.get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,

View File

@@ -158,6 +158,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::turn_server_route)
.ruma_route(&client::send_event_to_device_route)
.ruma_route(&client::create_content_route)
.ruma_route(&client::create_mxc_uri_route)
.ruma_route(&client::create_content_async_route)
.ruma_route(&client::get_content_thumbnail_route)
.ruma_route(&client::get_content_route)
.ruma_route(&client::get_content_as_filename_route)

View File

@@ -408,6 +408,33 @@ pub struct Config {
#[serde(default = "default_max_request_size")]
pub max_request_size: usize,
/// Maximum number of concurrently pending (asynchronous) media uploads a
/// user can have.
///
/// default: 5
#[serde(default = "default_max_pending_media_uploads")]
pub max_pending_media_uploads: usize,
/// The time in seconds before an unused pending MXC URI expires and is
/// removed.
///
/// default: 86400 (24 hours)
#[serde(default = "default_media_create_unused_expiration_time")]
pub media_create_unused_expiration_time: u64,
/// The maximum number of media create requests per second allowed from a
/// single user.
///
/// default: 10
#[serde(default = "default_media_rc_create_per_second")]
pub media_rc_create_per_second: u32,
/// The maximum burst count for media create requests from a single user.
///
/// default: 50
#[serde(default = "default_media_rc_create_burst_count")]
pub media_rc_create_burst_count: u32,
/// default: 192
#[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16,
@@ -3245,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 }
fn default_ip_lookup_strategy() -> u8 { 5 }
fn default_max_request_size() -> usize { 24 * 1024 * 1024 }
fn default_max_pending_media_uploads() -> usize { 5 }
fn default_media_create_unused_expiration_time() -> u64 { 86400 }
fn default_media_rc_create_per_second() -> u32 { 10 }
fn default_media_rc_create_burst_count() -> u32 { 50 }
fn default_request_conn_timeout() -> u64 { 10 }

View File

@@ -57,12 +57,18 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode {
use ErrorKind::*;
match kind {
// 504
| NotYetUploaded => StatusCode::GATEWAY_TIMEOUT,
// 429
| LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
// 413
| TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
// 409
| CannotOverwriteMedia => StatusCode::CONFLICT,
// 405
| Unrecognized => StatusCode::METHOD_NOT_ALLOWED,

View File

@@ -128,6 +128,11 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "mediaid_file",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "mediaid_pending",
ttl: 60 * 60 * 24 * 7,
..descriptor::RANDOM_SMALL_CACHE
},
Descriptor {
name: "mediaid_user",
..descriptor::RANDOM_SMALL

View File

@@ -1,17 +1,22 @@
use std::{sync::Arc, time::Duration};
use futures::{StreamExt, pin_mut};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition};
use tuwunel_core::{
Err, Result, debug, debug_info, err,
utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes},
utils::{
ReadyExt, str_from_bytes,
stream::{TryExpect, TryIgnore},
string_from_bytes,
},
};
use tuwunel_database::{Database, Interfix, Map, serialize_key};
use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key};
use super::{preview::UrlPreviewData, thumbnail::Dim};
pub(crate) struct Data {
mediaid_file: Arc<Map>,
mediaid_pending: Arc<Map>,
mediaid_user: Arc<Map>,
url_previews: Arc<Map>,
}
@@ -27,6 +32,7 @@ impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
mediaid_file: db["mediaid_file"].clone(),
mediaid_pending: db["mediaid_pending"].clone(),
mediaid_user: db["mediaid_user"].clone(),
url_previews: db["url_previews"].clone(),
}
@@ -52,6 +58,52 @@ impl Data {
Ok(key.to_vec())
}
/// Insert a pending MXC URI into the database
pub(super) fn insert_pending_mxc(
&self,
mxc: &Mxc<'_>,
user: &UserId,
unused_expires_at: u64,
) {
let value = (unused_expires_at, user);
debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending");
self.mediaid_pending.put(mxc, value);
}
/// Remove a pending MXC URI from the database
pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) { self.mediaid_pending.del(mxc); }
/// Count the number of pending MXC URIs for a specific user
pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) {
type KeyVal<'a> = (Ignore, (u64, &'a UserId));
self.mediaid_pending
.stream()
.expect_ok()
.ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id)
.ready_fold(
(0_usize, u64::MAX),
|(count, earliest_expiration), (_, (expires_at, _))| {
(count.saturating_add(1), earliest_expiration.min(expires_at))
},
)
.await
}
/// Search for a pending MXC URI in the database
pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> {
type Value<'a> = (u64, OwnedUserId);
self.mediaid_pending
.qry(mxc)
.await
.deserialized()
.map(|(expires_at, user_id): Value<'_>| (user_id, expires_at))
.inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending"))
.map_err(|e| err!(Request(NotFound("Pending not found or error: {e}"))))
}
pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) {
debug!("MXC URI: {mxc}");

View File

@@ -5,18 +5,29 @@ mod preview;
mod remote;
mod tests;
mod thumbnail;
use std::{path::PathBuf, sync::Arc, time::SystemTime};
use std::{
collections::HashMap,
path::PathBuf,
sync::{Arc, Mutex},
time::{Duration, Instant, SystemTime},
};
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use http::StatusCode;
use ruma::{
Mxc, OwnedMxcUri, OwnedUserId, UserId,
api::client::error::{ErrorKind, RetryAfter},
http_headers::ContentDisposition,
};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Notify,
};
use tuwunel_core::{
Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
utils::{self, MutexMap},
Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
utils::{self, MutexMap, time::now_millis},
warn,
};
@@ -30,10 +41,19 @@ pub struct FileMeta {
pub content_disposition: Option<ContentDisposition>,
}
/// For MSC2246
struct MXCState {
/// Save the notifier for each pending media upload
notifiers: Mutex<HashMap<OwnedMxcUri, Arc<Notify>>>,
/// Save the ratelimiter for each user
ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
}
pub struct Service {
url_preview_mutex: MutexMap<String, ()>,
pub(super) db: Data,
services: Arc<crate::services::OnceServices>,
url_preview_mutex: MutexMap<String, ()>,
mxc_state: MXCState,
}
/// generated MXC ID (`media-id`) length
@@ -49,9 +69,13 @@ pub const CORP_CROSS_ORIGIN: &str = "cross-origin";
impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
url_preview_mutex: MutexMap::new(),
db: Data::new(args.db),
services: args.services.clone(),
url_preview_mutex: MutexMap::new(),
mxc_state: MXCState {
notifiers: Mutex::new(HashMap::new()),
ratelimiter: Mutex::new(HashMap::new()),
},
}))
}
@@ -65,6 +89,113 @@ impl crate::Service for Service {
}
impl Service {
/// Create a pending media upload ID.
#[tracing::instrument(level = "debug", skip(self))]
pub async fn create_pending(
&self,
mxc: &Mxc<'_>,
user: &UserId,
unused_expires_at: u64,
) -> Result {
let config = &self.services.server.config;
// Rate limiting (rc_media_create)
let rate = f64::from(config.media_rc_create_per_second);
let burst = f64::from(config.media_rc_create_burst_count);
// Check rate limiting
if rate > 0.0 && burst > 0.0 {
let now = Instant::now();
let mut ratelimiter = self.mxc_state.ratelimiter.lock()?;
let (last_time, tokens) = ratelimiter
.entry(user.to_owned())
.or_insert_with(|| (now, burst));
let elapsed = now.duration_since(*last_time).as_secs_f64();
let new_tokens = elapsed.mul_add(rate, *tokens).min(burst);
if new_tokens >= 1.0 {
*last_time = now;
*tokens = new_tokens - 1.0;
} else {
return Err(Error::Request(
ErrorKind::LimitExceeded { retry_after: None },
"Too many pending media creation requests.".into(),
StatusCode::TOO_MANY_REQUESTS,
));
}
}
let max_uploads = config.max_pending_media_uploads;
let (current_uploads, earliest_expiration) =
self.db.count_pending_mxc_for_user(user).await;
// Check if the user has reached the maximum number of pending media uploads
if current_uploads >= max_uploads {
let retry_after = earliest_expiration.saturating_sub(now_millis());
return Err(Error::Request(
ErrorKind::LimitExceeded {
retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))),
},
"Maximum number of pending media uploads reached.".into(),
StatusCode::TOO_MANY_REQUESTS,
));
}
self.db
.insert_pending_mxc(mxc, user, unused_expires_at);
Ok(())
}
/// Uploads content to a pending media ID.
#[tracing::instrument(level = "debug", skip(self))]
pub async fn upload_pending(
&self,
mxc: &Mxc<'_>,
user: &UserId,
content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>,
file: &[u8],
) -> Result {
let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
if self.get_metadata(mxc).await.is_some() {
return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
}
return Err!(Request(NotFound("Media not found")));
};
if owner_id != user {
return Err!(Request(Forbidden("You did not create this media ID")));
}
let current_time = now_millis();
if expires_at < current_time {
return Err!(Request(NotFound("Pending media ID expired")));
}
self.create(mxc, Some(user), content_disposition, content_type, file)
.await?;
self.db.remove_pending_mxc(mxc);
let mxc_uri: OwnedMxcUri = mxc.to_string().into();
if let Some(notifier) = self
.mxc_state
.notifiers
.lock()?
.get(&mxc_uri)
.cloned()
{
notifier.notify_waiters();
self.mxc_state.notifiers.lock()?.remove(&mxc_uri);
}
Ok(())
}
/// Uploads a file.
pub async fn create(
&self,
@@ -168,6 +299,71 @@ impl Service {
}
}
/// Download a file and wait up to a timeout_ms if it is pending.
pub async fn get_with_timeout(
&self,
mxc: &Mxc<'_>,
timeout_duration: Duration,
) -> Result<Option<FileMeta>> {
if let Some(meta) = self.get(mxc).await? {
return Ok(Some(meta));
}
let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
return Ok(None);
};
let notifier = self
.mxc_state
.notifiers
.lock()?
.entry(mxc.to_string().into())
.or_insert_with(|| Arc::new(Notify::new()))
.clone();
if tokio::time::timeout(timeout_duration, notifier.notified())
.await
.is_err()
{
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
}
self.get(mxc).await
}
/// Download a thumbnail and wait up to a timeout_ms if it is pending.
pub async fn get_thumbnail_with_timeout(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
timeout_duration: Duration,
) -> Result<Option<FileMeta>> {
if let Some(meta) = self.get_thumbnail(mxc, dim).await? {
return Ok(Some(meta));
}
let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
return Ok(None);
};
let notifier = self
.mxc_state
.notifiers
.lock()?
.entry(mxc.to_string().into())
.or_insert_with(|| Arc::new(Notify::new()))
.clone();
if tokio::time::timeout(timeout_duration, notifier.notified())
.await
.is_err()
{
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
}
self.get_thumbnail(mxc, dim).await
}
/// Gets all the MXC URIs in our media database
pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
let all_keys = self.db.get_all_media_keys().await;

View File

@@ -9,13 +9,13 @@
{"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty"}
{"Action":"skip","Test":"TestArchivedRoomsHistory/timeline_is_empty/incremental_sync"}
{"Action":"pass","Test":"TestArchivedRoomsHistory/timeline_is_empty/initial_sync"}
{"Action":"fail","Test":"TestAsyncUpload"}
{"Action":"fail","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"}
{"Action":"fail","Test":"TestAsyncUpload/Create_media"}
{"Action":"fail","Test":"TestAsyncUpload/Download_media"}
{"Action":"fail","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"}
{"Action":"fail","Test":"TestAsyncUpload/Not_yet_uploaded"}
{"Action":"fail","Test":"TestAsyncUpload/Upload_media"}
{"Action":"pass","Test":"TestAsyncUpload"}
{"Action":"pass","Test":"TestAsyncUpload/Cannot_upload_to_a_media_ID_that_has_already_been_uploaded_to"}
{"Action":"pass","Test":"TestAsyncUpload/Create_media"}
{"Action":"pass","Test":"TestAsyncUpload/Download_media"}
{"Action":"pass","Test":"TestAsyncUpload/Download_media_over__matrix/client/v1/media/download"}
{"Action":"pass","Test":"TestAsyncUpload/Not_yet_uploaded"}
{"Action":"pass","Test":"TestAsyncUpload/Upload_media"}
{"Action":"pass","Test":"TestAvatarUrlUpdate"}
{"Action":"pass","Test":"TestBannedUserCannotSendJoin"}
{"Action":"skip","Test":"TestCanRegisterAdmin"}

View File

@@ -320,6 +320,25 @@
#
#max_request_size = 24 MiB
# Maximum number of concurrently pending (asynchronous) media uploads a
# user can have.
#
#max_pending_media_uploads = 5
# The time in seconds before an unused pending MXC URI expires and is
# removed.
#
#media_create_unused_expiration_time = 86400 (24 hours)
# The maximum number of media create requests per second allowed from a
# single user.
#
#media_rc_create_per_second = 10
# The maximum burst count for media create requests from a single user.
#
#media_rc_create_burst_count = 50
# This item is undocumented. Please contribute documentation for it.
#
#max_fetch_prev_events = 192