add MSC2246 support

This commit is contained in:
Donjuanplatinum
2026-03-01 05:15:09 +08:00
parent d2836e9f50
commit 2b81e189cb
8 changed files with 461 additions and 8 deletions

View File

@@ -4,13 +4,13 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use reqwest::Url;
use ruma::{
Mxc, UserId,
MilliSecondsSinceUnixEpoch, Mxc, UserId,
api::client::{
authenticated_media::{
get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
get_media_preview,
},
media::create_content,
media::{create_content, create_content_async, create_mxc_uri},
},
};
use tuwunel_core::{
@@ -80,6 +80,80 @@ pub(crate) async fn create_content_route(
})
}
/// # `POST /_matrix/media/v1/create`
///
/// Create a new MXC URI without content.
#[tracing::instrument(
name = "media_create_mxc",
level = "debug",
skip_all,
fields(%client),
)]
pub(crate) async fn create_mxc_uri_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_mxc_uri::v1::Request>,
) -> Result<create_mxc_uri::v1::Response> {
let user = body.sender_user();
let mxc = Mxc {
server_name: services.globals.server_name(),
media_id: &utils::random_string(MXC_LENGTH),
};
let unused_expires_at = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64)
.saturating_add(
services
.server
.config
.media_create_unused_expiration_time
* 1000,
);
services
.media
.create_pending(&mxc, user, unused_expires_at)
.await?;
Ok(create_mxc_uri::v1::Response {
content_uri: mxc.to_string().into(),
unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch),
})
}
/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}`
///
/// Upload content to a MXC URI that was created earlier.
#[tracing::instrument(
name = "media_upload_async",
level = "debug",
skip_all,
fields(%client),
)]
pub(crate) async fn create_content_async_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content_async::v3::Request>,
) -> Result<create_content_async::v3::Response> {
let user = body.sender_user();
let mxc = Mxc {
server_name: &body.server_name,
media_id: &body.media_id,
};
let filename = body.filename.as_deref();
let content_type = body.content_type.as_deref();
let content_disposition = make_content_disposition(None, content_type, filename);
services
.media
.upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file)
.await?;
Ok(create_content_async::v3::Response {})
}
/// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}`
///
/// Load media thumbnail from our server or over federation.
@@ -296,7 +370,11 @@ async fn fetch_thumbnail_meta(
timeout_ms: Duration,
dim: &Dim,
) -> Result<FileMeta> {
if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? {
if let Some(filemeta) = services
.media
.get_thumbnail_with_timeout(mxc, dim, timeout_ms)
.await?
{
return Ok(filemeta);
}
@@ -316,7 +394,11 @@ async fn fetch_file_meta(
user: &UserId,
timeout_ms: Duration,
) -> Result<FileMeta> {
if let Some(filemeta) = services.media.get(mxc).await? {
if let Some(filemeta) = services
.media
.get_with_timeout(mxc, timeout_ms)
.await?
{
return Ok(filemeta);
}

View File

@@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route(
media_id: &body.media_id,
};
match services.media.get(&mxc).await? {
match services
.media
.get_with_timeout(&mxc, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,
@@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route(
media_id: &body.media_id,
};
match services.media.get(&mxc).await? {
match services
.media
.get_with_timeout(&mxc, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,
@@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route(
};
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
match services.media.get_thumbnail(&mxc, &dim).await? {
match services
.media
.get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms)
.await?
{
| Some(FileMeta {
content,
content_type,

View File

@@ -158,6 +158,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::turn_server_route)
.ruma_route(&client::send_event_to_device_route)
.ruma_route(&client::create_content_route)
.ruma_route(&client::create_mxc_uri_route)
.ruma_route(&client::create_content_async_route)
.ruma_route(&client::get_content_thumbnail_route)
.ruma_route(&client::get_content_route)
.ruma_route(&client::get_content_as_filename_route)

View File

@@ -408,6 +408,33 @@ pub struct Config {
#[serde(default = "default_max_request_size")]
pub max_request_size: usize,
/// Maximum number of concurrently pending (asynchronous) media uploads a
/// user can have.
///
/// default: 5
#[serde(default = "default_max_pending_media_uploads")]
pub max_pending_media_uploads: usize,
/// The time in seconds before an unused pending MXC URI expires and is
/// removed.
///
/// default: 86400 (24 hours)
#[serde(default = "default_media_create_unused_expiration_time")]
pub media_create_unused_expiration_time: u64,
/// The maximum number of media create requests per second allowed from a
/// single user.
///
/// default: 10
#[serde(default = "default_rc_media_create_per_second")]
pub rc_media_create_per_second: u64,
/// The maximum burst count for media create requests from a single user.
///
/// default: 50
#[serde(default = "default_rc_media_create_burst_count")]
pub rc_media_create_burst_count: u64,
/// default: 192
#[serde(default = "default_max_fetch_prev_events")]
pub max_fetch_prev_events: u16,
@@ -3245,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 }
fn default_ip_lookup_strategy() -> u8 { 5 }
fn default_max_request_size() -> usize { 24 * 1024 * 1024 }
fn default_max_pending_media_uploads() -> usize { 5 }
fn default_media_create_unused_expiration_time() -> u64 { 86400 }
fn default_rc_media_create_per_second() -> u64 { 10 }
fn default_rc_media_create_burst_count() -> u64 { 50 }
fn default_request_conn_timeout() -> u64 { 10 }

View File

@@ -128,6 +128,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "mediaid_file",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "mediaid_pending",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "mediaid_user",
..descriptor::RANDOM_SMALL

View File

@@ -12,6 +12,7 @@ use super::{preview::UrlPreviewData, thumbnail::Dim};
pub(crate) struct Data {
mediaid_file: Arc<Map>,
mediaid_pending: Arc<Map>,
mediaid_user: Arc<Map>,
url_previews: Arc<Map>,
}
@@ -27,6 +28,7 @@ impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
mediaid_file: db["mediaid_file"].clone(),
mediaid_pending: db["mediaid_pending"].clone(),
mediaid_user: db["mediaid_user"].clone(),
url_previews: db["url_previews"].clone(),
}
@@ -52,6 +54,114 @@ impl Data {
Ok(key.to_vec())
}
/// Insert a pending MXC URI into the database
pub(super) fn insert_pending_mxc(
&self,
mxc: &Mxc<'_>,
user: &UserId,
unused_expires_at: u64,
) {
let key = mxc.to_string();
tracing::info!(
"Inserting pending MXC: key={}, user={}, expires_at={}",
key,
user,
unused_expires_at
);
// 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and
// user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user]
let mut value = Vec::with_capacity(user.as_bytes().len() + 9);
value.extend_from_slice(&unused_expires_at.to_be_bytes());
value.push(0xFF);
value.extend_from_slice(user.as_bytes());
self.mediaid_pending
.insert(key.as_bytes(), &value);
}
/// Count the number of pending MXC URIs for a specific user
pub(super) async fn count_pending_mxc_for_user(&self, user: &UserId) -> (usize, u64) {
let mut count = 0;
let mut earliest_expiration = u64::MAX;
let user_bytes = user.as_bytes();
self.mediaid_pending
.raw_stream()
.ignore_err()
.ready_for_each(|(_key, value)| {
let mut parts = value.splitn(2, |&b| b == 0xFF);
if let Some(expires_at_bytes) = parts.next() {
if let Some(user_id_bytes) = parts.next() {
if user_id_bytes == user_bytes {
count += 1;
let expires_at = u64::from_be_bytes(
expires_at_bytes.try_into().unwrap_or([0u8; 8]),
);
if expires_at < earliest_expiration {
earliest_expiration = expires_at;
}
}
}
}
})
.await;
(count, earliest_expiration)
}
/// Search for a pending MXC URI in the database
pub(super) async fn search_pending_mxc(
&self,
mxc: &Mxc<'_>,
) -> Option<(ruma::OwnedUserId, u64)> {
let key = mxc.to_string();
tracing::info!("Searching for pending MXC: key={}", key);
let value = match self.mediaid_pending.get(key.as_bytes()).await {
| Ok(v) => v,
| Err(e) => {
tracing::info!("pending MXC not found or error for {}: {}", key, e);
return None;
},
};
let mut parts = value.splitn(2, |&b| b == 0xFF);
// value: [expires_at, 0xFF, user]
let expires_at_bytes = parts.next()?;
let expires_at = match expires_at_bytes.try_into() {
| Ok(v) => u64::from_be_bytes(v),
| Err(_) => {
tracing::error!("Failed to parse expires_at for {}", key);
return None;
},
};
let user_id_bytes = parts.next()?;
let user_str = match str_from_bytes(user_id_bytes) {
| Ok(v) => v,
| Err(_) => {
tracing::error!("Failed to parse user_str for {}", key);
return None;
},
};
let user_id = match user_str.try_into() {
| Ok(v) => v,
| Err(e) => {
tracing::error!("Failed to parse user_id for {}: {}", key, e);
return None;
},
};
tracing::info!(
"Found pending MXC for {}: user={}, expires_at={}",
key,
user_id,
expires_at
);
Some((user_id, expires_at))
}
pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) {
self.mediaid_pending
.remove(mxc.to_string().as_bytes());
}
pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) {
debug!("MXC URI: {mxc}");

View File

@@ -5,7 +5,11 @@ mod preview;
mod remote;
mod tests;
mod thumbnail;
use std::{path::PathBuf, sync::Arc, time::SystemTime};
use std::{
path::PathBuf,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose};
@@ -13,6 +17,7 @@ use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use tokio::{
fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Notify,
};
use tuwunel_core::{
Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
@@ -29,9 +34,18 @@ pub struct FileMeta {
pub content_type: Option<String>,
pub content_disposition: Option<ContentDisposition>,
}
/// For MSC2246
pub(super) struct MXCState {
/// Save the notifier for each pending media upload
pub(super) notifiers: std::sync::Mutex<std::collections::HashMap<String, Arc<Notify>>>,
/// Save the ratelimiter for each user
pub(super) ratelimiter:
std::sync::Mutex<std::collections::HashMap<ruma::OwnedUserId, (Instant, f64)>>,
}
pub struct Service {
url_preview_mutex: MutexMap<String, ()>,
mxc_state: MXCState,
pub(super) db: Data,
services: Arc<crate::services::OnceServices>,
}
@@ -50,6 +64,10 @@ impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
url_preview_mutex: MutexMap::new(),
mxc_state: MXCState {
notifiers: std::sync::Mutex::new(std::collections::HashMap::new()),
ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()),
},
db: Data::new(args.db),
services: args.services.clone(),
}))
@@ -65,6 +83,119 @@ impl crate::Service for Service {
}
impl Service {
/// Create a pending media upload ID.
pub async fn create_pending(
&self,
mxc: &Mxc<'_>,
user: &UserId,
unused_expires_at: u64,
) -> Result<()> {
let config = &self.services.server.config;
// Rate limiting (rc_media_create)
let rate = config.rc_media_create_per_second as f64;
let burst = config.rc_media_create_burst_count as f64;
// Check rate limiting
if rate > 0.0 && burst > 0.0 {
let now = Instant::now();
let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap();
let (last_time, tokens) = ratelimiter
.entry(user.to_owned())
.or_insert_with(|| (now, burst));
let elapsed = now.duration_since(*last_time).as_secs_f64();
let new_tokens = (*tokens + elapsed * rate).min(burst);
if new_tokens >= 1.0 {
*last_time = now;
*tokens = new_tokens - 1.0;
} else {
return Err(tuwunel_core::Error::Request(
ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None },
"Too many pending media creation requests.".into(),
http::StatusCode::TOO_MANY_REQUESTS,
));
}
}
let max_uploads = config.max_pending_media_uploads;
let (current_uploads, earliest_expiration) =
self.db.count_pending_mxc_for_user(user).await;
// Check if the user has reached the maximum number of pending media uploads
if current_uploads >= max_uploads {
let retry_after = earliest_expiration.saturating_sub(
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
);
return Err(tuwunel_core::Error::Request(
ruma::api::client::error::ErrorKind::LimitExceeded {
retry_after: Some(ruma::api::client::error::RetryAfter::Delay(
Duration::from_millis(retry_after),
)),
},
"Maximum number of pending media uploads reached.".into(),
http::StatusCode::TOO_MANY_REQUESTS,
));
}
self.db
.insert_pending_mxc(mxc, user, unused_expires_at);
Ok(())
}
/// Uploads content to a pending media ID.
pub async fn upload_pending(
&self,
mxc: &Mxc<'_>,
user: &UserId,
content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>,
file: &[u8],
) -> Result<()> {
let pending = self.db.search_pending_mxc(mxc).await;
let Some((owner_id, expires_at)) = pending else {
return Err!(Request(NotFound("Media not found")));
};
if owner_id != user {
return Err!(Request(Forbidden("You did not create this media ID")));
}
let current_time = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_millis() as u64;
if expires_at < current_time {
return Err!(Request(NotFound("Pending media ID expired")));
}
self.create(mxc, Some(user), content_disposition, content_type, file)
.await?;
self.db.remove_pending_mxc(mxc);
if let Some(notifier) = self
.mxc_state
.notifiers
.lock()
.unwrap()
.get(&mxc.to_string())
.cloned()
{
notifier.notify_waiters();
let mut map = self.mxc_state.notifiers.lock().unwrap();
map.remove(&mxc.to_string());
}
Ok(())
}
/// Uploads a file.
pub async fn create(
&self,
@@ -168,6 +299,71 @@ impl Service {
}
}
/// Download a file and wait up to a timeout_ms if it is pending.
pub async fn get_with_timeout(
&self,
mxc: &Mxc<'_>,
timeout_duration: Duration,
) -> Result<Option<FileMeta>> {
if let Some(meta) = self.get(mxc).await? {
return Ok(Some(meta));
}
let pending = self.db.search_pending_mxc(mxc).await;
if pending.is_none() {
return Ok(None);
}
let notifier = {
let mut map = self.mxc_state.notifiers.lock().unwrap();
map.entry(mxc.to_string())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
};
if tokio::time::timeout(timeout_duration, notifier.notified())
.await
.is_err()
{
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
}
self.get(mxc).await
}
/// Download a thumbnail and wait up to a timeout_ms if it is pending.
pub async fn get_thumbnail_with_timeout(
&self,
mxc: &Mxc<'_>,
dim: &Dim,
timeout_duration: Duration,
) -> Result<Option<FileMeta>> {
if let Some(meta) = self.get_thumbnail(mxc, dim).await? {
return Ok(Some(meta));
}
let pending = self.db.search_pending_mxc(mxc).await;
if pending.is_none() {
return Ok(None);
}
let notifier = {
let mut map = self.mxc_state.notifiers.lock().unwrap();
map.entry(mxc.to_string())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
};
if tokio::time::timeout(timeout_duration, notifier.notified())
.await
.is_err()
{
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
}
self.get_thumbnail(mxc, dim).await
}
/// Gets all the MXC URIs in our media database
pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
let all_keys = self.db.get_all_media_keys().await;

View File

@@ -320,6 +320,22 @@
#
#max_request_size = 24 MiB
# Maximum number of concurrently pending (asynchronous) media uploads a user can have.
#
#max_pending_media_uploads = 5
# The time in seconds before an unused pending MXC URI expires and is removed.
#
#media_create_unused_expiration_time = 86400 (24 hours)
# The maximum number of media create requests per second allowed from a single user.
#
#rc_media_create_per_second = 10
# The maximum burst count for media create requests from a single user.
#
#rc_media_create_burst_count = 50
# This item is undocumented. Please contribute documentation for it.
#
#max_fetch_prev_events = 192