add MSC2246 support
This commit is contained in:
@@ -5,7 +5,11 @@ mod preview;
|
||||
mod remote;
|
||||
mod tests;
|
||||
mod thumbnail;
|
||||
use std::{path::PathBuf, sync::Arc, time::SystemTime};
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
@@ -13,6 +17,7 @@ use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Notify,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
|
||||
@@ -29,9 +34,18 @@ pub struct FileMeta {
|
||||
pub content_type: Option<String>,
|
||||
pub content_disposition: Option<ContentDisposition>,
|
||||
}
|
||||
/// For MSC2246
|
||||
pub(super) struct MXCState {
|
||||
/// Save the notifier for each pending media upload
|
||||
pub(super) notifiers: std::sync::Mutex<std::collections::HashMap<String, Arc<Notify>>>,
|
||||
/// Save the ratelimiter for each user
|
||||
pub(super) ratelimiter:
|
||||
std::sync::Mutex<std::collections::HashMap<ruma::OwnedUserId, (Instant, f64)>>,
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
url_preview_mutex: MutexMap<String, ()>,
|
||||
mxc_state: MXCState,
|
||||
pub(super) db: Data,
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
}
|
||||
@@ -50,6 +64,10 @@ impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
url_preview_mutex: MutexMap::new(),
|
||||
mxc_state: MXCState {
|
||||
notifiers: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
ratelimiter: std::sync::Mutex::new(std::collections::HashMap::new()),
|
||||
},
|
||||
db: Data::new(args.db),
|
||||
services: args.services.clone(),
|
||||
}))
|
||||
@@ -65,6 +83,119 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Create a pending media upload ID.
|
||||
pub async fn create_pending(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: &UserId,
|
||||
unused_expires_at: u64,
|
||||
) -> Result<()> {
|
||||
let config = &self.services.server.config;
|
||||
|
||||
// Rate limiting (rc_media_create)
|
||||
let rate = config.rc_media_create_per_second as f64;
|
||||
let burst = config.rc_media_create_burst_count as f64;
|
||||
|
||||
// Check rate limiting
|
||||
if rate > 0.0 && burst > 0.0 {
|
||||
let now = Instant::now();
|
||||
let mut ratelimiter = self.mxc_state.ratelimiter.lock().unwrap();
|
||||
|
||||
let (last_time, tokens) = ratelimiter
|
||||
.entry(user.to_owned())
|
||||
.or_insert_with(|| (now, burst));
|
||||
|
||||
let elapsed = now.duration_since(*last_time).as_secs_f64();
|
||||
let new_tokens = (*tokens + elapsed * rate).min(burst);
|
||||
|
||||
if new_tokens >= 1.0 {
|
||||
*last_time = now;
|
||||
*tokens = new_tokens - 1.0;
|
||||
} else {
|
||||
return Err(tuwunel_core::Error::Request(
|
||||
ruma::api::client::error::ErrorKind::LimitExceeded { retry_after: None },
|
||||
"Too many pending media creation requests.".into(),
|
||||
http::StatusCode::TOO_MANY_REQUESTS,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let max_uploads = config.max_pending_media_uploads;
|
||||
let (current_uploads, earliest_expiration) =
|
||||
self.db.count_pending_mxc_for_user(user).await;
|
||||
|
||||
// Check if the user has reached the maximum number of pending media uploads
|
||||
if current_uploads >= max_uploads {
|
||||
let retry_after = earliest_expiration.saturating_sub(
|
||||
SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64,
|
||||
);
|
||||
return Err(tuwunel_core::Error::Request(
|
||||
ruma::api::client::error::ErrorKind::LimitExceeded {
|
||||
retry_after: Some(ruma::api::client::error::RetryAfter::Delay(
|
||||
Duration::from_millis(retry_after),
|
||||
)),
|
||||
},
|
||||
"Maximum number of pending media uploads reached.".into(),
|
||||
http::StatusCode::TOO_MANY_REQUESTS,
|
||||
));
|
||||
}
|
||||
|
||||
self.db
|
||||
.insert_pending_mxc(mxc, user, unused_expires_at);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads content to a pending media ID.
|
||||
pub async fn upload_pending(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: &UserId,
|
||||
content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let pending = self.db.search_pending_mxc(mxc).await;
|
||||
let Some((owner_id, expires_at)) = pending else {
|
||||
return Err!(Request(NotFound("Media not found")));
|
||||
};
|
||||
|
||||
if owner_id != user {
|
||||
return Err!(Request(Forbidden("You did not create this media ID")));
|
||||
}
|
||||
|
||||
let current_time = SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_millis() as u64;
|
||||
|
||||
if expires_at < current_time {
|
||||
return Err!(Request(NotFound("Pending media ID expired")));
|
||||
}
|
||||
|
||||
self.create(mxc, Some(user), content_disposition, content_type, file)
|
||||
.await?;
|
||||
|
||||
self.db.remove_pending_mxc(mxc);
|
||||
|
||||
if let Some(notifier) = self
|
||||
.mxc_state
|
||||
.notifiers
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(&mxc.to_string())
|
||||
.cloned()
|
||||
{
|
||||
notifier.notify_waiters();
|
||||
let mut map = self.mxc_state.notifiers.lock().unwrap();
|
||||
map.remove(&mxc.to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
@@ -168,6 +299,71 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
/// Download a file and wait up to a timeout_ms if it is pending.
|
||||
pub async fn get_with_timeout(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
if let Some(meta) = self.get(mxc).await? {
|
||||
return Ok(Some(meta));
|
||||
}
|
||||
|
||||
let pending = self.db.search_pending_mxc(mxc).await;
|
||||
if pending.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let notifier = {
|
||||
let mut map = self.mxc_state.notifiers.lock().unwrap();
|
||||
map.entry(mxc.to_string())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone()
|
||||
};
|
||||
|
||||
if tokio::time::timeout(timeout_duration, notifier.notified())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
|
||||
}
|
||||
|
||||
self.get(mxc).await
|
||||
}
|
||||
|
||||
/// Download a thumbnail and wait up to a timeout_ms if it is pending.
|
||||
pub async fn get_thumbnail_with_timeout(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
dim: &Dim,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
if let Some(meta) = self.get_thumbnail(mxc, dim).await? {
|
||||
return Ok(Some(meta));
|
||||
}
|
||||
|
||||
let pending = self.db.search_pending_mxc(mxc).await;
|
||||
if pending.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let notifier = {
|
||||
let mut map = self.mxc_state.notifiers.lock().unwrap();
|
||||
map.entry(mxc.to_string())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone()
|
||||
};
|
||||
|
||||
if tokio::time::timeout(timeout_duration, notifier.notified())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
|
||||
}
|
||||
|
||||
self.get_thumbnail(mxc, dim).await
|
||||
}
|
||||
|
||||
/// Gets all the MXC URIs in our media database
|
||||
pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
|
||||
let all_keys = self.db.get_all_media_keys().await;
|
||||
|
||||
Reference in New Issue
Block a user