This commit is contained in:
Jason Volk
2026-03-08 14:36:49 +00:00
parent c960a9dbc3
commit dfedef4d19
5 changed files with 71 additions and 88 deletions

View File

@@ -15,7 +15,10 @@ use ruma::{
}; };
use tuwunel_core::{ use tuwunel_core::{
Err, Result, err, Err, Result, err,
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize}, utils::{
self, content_disposition::make_content_disposition, math::ruma_from_usize,
time::now_millis,
},
}; };
use tuwunel_service::{ use tuwunel_service::{
Services, Services,
@@ -100,20 +103,14 @@ pub(crate) async fn create_mxc_uri_route(
media_id: &utils::random_string(MXC_LENGTH), media_id: &utils::random_string(MXC_LENGTH),
}; };
let unused_expires_at = u64::try_from( // safe because even if it overflows, it will be greater than the current time
std::time::SystemTime::now() // and the unused media will be deleted anyway
.duration_since(std::time::UNIX_EPOCH) let unused_expires_at = now_millis().saturating_add(
.expect("Time went backwards")
.as_millis(),
)?
.saturating_add(
services services
.server .server
.config .config
.media_create_unused_expiration_time .media_create_unused_expiration_time
// safe because even if it overflows, it will be greater than the current time .saturating_mul(1000),
// and the unused media will be deleted anyway
.saturating_mul(1000),
); );
services services
.media .media

View File

@@ -57,24 +57,24 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode {
use ErrorKind::*; use ErrorKind::*;
match kind { match kind {
// 504
| NotYetUploaded => StatusCode::GATEWAY_TIMEOUT,
// 429 // 429
| LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, | LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
// 409
| CannotOverwriteMedia => StatusCode::CONFLICT,
// 413 // 413
| TooLarge => StatusCode::PAYLOAD_TOO_LARGE, | TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
// 409
| CannotOverwriteMedia => StatusCode::CONFLICT,
// 405 // 405
| Unrecognized => StatusCode::METHOD_NOT_ALLOWED, | Unrecognized => StatusCode::METHOD_NOT_ALLOWED,
// 404 // 404
| NotFound | NotImplemented | FeatureDisabled => StatusCode::NOT_FOUND, | NotFound | NotImplemented | FeatureDisabled => StatusCode::NOT_FOUND,
// 504
| NotYetUploaded => StatusCode::GATEWAY_TIMEOUT,
// 403 // 403
| GuestAccessForbidden | GuestAccessForbidden
| ThreepidAuthFailed | ThreepidAuthFailed

View File

@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use futures::{StreamExt, pin_mut}; use futures::{StreamExt, pin_mut};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug, debug_info, err, Err, Result, debug, debug_error, debug_info, err, error,
utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes}, utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes},
}; };
use tuwunel_database::{Database, Interfix, Map, serialize_key}; use tuwunel_database::{Database, Interfix, Map, serialize_key};
@@ -61,13 +61,9 @@ impl Data {
user: &UserId, user: &UserId,
unused_expires_at: u64, unused_expires_at: u64,
) { ) {
debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending MXC");
let key = mxc.to_string(); let key = mxc.to_string();
tracing::info!(
"Inserting pending MXC: key={}, user={}, expires_at={}",
key,
user,
unused_expires_at
);
// 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and // 8 bytes for unused_expires_at (u64), 1 byte for 0xFF, and
// user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user] // user.as_bytes().len() for user value: [unused_expires_at, 0xFF, user]
let mut value = Vec::with_capacity( let mut value = Vec::with_capacity(
@@ -119,12 +115,13 @@ impl Data {
&self, &self,
mxc: &Mxc<'_>, mxc: &Mxc<'_>,
) -> Option<(ruma::OwnedUserId, u64)> { ) -> Option<(ruma::OwnedUserId, u64)> {
debug!(?mxc, "Searching for pending MXC");
let key = mxc.to_string(); let key = mxc.to_string();
tracing::info!("Searching for pending MXC: key={}", key);
let value = match self.mediaid_pending.get(key.as_bytes()).await { let value = match self.mediaid_pending.get(key.as_bytes()).await {
| Ok(v) => v, | Ok(v) => v,
| Err(e) => { | Err(e) => {
tracing::info!("pending MXC not found or error for {}: {}", key, e); debug_error!(?key, "pending MXC not found or error: {e}");
return None; return None;
}, },
}; };
@@ -134,30 +131,26 @@ impl Data {
let expires_at = match expires_at_bytes.try_into() { let expires_at = match expires_at_bytes.try_into() {
| Ok(v) => u64::from_be_bytes(v), | Ok(v) => u64::from_be_bytes(v),
| Err(_) => { | Err(_) => {
tracing::error!("Failed to parse expires_at for {}", key); debug_error!(?key, "Failed to parse expires_at for key");
return None; return None;
}, },
}; };
let user_id_bytes = parts.next()?; let user_id_bytes = parts.next()?;
let Ok(user_str) = str_from_bytes(user_id_bytes) else { let Ok(user_str) = str_from_bytes(user_id_bytes) else {
tracing::error!("Failed to parse user_str for {}", key); error!(?key, "Failed to parse user_str for key");
return None; return None;
}; };
let user_id = match user_str.try_into() { let user_id = match user_str.try_into() {
| Ok(v) => v, | Ok(v) => v,
| Err(e) => { | Err(e) => {
tracing::error!("Failed to parse user_id for {}: {}", key, e); error!(?key, "Failed to parse user_id for key: {e}");
return None; return None;
}, },
}; };
tracing::info!( debug!(?key, ?user_id, ?expires_at, "Found pending MXC");
"Found pending MXC for {}: user={}, expires_at={}",
key,
user_id,
expires_at
);
Some((user_id, expires_at)) Some((user_id, expires_at))
} }

View File

@@ -6,22 +6,28 @@ mod remote;
mod tests; mod tests;
mod thumbnail; mod thumbnail;
use std::{ use std::{
collections::HashMap,
path::PathBuf, path::PathBuf,
sync::Arc, sync::{Arc, Mutex},
time::{Duration, Instant, SystemTime}, time::{Duration, Instant, SystemTime},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose}; use base64::{Engine as _, engine::general_purpose};
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition}; use http::StatusCode;
use ruma::{
Mxc, OwnedMxcUri, OwnedUserId, UserId,
api::client::error::{ErrorKind, RetryAfter},
http_headers::ContentDisposition,
};
use tokio::{ use tokio::{
fs, fs,
io::{AsyncReadExt, AsyncWriteExt, BufReader}, io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Notify, sync::Notify,
}; };
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
utils::{self, MutexMap}, utils::{self, MutexMap, time::now_millis},
warn, warn,
}; };
@@ -34,20 +40,20 @@ pub struct FileMeta {
pub content_type: Option<String>, pub content_type: Option<String>,
pub content_disposition: Option<ContentDisposition>, pub content_disposition: Option<ContentDisposition>,
} }
/// For MSC2246 /// For MSC2246
pub(super) struct MXCState { struct MXCState {
/// Save the notifier for each pending media upload /// Save the notifier for each pending media upload
pub(super) notifiers: std::sync::Mutex<std::collections::HashMap<String, Arc<Notify>>>, notifiers: Mutex<HashMap<String, Arc<Notify>>>,
/// Save the ratelimiter for each user /// Save the ratelimiter for each user
pub(super) ratelimiter: ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
std::sync::Mutex<std::collections::HashMap<ruma::OwnedUserId, (Instant, f64)>>,
} }
pub struct Service { pub struct Service {
url_preview_mutex: MutexMap<String, ()>,
mxc_state: MXCState,
pub(super) db: Data, pub(super) db: Data,
services: Arc<crate::services::OnceServices>, services: Arc<crate::services::OnceServices>,
url_preview_mutex: MutexMap<String, ()>,
mxc_state: MXCState,
} }
/// generated MXC ID (`media-id`) length /// generated MXC ID (`media-id`) length
@@ -63,13 +69,13 @@ pub const CORP_CROSS_ORIGIN: &str = "cross-origin";
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
url_preview_mutex: MutexMap::new(),
mxc_state: MXCState {
notifiers: std::sync::Mutex::new(std::collections::HashMap::new()),
ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()),
},
db: Data::new(args.db), db: Data::new(args.db),
services: args.services.clone(), services: args.services.clone(),
url_preview_mutex: MutexMap::new(),
mxc_state: MXCState {
notifiers: Mutex::new(HashMap::new()),
ratelimiter: Mutex::new(HashMap::new()),
},
})) }))
} }
@@ -112,10 +118,10 @@ impl Service {
*last_time = now; *last_time = now;
*tokens = new_tokens - 1.0; *tokens = new_tokens - 1.0;
} else { } else {
return Err(tuwunel_core::Error::Request( return Err(Error::Request(
ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None }, ErrorKind::LimitExceeded { retry_after: None },
"Too many pending media creation requests.".into(), "Too many pending media creation requests.".into(),
http::StatusCode::TOO_MANY_REQUESTS, StatusCode::TOO_MANY_REQUESTS,
)); ));
} }
} }
@@ -126,25 +132,19 @@ impl Service {
// Check if the user has reached the maximum number of pending media uploads // Check if the user has reached the maximum number of pending media uploads
if current_uploads >= max_uploads { if current_uploads >= max_uploads {
let retry_after = earliest_expiration.saturating_sub(u64::try_from( let retry_after = earliest_expiration.saturating_sub(now_millis());
SystemTime::now() return Err(Error::Request(
.duration_since(std::time::UNIX_EPOCH) ErrorKind::LimitExceeded {
.unwrap() retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))),
.as_millis(),
)?);
return Err(tuwunel_core::Error::Request(
ruma::api::client::error::ErrorKind::LimitExceeded {
retry_after: Some(ruma::api::client::error::RetryAfter::Delay(
Duration::from_millis(retry_after),
)),
}, },
"Maximum number of pending media uploads reached.".into(), "Maximum number of pending media uploads reached.".into(),
http::StatusCode::TOO_MANY_REQUESTS, StatusCode::TOO_MANY_REQUESTS,
)); ));
} }
self.db self.db
.insert_pending_mxc(mxc, user, unused_expires_at); .insert_pending_mxc(mxc, user, unused_expires_at);
Ok(()) Ok(())
} }
@@ -156,9 +156,8 @@ impl Service {
content_disposition: Option<&ContentDisposition>, content_disposition: Option<&ContentDisposition>,
content_type: Option<&str>, content_type: Option<&str>,
file: &[u8], file: &[u8],
) -> Result<()> { ) -> Result {
let pending = self.db.search_pending_mxc(mxc).await; let Some((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
let Some((owner_id, expires_at)) = pending else {
if self.get_metadata(mxc).await.is_some() { if self.get_metadata(mxc).await.is_some() {
return Err!(Request(CannotOverwriteMedia("Media ID already has content"))); return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
} }
@@ -170,13 +169,7 @@ impl Service {
return Err!(Request(Forbidden("You did not create this media ID"))); return Err!(Request(Forbidden("You did not create this media ID")));
} }
let current_time = u64::try_from( let current_time = now_millis();
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_millis(),
)?;
if expires_at < current_time { if expires_at < current_time {
return Err!(Request(NotFound("Pending media ID expired"))); return Err!(Request(NotFound("Pending media ID expired")));
} }
@@ -189,8 +182,7 @@ impl Service {
if let Some(notifier) = self if let Some(notifier) = self
.mxc_state .mxc_state
.notifiers .notifiers
.lock() .lock()?
.unwrap()
.get(&mxc.to_string()) .get(&mxc.to_string())
.cloned() .cloned()
{ {
@@ -315,10 +307,9 @@ impl Service {
return Ok(Some(meta)); return Ok(Some(meta));
} }
let pending = self.db.search_pending_mxc(mxc).await; let Some(_pending) = self.db.search_pending_mxc(mxc).await else {
if pending.is_none() {
return Ok(None); return Ok(None);
} };
let notifier = { let notifier = {
let mut map = self.mxc_state.notifiers.lock().unwrap(); let mut map = self.mxc_state.notifiers.lock().unwrap();
@@ -348,10 +339,9 @@ impl Service {
return Ok(Some(meta)); return Ok(Some(meta));
} }
let pending = self.db.search_pending_mxc(mxc).await; let Some(_pending) = self.db.search_pending_mxc(mxc).await else {
if pending.is_none() {
return Ok(None); return Ok(None);
} };
let notifier = { let notifier = {
let mut map = self.mxc_state.notifiers.lock().unwrap(); let mut map = self.mxc_state.notifiers.lock().unwrap();

View File

@@ -320,15 +320,18 @@
# #
#max_request_size = 24 MiB #max_request_size = 24 MiB
# Maximum number of concurrently pending (asynchronous) media uploads a user can have. # Maximum number of concurrently pending (asynchronous) media uploads a
# user can have.
# #
#max_pending_media_uploads = 5 #max_pending_media_uploads = 5
# The time in seconds before an unused pending MXC URI expires and is removed. # The time in seconds before an unused pending MXC URI expires and is
# removed.
# #
#media_create_unused_expiration_time = 86400 (24 hours) #media_create_unused_expiration_time = 86400 (24 hours)
# The maximum number of media create requests per second allowed from a single user. # The maximum number of media create requests per second allowed from a
# single user.
# #
#rc_media_create_per_second = 10 #rc_media_create_per_second = 10