chain_width to 50

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-04-22 04:42:26 +00:00
parent 9b658d86b2
commit 76509830e6
190 changed files with 3469 additions and 930 deletions

View File

@@ -70,8 +70,14 @@ pub async fn update(
.put(roomuserdataid, Json(data));
let key = (room_id, user_id, &event_type);
let prev = self.db.roomusertype_roomuserdataid.qry(&key).await;
self.db.roomusertype_roomuserdataid.put(key, roomuserdataid);
let prev = self
.db
.roomusertype_roomuserdataid
.qry(&key)
.await;
self.db
.roomusertype_roomuserdataid
.put(key, roomuserdataid);
// Remove old entry
if let Ok(prev) = prev {
@@ -119,7 +125,11 @@ pub async fn get_raw(
self.db
.roomusertype_roomuserdataid
.qry(&key)
.and_then(|roomuserdataid| self.db.roomuserdataid_accountdata.get(&roomuserdataid))
.and_then(|roomuserdataid| {
self.db
.roomuserdataid_accountdata
.get(&roomuserdataid)
})
.await
}

View File

@@ -106,8 +106,10 @@ impl Console {
| ReadlineEvent::Line(string) => self.clone().handle(string).await,
| ReadlineEvent::Interrupted => continue,
| ReadlineEvent::Eof => break,
| ReadlineEvent::Quit =>
self.server.shutdown().unwrap_or_else(error::default_log),
| ReadlineEvent::Quit => self
.server
.shutdown()
.unwrap_or_else(error::default_log),
},
| Err(error) => match error {
| ReadlineError::Closed => break,
@@ -135,7 +137,11 @@ impl Console {
let (abort, abort_reg) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_reg);
_ = self.input_abort.lock().expect("locked").insert(abort);
_ = self
.input_abort
.lock()
.expect("locked")
.insert(abort);
defer! {{
_ = self.input_abort.lock().expect("locked").take();
}}
@@ -158,7 +164,11 @@ impl Console {
let (abort, abort_reg) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_reg);
_ = self.command_abort.lock().expect("locked").insert(abort);
_ = self
.command_abort
.lock()
.expect("locked")
.insert(abort);
defer! {{
_ = self.command_abort.lock().expect("locked").take();
}}

View File

@@ -8,7 +8,12 @@ pub(super) const SIGNAL: &str = "SIGUSR2";
#[implement(super::Service)]
pub(super) async fn console_auto_start(&self) {
#[cfg(feature = "console")]
if self.services.server.config.admin_console_automatic {
if self
.services
.server
.config
.admin_console_automatic
{
// Allow more of the startup sequence to execute before spawning
tokio::task::yield_now().await;
self.console.start().await;
@@ -32,7 +37,12 @@ pub(super) async fn startup_execute(&self) -> Result {
let smoketest = self.services.server.config.test.contains("smoke");
// When true, errors are ignored and startup continues.
let errors = !smoketest && self.services.server.config.admin_execute_errors_ignore;
let errors = !smoketest
&& self
.services
.server
.config
.admin_execute_errors_ignore;
//TODO: remove this after run-states are broadcast
sleep(Duration::from_millis(500)).await;
@@ -65,10 +75,19 @@ pub(super) async fn startup_execute(&self) -> Result {
#[implement(super::Service)]
pub(super) async fn signal_execute(&self) -> Result {
// List of comamnds to execute
let commands = self.services.server.config.admin_signal_execute.clone();
let commands = self
.services
.server
.config
.admin_signal_execute
.clone();
// When true, errors are ignored and execution continues.
let ignore_errors = self.services.server.config.admin_execute_errors_ignore;
let ignore_errors = self
.services
.server
.config
.admin_execute_errors_ignore;
for (i, command) in commands.iter().enumerate() {
if let Err(e) = self.execute_command(i, command.clone()).await {

View File

@@ -30,7 +30,12 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result {
let state_lock = self.services.state.mutex.lock(&room_id).await;
if self.services.state_cache.is_joined(user_id, &room_id).await {
if self
.services
.state_cache
.is_joined(user_id, &room_id)
.await
{
return Err!(debug_warn!("User is already joined in the admin room"));
}
if self
@@ -106,7 +111,9 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result {
room_power_levels
.users
.insert(server_user.into(), 69420.into());
room_power_levels.users.insert(user_id.into(), 100.into());
room_power_levels
.users
.insert(user_id.into(), 100.into());
self.services
.timeline
@@ -119,9 +126,17 @@ pub async fn make_user_admin(&self, user_id: &UserId) -> Result {
.await?;
// Set room tag
let room_tag = self.services.server.config.admin_room_tag.as_str();
let room_tag = self
.services
.server
.config
.admin_room_tag
.as_str();
if !room_tag.is_empty() {
if let Err(e) = self.set_room_tag(&room_id, user_id, room_tag).await {
if let Err(e) = self
.set_room_tag(&room_id, user_id, room_tag)
.await
{
error!(?room_id, ?user_id, ?room_tag, "Failed to set tag for admin grant: {e}");
}
}

View File

@@ -260,7 +260,12 @@ impl Service {
return Ok(());
};
let Ok(pdu) = self.services.timeline.get_pdu(&in_reply_to.event_id).await else {
let Ok(pdu) = self
.services
.timeline
.get_pdu(&in_reply_to.event_id)
.await
else {
error!(
event_id = ?in_reply_to.event_id,
"Missing admin command in_reply_to event"
@@ -327,7 +332,10 @@ impl Service {
pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool {
// Server-side command-escape with public echo
let is_escape = body.starts_with('\\');
let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin");
let is_public_escape = is_escape
&& body
.trim_start_matches('\\')
.starts_with("!admin");
// Admin command with public echo (in admin room)
let server_user = &self.services.globals.server_user;
@@ -361,7 +369,12 @@ impl Service {
// This will evaluate to false if the emergency password is set up so that
// the administrator can execute commands as the server user
let emergency_password_set = self.services.server.config.emergency_password.is_some();
let emergency_password_set = self
.services
.server
.config
.emergency_password
.is_some();
let from_server = pdu.sender == *server_user && !emergency_password_set;
if from_server && self.is_admin_room(&pdu.room_id).await {
return false;
@@ -382,7 +395,11 @@ impl Service {
/// Sets the self-reference to crate::Services which will provide context to
/// the admin commands.
pub(super) fn set_services(&self, services: Option<&Arc<crate::Services>>) {
let receiver = &mut *self.services.services.write().expect("locked for writing");
let receiver = &mut *self
.services
.services
.write()
.expect("locked for writing");
let weak = services.map(Arc::downgrade);
*receiver = weak;
}

View File

@@ -94,7 +94,9 @@ impl Service {
.ok_or_else(|| err!("Appservice not found"))?;
// remove the appservice from the database
self.db.id_appserviceregistrations.del(appservice_id);
self.db
.id_appserviceregistrations
.del(appservice_id);
// deletes all active requests for the appservice if there are any so we stop
// sending to the URL

View File

@@ -41,9 +41,13 @@ impl crate::Service for Service {
return Ok(());
}
self.set_emergency_access().await.inspect_err(|e| {
error!("Could not set the configured emergency password for the server user: {e}");
})
self.set_emergency_access()
.await
.inspect_err(|e| {
error!(
"Could not set the configured emergency password for the server user: {e}"
);
})
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
@@ -69,7 +73,9 @@ impl Service {
.update(
None,
server_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent { global: ruleset },
})
@@ -86,7 +92,10 @@ impl Service {
Ok(())
} else {
// logs out any users still in the server service account and removes sessions
self.services.users.deactivate_account(server_user).await
self.services
.users
.deactivate_account(server_user)
.await
}
}
}

View File

@@ -74,10 +74,15 @@ where
return Err!(Request(Forbidden(debug_warn!("Federation with {dest} is not allowed."))));
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
let actual = self
.services
.resolver
.get_actual_dest(dest)
.await?;
let request = into_http_request::<T>(&actual, request)?;
let request = self.prepare(dest, request)?;
self.perform::<T>(dest, &actual, request, client).await
self.perform::<T>(dest, &actual, request, client)
.await
}
#[implement(super::Service)]

View File

@@ -140,19 +140,31 @@ impl Service {
pub fn notification_push_path(&self) -> &String { &self.server.config.notification_push_path }
pub fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
&self.server.config.url_preview_domain_contains_allowlist
&self
.server
.config
.url_preview_domain_contains_allowlist
}
pub fn url_preview_domain_explicit_allowlist(&self) -> &Vec<String> {
&self.server.config.url_preview_domain_explicit_allowlist
&self
.server
.config
.url_preview_domain_explicit_allowlist
}
pub fn url_preview_domain_explicit_denylist(&self) -> &Vec<String> {
&self.server.config.url_preview_domain_explicit_denylist
&self
.server
.config
.url_preview_domain_explicit_denylist
}
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
&self.server.config.url_preview_url_contains_allowlist
&self
.server
.config
.url_preview_url_contains_allowlist
}
pub fn url_preview_max_spider_size(&self) -> usize {

View File

@@ -56,7 +56,9 @@ pub fn create_backup(
let count = self.services.globals.next_count()?;
let key = (user_id, &version);
self.db.backupid_algorithm.put(key, Json(backup_metadata));
self.db
.backupid_algorithm
.put(key, Json(backup_metadata));
self.db.backupid_etag.put(key, count);
@@ -88,7 +90,13 @@ pub async fn update_backup<'a>(
backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<&'a str> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
if self
.db
.backupid_algorithm
.qry(&key)
.await
.is_err()
{
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
}
@@ -140,7 +148,11 @@ pub async fn get_latest_backup(
#[implement(Service)]
pub async fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Raw<BackupAlgorithm>> {
let key = (user_id, version);
self.db.backupid_algorithm.qry(&key).await.deserialized()
self.db
.backupid_algorithm
.qry(&key)
.await
.deserialized()
}
#[implement(Service)]
@@ -153,7 +165,13 @@ pub async fn add_key(
key_data: &Raw<KeyBackupData>,
) -> Result<()> {
let key = (user_id, version);
if self.db.backupid_algorithm.qry(&key).await.is_err() {
if self
.db
.backupid_algorithm
.qry(&key)
.await
.is_err()
{
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
}
@@ -251,7 +269,11 @@ pub async fn get_session(
) -> Result<Raw<KeyBackupData>> {
let key = (user_id, version, room_id, session_id);
self.db.backupkeyid_backup.qry(&key).await.deserialized()
self.db
.backupkeyid_backup
.qry(&key)
.await
.deserialized()
}
#[implement(Service)]

View File

@@ -159,7 +159,11 @@ impl Service {
/// Downloads a file.
pub async fn get(&self, mxc: &Mxc<'_>) -> Result<Option<FileMeta>> {
match self.db.search_file_metadata(mxc, &Dim::default()).await {
match self
.db
.search_file_metadata(mxc, &Dim::default())
.await
{
| Ok(Metadata { content_disposition, content_type, key }) => {
let mut content = Vec::with_capacity(8192);
let path = self.get_media_file(&key);

View File

@@ -82,7 +82,10 @@ async fn request_url_preview(&self, url: &Url) -> Result<UrlPreviewData> {
}
}
let Some(content_type) = response.headers().get(reqwest::header::CONTENT_TYPE) else {
let Some(content_type) = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
else {
return Err!(Request(Unknown("Unknown or invalid Content-Type header")));
};
@@ -108,14 +111,21 @@ pub async fn download_image(&self, url: &str) -> Result<UrlPreviewData> {
use ruma::Mxc;
use tuwunel_core::utils::random_string;
let image = self.services.client.url_preview.get(url).send().await?;
let image = self
.services
.client
.url_preview
.get(url)
.send()
.await?;
let image = image.bytes().await?;
let mxc = Mxc {
server_name: self.services.globals.server_name(),
media_id: &random_string(super::MXC_LENGTH),
};
self.create(&mxc, None, None, None, &image).await?;
self.create(&mxc, None, None, None, &image)
.await?;
let cursor = std::io::Cursor::new(&image);
let (width, height) = match ImageReader::new(cursor).with_guessed_format() {
@@ -152,13 +162,20 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? {
bytes.extend_from_slice(&chunk);
if bytes.len() > self.services.globals.url_preview_max_spider_size() {
if bytes.len()
> self
.services
.globals
.url_preview_max_spider_size()
{
debug!(
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not \
processing the rest of the response body and assuming our necessary data is in \
this range.",
url,
self.services.globals.url_preview_max_spider_size()
self.services
.globals
.url_preview_max_spider_size()
);
break;
}
@@ -177,7 +194,10 @@ async fn download_html(&self, url: &str) -> Result<UrlPreviewData> {
/* use OpenGraph title/description, but fall back to HTML if not available */
data.title = props.get("title").cloned().or(html.title);
data.description = props.get("description").cloned().or(html.description);
data.description = props
.get("description")
.cloned()
.or(html.description);
Ok(data)
}
@@ -214,8 +234,14 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
.services
.globals
.url_preview_domain_explicit_allowlist();
let denylist_domain_explicit = self.services.globals.url_preview_domain_explicit_denylist();
let allowlist_url_contains = self.services.globals.url_preview_url_contains_allowlist();
let denylist_domain_explicit = self
.services
.globals
.url_preview_domain_explicit_denylist();
let allowlist_url_contains = self
.services
.globals
.url_preview_url_contains_allowlist();
if allowlist_domain_contains.contains(&"*".to_owned())
|| allowlist_domain_explicit.contains(&"*".to_owned())
@@ -262,7 +288,11 @@ pub fn url_preview_allowed(&self, url: &Url) -> bool {
}
// check root domain if available and if user has root domain checks
if self.services.globals.url_preview_check_root_domain() {
if self
.services
.globals
.url_preview_check_root_domain()
{
debug!("Checking root domain");
match host.split_once('.') {
| None => return false,

View File

@@ -87,11 +87,14 @@ async fn fetch_thumbnail_authenticated(
timeout_ms,
};
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
let Response { content, .. } = self
.federation_request(mxc, user, server, request)
.await?;
match content {
| FileOrLocation::File(content) =>
self.handle_thumbnail_file(mxc, user, dim, content).await,
self.handle_thumbnail_file(mxc, user, dim, content)
.await,
| FileOrLocation::Location(location) => self.handle_location(mxc, user, &location).await,
}
}
@@ -111,7 +114,9 @@ async fn fetch_content_authenticated(
timeout_ms,
};
let Response { content, .. } = self.federation_request(mxc, user, server, request).await?;
let Response { content, .. } = self
.federation_request(mxc, user, server, request)
.await?;
match content {
| FileOrLocation::File(content) => self.handle_content_file(mxc, user, content).await,
@@ -145,11 +150,14 @@ async fn fetch_thumbnail_unauthenticated(
let Response {
file, content_type, content_disposition, ..
} = self.federation_request(mxc, user, server, request).await?;
} = self
.federation_request(mxc, user, server, request)
.await?;
let content = Content { file, content_type, content_disposition };
self.handle_thumbnail_file(mxc, user, dim, content).await
self.handle_thumbnail_file(mxc, user, dim, content)
.await
}
#[allow(deprecated)]
@@ -173,7 +181,9 @@ async fn fetch_content_unauthenticated(
let Response {
file, content_type, content_disposition, ..
} = self.federation_request(mxc, user, server, request).await?;
} = self
.federation_request(mxc, user, server, request)
.await?;
let content = Content { file, content_type, content_disposition };
@@ -245,11 +255,13 @@ async fn handle_location(
user: Option<&UserId>,
location: &str,
) -> Result<FileMeta> {
self.location_request(location).await.map_err(|error| {
err!(Request(NotFound(
debug_warn!(%mxc, ?user, ?location, ?error, "Fetching media from location failed")
)))
})
self.location_request(location)
.await
.map_err(|error| {
err!(Request(NotFound(
debug_warn!(%mxc, ?user, ?location, ?error, "Fetching media from location failed")
)))
})
}
#[implement(super::Service)]

View File

@@ -67,8 +67,14 @@ impl super::Service {
match self.db.search_file_metadata(mxc, &dim).await {
| Ok(metadata) => self.get_thumbnail_saved(metadata).await,
| _ => match self.db.search_file_metadata(mxc, &Dim::default()).await {
| Ok(metadata) => self.get_thumbnail_generate(mxc, &dim, metadata).await,
| _ => match self
.db
.search_file_metadata(mxc, &Dim::default())
.await
{
| Ok(metadata) =>
self.get_thumbnail_generate(mxc, &dim, metadata)
.await,
| _ => Ok(None),
},
}

View File

@@ -55,7 +55,10 @@ pub(crate) async fn migrations(services: &Services) -> Result<()> {
async fn fresh(services: &Services) -> Result<()> {
let db = &services.db;
services.globals.db.bump_database_version(DATABASE_VERSION);
services
.globals
.db
.bump_database_version(DATABASE_VERSION);
db["global"].insert(b"feat_sha256_media", []);
db["global"].insert(b"fix_bad_double_separator_in_state_cache", []);
@@ -64,7 +67,9 @@ async fn fresh(services: &Services) -> Result<()> {
db["global"].insert(b"fix_readreceiptid_readreceipt_duplicates", []);
// Create the admin room and server user on first run
crate::admin::create_admin_room(services).boxed().await?;
crate::admin::create_admin_room(services)
.boxed()
.await?;
warn!("Created new RocksDB database with version {DATABASE_VERSION}");
@@ -93,7 +98,11 @@ async fn migrate(services: &Services) -> Result<()> {
db_lt_13(services).await?;
}
if db["global"].get(b"feat_sha256_media").await.is_not_found() {
if db["global"]
.get(b"feat_sha256_media")
.await
.is_not_found()
{
media::migrations::migrate_sha256_media(services).await?;
} else if config.media_startup_check {
media::migrations::checkup_sha256_media(services).await?;
@@ -241,7 +250,9 @@ async fn db_lt_12(services: &Services) -> Result<()> {
let content_rule_transformation =
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
let rule = rules_list.content.get(content_rule_transformation[0]);
let rule = rules_list
.content
.get(content_rule_transformation[0]);
if rule.is_some() {
let mut rule = rule.unwrap().clone();
content_rule_transformation[1].clone_into(&mut rule.rule_id);
@@ -267,7 +278,9 @@ async fn db_lt_12(services: &Services) -> Result<()> {
if let Some(rule) = rule {
let mut rule = rule.clone();
transformation[1].clone_into(&mut rule.rule_id);
rules_list.underride.shift_remove(transformation[0]);
rules_list
.underride
.shift_remove(transformation[0]);
rules_list.underride.insert(rule);
}
}
@@ -278,7 +291,9 @@ async fn db_lt_12(services: &Services) -> Result<()> {
.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
@@ -323,7 +338,9 @@ async fn db_lt_13(services: &Services) -> Result<()> {
.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(account_data).expect("to json value always works"),
)
.await?;
@@ -435,12 +452,18 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
for user_id in &joined_members {
debug_info!("User is joined, marking as joined");
services.rooms.state_cache.mark_as_joined(user_id, room_id);
services
.rooms
.state_cache
.mark_as_joined(user_id, room_id);
}
for user_id in &non_joined_members {
debug_info!("User is left or banned, marking as left");
services.rooms.state_cache.mark_as_left(user_id, room_id);
services
.rooms
.state_cache
.mark_as_left(user_id, room_id);
}
}

View File

@@ -84,8 +84,13 @@ impl Data {
let now = utils::millis_since_unix_epoch();
let last_last_active_ts = match last_presence {
| Err(_) => 0,
| Ok((_, ref presence)) =>
now.saturating_sub(presence.content.last_active_ago.unwrap_or_default().into()),
| Ok((_, ref presence)) => now.saturating_sub(
presence
.content
.last_active_ago
.unwrap_or_default()
.into(),
),
};
let last_active_ts = match last_active_ago {
@@ -118,7 +123,8 @@ impl Data {
let count = self.services.globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence.raw_put(key, Json(presence));
self.presenceid_presence
.raw_put(key, Json(presence));
self.userid_presenceid.raw_put(user_id, count);
if let Ok((last_count, _)) = last_presence {

View File

@@ -110,8 +110,11 @@ impl Service {
let last_last_active_ago = match last_presence {
| Err(_) => 0_u64,
| Ok((_, ref presence)) =>
presence.content.last_active_ago.unwrap_or_default().into(),
| Ok((_, ref presence)) => presence
.content
.last_active_ago
.unwrap_or_default()
.into(),
};
if !state_changed && last_last_active_ago < REFRESH_TIMEOUT {
@@ -151,8 +154,16 @@ impl Service {
&& user_id != self.services.globals.server_user
{
let timeout = match presence_state {
| PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
| _ => self.services.server.config.presence_offline_timeout_s,
| PresenceState::Online =>
self.services
.server
.config
.presence_idle_timeout_s,
| _ =>
self.services
.server
.config
.presence_offline_timeout_s,
};
self.timer_channel

View File

@@ -129,10 +129,13 @@ impl Service {
let pushkey = data.pusher.ids.pushkey.as_str();
let key = (sender, pushkey);
self.db.senderkey_pusher.put(key, Json(pusher));
self.db.pushkey_deviceid.insert(pushkey, sender_device);
self.db
.pushkey_deviceid
.insert(pushkey, sender_device);
},
| set_pusher::v3::PusherAction::Delete(ids) => {
self.delete_pusher(sender, ids.pushkey.as_str()).await;
self.delete_pusher(sender, ids.pushkey.as_str())
.await;
},
}
@@ -152,7 +155,11 @@ impl Service {
}
pub async fn get_pusher_device(&self, pushkey: &str) -> Result<OwnedDeviceId> {
self.db.pushkey_deviceid.get(pushkey).await.deserialized()
self.db
.pushkey_deviceid
.get(pushkey)
.await
.deserialized()
}
pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Pusher> {
@@ -217,7 +224,12 @@ impl Service {
}
}
let response = self.services.client.pusher.execute(reqwest_request).await;
let response = self
.services
.client
.pusher
.execute(reqwest_request)
.await;
match response {
| Ok(mut response) => {
@@ -319,7 +331,8 @@ impl Service {
}
if notify == Some(true) {
self.send_notice(unread, pusher, tweaks, pdu).await?;
self.send_notice(unread, pusher, tweaks, pdu)
.await?;
}
// Else the event triggered no actions
@@ -460,8 +473,12 @@ impl Service {
event.state_key.as_deref() == Some(event.sender.as_str());
}
notifi.sender_display_name =
self.services.users.displayname(&event.sender).await.ok();
notifi.sender_display_name = self
.services
.users
.displayname(&event.sender)
.await
.ok();
notifi.room_name = self
.services

View File

@@ -77,10 +77,12 @@ impl super::Service {
self.services.server.check_running()?;
match self.request_well_known(dest.as_str()).await? {
| Some(delegated) =>
self.actual_dest_3(&mut host, cache, delegated).await?,
self.actual_dest_3(&mut host, cache, delegated)
.await?,
| _ => match self.query_srv_record(dest.as_str()).await? {
| Some(overrider) =>
self.actual_dest_4(&host, cache, overrider).await?,
self.actual_dest_4(&host, cache, overrider)
.await?,
| _ => self.actual_dest_5(dest, cache).await?,
},
}
@@ -97,7 +99,8 @@ impl super::Service {
let (host, port) = host.split_at(pos);
FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
port.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
)
} else {
FedDest::Named(host, FedDest::default_port())
@@ -124,7 +127,8 @@ impl super::Service {
Ok(FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
port.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
))
}
@@ -145,7 +149,8 @@ impl super::Service {
trace!("Delegated hostname has no port in this branch");
match self.query_srv_record(&delegated).await? {
| Some(overrider) =>
self.actual_dest_3_3(cache, delegated, overrider).await,
self.actual_dest_3_3(cache, delegated, overrider)
.await,
| _ => self.actual_dest_3_4(cache, delegated).await,
}
},
@@ -170,7 +175,8 @@ impl super::Service {
Ok(FedDest::Named(
host.to_owned(),
port.try_into().unwrap_or_else(|_| FedDest::default_port()),
port.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
))
}
@@ -287,17 +293,23 @@ impl super::Service {
self.services.server.check_running()?;
debug!("querying IP for {untername:?} ({hostname:?}:{port})");
match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
match self
.resolver
.resolver
.lookup_ip(hostname.to_owned())
.await
{
| Err(e) => Self::handle_resolve_error(&e, hostname),
| Ok(override_ip) => {
self.cache.set_override(untername, &CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),
overriding: (hostname != untername)
.then_some(hostname.into())
.inspect(|_| debug_info!("{untername:?} overriden by {hostname:?}")),
});
self.cache
.set_override(untername, &CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(),
port,
expire: CachedOverride::default_expire(),
overriding: (hostname != untername)
.then_some(hostname.into())
.inspect(|_| debug_info!("{untername:?} overriden by {hostname:?}")),
});
Ok(())
},
@@ -319,7 +331,11 @@ impl super::Service {
| Ok(result) => {
return Ok(result.iter().next().map(|result| {
FedDest::Named(
result.target().to_string().trim_end_matches('.').to_owned(),
result
.target()
.to_string()
.trim_end_matches('.')
.to_owned(),
format!(":{}", result.port())
.as_str()
.try_into()

View File

@@ -127,7 +127,10 @@ async fn hooked_resolve(
.boxed()
.await,
| _ => resolve_to_reqwest(server, resolver, name).boxed().await,
| _ =>
resolve_to_reqwest(server, resolver, name)
.boxed()
.await,
}
}
@@ -139,8 +142,13 @@ async fn resolve_to_reqwest(
use std::{io, io::ErrorKind::Interrupted};
let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
let handle_results =
|results: LookupIp| Box::new(results.into_iter().map(|ip| SocketAddr::new(ip, 0)));
let handle_results = |results: LookupIp| {
Box::new(
results
.into_iter()
.map(|ip| SocketAddr::new(ip, 0)),
)
};
tokio::select! {
results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),

View File

@@ -87,7 +87,9 @@ impl Service {
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.db.aliasid_alias.insert(&aliasid, alias.as_bytes());
self.db
.aliasid_alias
.insert(&aliasid, alias.as_bytes());
Ok(())
}
@@ -171,7 +173,11 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<OwnedRoomId> {
self.db.alias_roomid.get(alias.alias()).await.deserialized()
self.db
.alias_roomid
.get(alias.alias())
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
@@ -248,7 +254,11 @@ impl Service {
}
async fn who_created_alias(&self, alias: &RoomAliasId) -> Result<OwnedUserId> {
self.db.alias_userid.get(alias.alias()).await.deserialized()
self.db
.alias_userid
.get(alias.alias())
.await
.deserialized()
}
async fn resolve_appservice_alias(

View File

@@ -136,7 +136,10 @@ async fn get_auth_chain_outer(
return Ok(Vec::new());
}
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
if let Ok(cached) = self
.get_cached_eventid_authchain(&chunk_key)
.await
{
return Ok(cached.to_vec());
}
@@ -144,11 +147,16 @@ async fn get_auth_chain_outer(
.into_iter()
.try_stream()
.broad_and_then(|(shortid, event_id)| async move {
if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await {
if let Ok(cached) = self
.get_cached_eventid_authchain(&[shortid])
.await
{
return Ok(cached.to_vec());
}
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
let auth_chain = self
.get_auth_chain_inner(room_id, event_id)
.await?;
self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice());
debug!(
?event_id,
@@ -254,4 +262,10 @@ pub fn get_cache_usage(&self) -> (usize, usize) {
}
#[implement(Service)]
pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); }
pub fn clear_cache(&self) {
self.db
.auth_chain_cache
.lock()
.expect("locked")
.clear();
}

View File

@@ -25,8 +25,12 @@ pub async fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Res
return Ok(());
}
if acl_event_content.deny.contains(&String::from("*"))
&& acl_event_content.allow.contains(&String::from("*"))
if acl_event_content
.deny
.contains(&String::from("*"))
&& acl_event_content
.allow
.contains(&String::from("*"))
{
warn!(%room_id, "Ignoring broken ACL event (allow key and deny key both contain wildcard \"*\"");
return Ok(());

View File

@@ -66,7 +66,11 @@ pub async fn handle_incoming_pdu<'a>(
let meta_exists = self.services.metadata.exists(room_id).map(Ok);
// 1.2 Check if the room is disabled
let is_disabled = self.services.metadata.is_disabled(room_id).map(Ok);
let is_disabled = self
.services
.metadata
.is_disabled(room_id)
.map(Ok);
// 1.3.1 Check room ACL on origin field/server
let origin_acl_check = self.acl_check(origin, room_id);

View File

@@ -100,7 +100,11 @@ impl Service {
}
async fn event_fetch(&self, event_id: OwnedEventId) -> Option<PduEvent> {
self.services.timeline.get_pdu(&event_id).await.ok()
self.services
.timeline
.get_pdu(&event_id)
.await
.ok()
}
}

View File

@@ -92,7 +92,11 @@ pub async fn resolve_state(
let new_room_state: CompressedState = self
.services
.state_compressor
.compress_state_events(state_events.iter().map(|(ssk, eid)| (ssk, (*eid).borrow())))
.compress_state_events(
state_events
.iter()
.map(|(ssk, eid)| (ssk, (*eid).borrow())),
)
.collect()
.await;

View File

@@ -54,7 +54,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
debug!("Resolving state at event");
let mut state_at_incoming_event = if incoming_pdu.prev_events.len() == 1 {
self.state_at_incoming_degree_one(&incoming_pdu).await?
self.state_at_incoming_degree_one(&incoming_pdu)
.await?
} else {
self.state_at_incoming_resolved(&incoming_pdu, room_id, &room_version_id)
.await?
@@ -74,10 +75,19 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// 11. Check the auth of the event passes based on the state of the event
let state_fetch_state = &state_at_incoming_event;
let state_fetch = |k: StateEventType, s: StateKey| async move {
let shortstatekey = self.services.short.get_shortstatekey(&k, &s).await.ok()?;
let shortstatekey = self
.services
.short
.get_shortstatekey(&k, &s)
.await
.ok()?;
let event_id = state_fetch_state.get(&shortstatekey)?;
self.services.timeline.get_pdu(event_id).await.ok()
self.services
.timeline
.get_pdu(event_id)
.await
.ok()
};
let auth_check = state_res::event_auth::auth_check(

View File

@@ -76,7 +76,9 @@ pub async fn witness_retain(&self, senders: Witness, ctx: &Context<'_>) -> Witne
);
let include_redundant = cfg!(feature = "element_hacks")
|| ctx.options.is_some_and(Options::include_redundant_members);
|| ctx
.options
.is_some_and(Options::include_redundant_members);
let witness = self
.witness(ctx, senders.iter().map(AsRef::as_ref))

View File

@@ -48,5 +48,7 @@ pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result<PduEvent> {
#[implement(Service)]
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) {
self.db.eventid_outlierpdu.raw_put(event_id, Json(pdu));
self.db
.eventid_outlierpdu
.raw_put(event_id, Json(pdu));
}

View File

@@ -52,7 +52,8 @@ impl Data {
const BUFSIZE: usize = size_of::<u64>() * 2;
let key: &[u64] = &[to, from];
self.tofrom_relation.aput_raw::<BUFSIZE, _, _>(key, []);
self.tofrom_relation
.aput_raw::<BUFSIZE, _, _>(key, []);
}
pub(super) fn get_relations<'a>(
@@ -65,11 +66,21 @@ impl Data {
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
let mut current = ArrayVec::<u8, 16>::new();
current.extend(target.to_be_bytes());
current.extend(from.saturating_inc(dir).into_unsigned().to_be_bytes());
current.extend(
from.saturating_inc(dir)
.into_unsigned()
.to_be_bytes(),
);
let current = current.as_slice();
match dir {
| Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
| Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
| Direction::Forward => self
.tofrom_relation
.raw_keys_from(current)
.boxed(),
| Direction::Backward => self
.tofrom_relation
.rev_raw_keys_from(current)
.boxed(),
}
.ignore_err()
.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
@@ -78,7 +89,12 @@ impl Data {
.wide_filter_map(move |shorteventid| async move {
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pdu_id)
.await
.ok()?;
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
@@ -109,6 +125,9 @@ impl Data {
}
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.softfailedeventids.get(event_id).await.is_ok()
self.softfailedeventids
.get(event_id)
.await
.is_ok()
}
}

View File

@@ -119,7 +119,9 @@ impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
self.db.is_event_referenced(room_id, event_id).await
self.db
.is_event_referenced(room_id, event_id)
.await
}
#[inline]

View File

@@ -58,7 +58,8 @@ impl Data {
let count = self.services.globals.next_count().unwrap();
let latest_id = (room_id, count, user_id);
self.readreceiptid_readreceipt.put(latest_id, Json(event));
self.readreceiptid_readreceipt
.put(latest_id, Json(event));
}
pub(super) fn readreceipts_since<'a>(
@@ -91,7 +92,8 @@ impl Data {
let next_count = self.services.globals.next_count().unwrap();
self.roomuserid_privateread.put(key, pdu_count);
self.roomuserid_lastprivatereadupdate.put(key, next_count);
self.roomuserid_lastprivatereadupdate
.put(key, next_count);
}
pub(super) async fn private_read_get_count(
@@ -100,7 +102,10 @@ impl Data {
user_id: &UserId,
) -> Result<u64> {
let key = (room_id, user_id);
self.roomuserid_privateread.qry(&key).await.deserialized()
self.roomuserid_privateread
.qry(&key)
.await
.deserialized()
}
pub(super) async fn last_privateread_update(

View File

@@ -54,7 +54,9 @@ impl Service {
room_id: &RoomId,
event: &ReceiptEvent,
) {
self.db.readreceipt_update(user_id, room_id, event).await;
self.db
.readreceipt_update(user_id, room_id, event)
.await;
self.services
.sending
.flush_room(room_id)
@@ -68,18 +70,30 @@ impl Service {
room_id: &RoomId,
user_id: &UserId,
) -> Result<Raw<AnySyncEphemeralRoomEvent>> {
let pdu_count = self.private_read_get_count(room_id, user_id).map_err(|e| {
err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))
});
let shortroomid = self.services.short.get_shortroomid(room_id).map_err(|e| {
err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))
});
let pdu_count = self
.private_read_get_count(room_id, user_id)
.map_err(|e| {
err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))
});
let shortroomid = self
.services
.short
.get_shortroomid(room_id)
.map_err(|e| {
err!(Database(warn!(
"Short room ID does not exist in database for {room_id}: {e}"
)))
});
let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?;
let shorteventid = PduCount::Normal(pdu_count);
let pdu_id: RawPduId = PduId { shortroomid, shorteventid }.into();
let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?;
let pdu = self
.services
.timeline
.get_pdu_from_id(&pdu_id)
.await?;
let event_id: OwnedEventId = pdu.event_id;
let user_id: OwnedUserId = user_id.to_owned();
@@ -129,13 +143,17 @@ impl Service {
room_id: &RoomId,
user_id: &UserId,
) -> Result<u64> {
self.db.private_read_get_count(room_id, user_id).await
self.db
.private_read_get_count(room_id, user_id)
.await
}
/// Returns the PDU count of the last typing update in this room.
#[inline]
pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.last_privateread_update(user_id, room_id).await
self.db
.last_privateread_update(user_id, room_id)
.await
}
}

View File

@@ -139,9 +139,15 @@ pub async fn search_pdu_ids(
&self,
query: &RoomQuery<'_>,
) -> Result<impl Stream<Item = RawPduId> + Send + '_ + use<'_>> {
let shortroomid = self.services.short.get_shortroomid(query.room_id).await?;
let shortroomid = self
.services
.short
.get_shortroomid(query.room_id)
.await?;
let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await;
let pdu_ids = self
.search_pdu_ids_query_room(query, shortroomid)
.await;
let iters = pdu_ids.into_iter().map(IntoIterator::into_iter);

View File

@@ -112,7 +112,10 @@ pub async fn get_or_create_shortstatekey(
) -> ShortStateKey {
const BUFSIZE: usize = size_of::<ShortStateKey>();
if let Ok(shortstatekey) = self.get_shortstatekey(event_type, state_key).await {
if let Ok(shortstatekey) = self
.get_shortstatekey(event_type, state_key)
.await
{
return shortstatekey;
}
@@ -235,7 +238,11 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortSta
#[implement(Service)]
pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result<ShortRoomId> {
self.db.roomid_shortroomid.get(room_id).await.deserialized()
self.db
.roomid_shortroomid
.get(room_id)
.await
.deserialized()
}
#[implement(Service)]

View File

@@ -92,14 +92,23 @@ impl crate::Service for Service {
}
async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
let roomid_spacehierarchy_cache = self.roomid_spacehierarchy_cache.lock().await.len();
let roomid_spacehierarchy_cache = self
.roomid_spacehierarchy_cache
.lock()
.await
.len();
writeln!(out, "roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}")?;
Ok(())
}
async fn clear_cache(&self) { self.roomid_spacehierarchy_cache.lock().await.clear(); }
async fn clear_cache(&self) {
self.roomid_spacehierarchy_cache
.lock()
.await
.clear();
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
@@ -121,7 +130,11 @@ pub async fn get_summary_and_children_local(
| None => (), // cache miss
| Some(None) => return Ok(None),
| Some(Some(cached)) => {
let allowed_rooms = cached.summary.allowed_room_ids.iter().map(AsRef::as_ref);
let allowed_rooms = cached
.summary
.allowed_room_ids
.iter()
.map(AsRef::as_ref);
let is_accessible_child = self.is_accessible_child(
current_room,
@@ -154,10 +167,13 @@ pub async fn get_summary_and_children_local(
return Ok(None);
};
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.to_owned(),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
self.roomid_spacehierarchy_cache
.lock()
.await
.insert(
current_room.to_owned(),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
Ok(Some(SummaryAccessibility::Accessible(summary)))
}
@@ -196,10 +212,13 @@ async fn get_summary_and_children_federation(
};
let summary = response.room;
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.to_owned(),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
self.roomid_spacehierarchy_cache
.lock()
.await
.insert(
current_room.to_owned(),
Some(CachedSpaceHierarchySummary { summary: summary.clone() }),
);
response
.children
@@ -304,7 +323,11 @@ async fn get_room_summary(
children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
let join_rule = self.services.state_accessor.get_join_rules(room_id).await;
let join_rule = self
.services
.state_accessor
.get_join_rules(room_id)
.await;
let is_accessible_child = self
.is_accessible_child(
@@ -319,15 +342,33 @@ async fn get_room_summary(
return Err!(Request(Forbidden("User is not allowed to see the room")));
}
let name = self.services.state_accessor.get_name(room_id).ok();
let name = self
.services
.state_accessor
.get_name(room_id)
.ok();
let topic = self.services.state_accessor.get_room_topic(room_id).ok();
let topic = self
.services
.state_accessor
.get_room_topic(room_id)
.ok();
let room_type = self.services.state_accessor.get_room_type(room_id).ok();
let room_type = self
.services
.state_accessor
.get_room_type(room_id)
.ok();
let world_readable = self.services.state_accessor.is_world_readable(room_id);
let world_readable = self
.services
.state_accessor
.is_world_readable(room_id);
let guest_can_join = self.services.state_accessor.guest_can_join(room_id);
let guest_can_join = self
.services
.state_accessor
.guest_can_join(room_id);
let num_joined_members = self
.services
@@ -392,7 +433,10 @@ async fn get_room_summary(
room_version,
room_id: room_id.to_owned(),
num_joined_members: num_joined_members.try_into().unwrap_or_default(),
allowed_room_ids: join_rule.allowed_rooms().map(Into::into).collect(),
allowed_room_ids: join_rule
.allowed_rooms()
.map(Into::into)
.collect(),
join_rule: join_rule.clone().into(),
};
@@ -425,9 +469,15 @@ where
}
if let Identifier::UserId(user_id) = identifier {
let is_joined = self.services.state_cache.is_joined(user_id, current_room);
let is_joined = self
.services
.state_cache
.is_joined(user_id, current_room);
let is_invited = self.services.state_cache.is_invited(user_id, current_room);
let is_invited = self
.services
.state_cache
.is_invited(user_id, current_room);
pin_mut!(is_joined, is_invited);
if is_joined.or(is_invited).await {
@@ -444,9 +494,15 @@ where
.stream()
.any(async |room| match identifier {
| Identifier::UserId(user) =>
self.services.state_cache.is_joined(user, room).await,
self.services
.state_cache
.is_joined(user, room)
.await,
| Identifier::ServerName(server) =>
self.services.state_cache.server_in_room(server, room).await,
self.services
.state_cache
.server_in_room(server, room)
.await,
})
.await,

View File

@@ -120,7 +120,11 @@ impl Service {
match pdu.kind {
| TimelineEventType::RoomMember => {
let Some(user_id) = pdu.state_key.as_ref().map(UserId::parse).flat_ok()
let Some(user_id) = pdu
.state_key
.as_ref()
.map(UserId::parse)
.flat_ok()
else {
continue;
};
@@ -154,7 +158,10 @@ impl Service {
}
}
self.services.state_cache.update_joined_count(room_id).await;
self.services
.state_cache
.update_joined_count(room_id)
.await;
self.set_room_state(room_id, shortstatehash, state_lock);
@@ -218,13 +225,15 @@ impl Service {
} else {
(state_ids_compressed, Arc::new(CompressedState::new()))
};
self.services.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
1_000_000, // high number because no state will be based on this one
states_parents,
)?;
self.services
.state_compressor
.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
1_000_000, // high number because no state will be based on this one
states_parents,
)?;
}
self.db
@@ -248,7 +257,9 @@ impl Service {
.get_or_create_shorteventid(&new_pdu.event_id)
.await;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await;
let previous_shortstatehash = self
.get_room_shortstatehash(&new_pdu.room_id)
.await;
if let Ok(p) = previous_shortstatehash {
self.db
@@ -303,13 +314,15 @@ impl Service {
statediffremoved.insert(*replaces);
}
self.services.state_compressor.save_state_from_diff(
shortstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2,
states_parents,
)?;
self.services
.state_compressor
.save_state_from_diff(
shortstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
2,
states_parents,
)?;
Ok(shortstatehash)
},

View File

@@ -49,7 +49,11 @@ pub fn room_state_full_pdus<'a>(
self.services
.state
.get_room_shortstatehash(room_id)
.map_ok(|shortstatehash| self.state_full_pdus(shortstatehash).map(Ok).boxed())
.map_ok(|shortstatehash| {
self.state_full_pdus(shortstatehash)
.map(Ok)
.boxed()
})
.map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.try_flatten_stream()
}

View File

@@ -29,7 +29,8 @@ use crate::rooms::{
#[implement(super::Service)]
#[inline]
pub async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).await == MembershipState::Join
self.user_membership(shortstatehash, user_id)
.await == MembershipState::Join
}
/// The user was an invited or joined room member at this state (potentially
@@ -37,7 +38,9 @@ pub async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &Us
#[implement(super::Service)]
#[inline]
pub async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool {
let s = self.user_membership(shortstatehash, user_id).await;
let s = self
.user_membership(shortstatehash, user_id)
.await;
s == MembershipState::Join || s == MembershipState::Invite
}
@@ -259,7 +262,9 @@ pub fn state_keys_with_shortids<'a>(
.zip(shorteventids)
.ready_filter_map(|(res, id)| res.map(|res| (res, id)).ok())
.ready_filter_map(move |((event_type_, state_key), event_id)| {
event_type_.eq(event_type).then_some((state_key, event_id))
event_type_
.eq(event_type)
.then_some((state_key, event_id))
})
}
@@ -338,7 +343,11 @@ pub fn state_full_pdus(
.multi_get_eventid_from_short(short_ids)
.ready_filter_map(Result::ok)
.broad_filter_map(move |event_id: OwnedEventId| async move {
self.services.timeline.get_pdu(&event_id).await.ok()
self.services
.timeline
.get_pdu(&event_id)
.await
.ok()
})
}
@@ -406,7 +415,12 @@ async fn load_full_state(&self, shortstatehash: ShortStateHash) -> Result<Arc<Co
.state_compressor
.load_shortstatehash_info(shortstatehash)
.map_err(|e| err!(Database("Missing state IDs: {e}")))
.map_ok(|vec| vec.last().expect("at least one layer").full_state.clone())
.map_ok(|vec| {
vec.last()
.expect("at least one layer")
.full_state
.clone()
})
.await
}

View File

@@ -98,7 +98,11 @@ pub async fn user_can_see_event(
return true;
};
let currently_member = self.services.state_cache.is_joined(user_id, room_id).await;
let currently_member = self
.services
.state_cache
.is_joined(user_id, room_id)
.await;
let history_visibility = self
.state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "")
@@ -110,11 +114,13 @@ pub async fn user_can_see_event(
match history_visibility {
| HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, user_id).await
self.user_was_invited(shortstatehash, user_id)
.await
},
| HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id).await
self.user_was_joined(shortstatehash, user_id)
.await
},
| HistoryVisibility::WorldReadable => true,
| HistoryVisibility::Shared | _ => currently_member,
@@ -126,7 +132,12 @@ pub async fn user_can_see_event(
#[implement(super::Service)]
#[tracing::instrument(skip_all, level = "trace")]
pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool {
if self.services.state_cache.is_joined(user_id, room_id).await {
if self
.services
.state_cache
.is_joined(user_id, room_id)
.await
{
return true;
}
@@ -139,7 +150,10 @@ pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId
match history_visibility {
| HistoryVisibility::Invited =>
self.services.state_cache.is_invited(user_id, room_id).await,
self.services
.state_cache
.is_invited(user_id, room_id)
.await,
| HistoryVisibility::WorldReadable => true,
| _ => false,
}

View File

@@ -248,7 +248,9 @@ impl Service {
.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
GlobalAccountDataEventType::Direct
.to_string()
.into(),
&serde_json::to_value(&direct_event)
.expect("to json always works"),
)
@@ -262,7 +264,12 @@ impl Service {
},
| MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
if self.services.users.user_is_ignored(sender, user_id).await {
if self
.services
.users
.user_is_ignored(sender, user_id)
.await
{
return Ok(());
}
@@ -346,14 +353,22 @@ impl Service {
self.db.userroomid_joined.insert(&userroom_id, []);
self.db.roomuserid_joined.insert(&roomuser_id, []);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db
.userroomid_invitestate
.remove(&userroom_id);
self.db
.roomuserid_invitecount
.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db
.userroomid_knockedstate
.remove(&userroom_id);
self.db
.roomuserid_knockedcount
.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@@ -382,11 +397,19 @@ impl Service {
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db
.userroomid_invitestate
.remove(&userroom_id);
self.db
.roomuserid_invitecount
.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db
.userroomid_knockedstate
.remove(&userroom_id);
self.db
.roomuserid_knockedcount
.remove(&roomuser_id);
self.db.roomid_inviteviaservers.remove(room_id);
}
@@ -417,8 +440,12 @@ impl Service {
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_invitestate.remove(&userroom_id);
self.db.roomuserid_invitecount.remove(&roomuser_id);
self.db
.userroomid_invitestate
.remove(&userroom_id);
self.db
.roomuserid_invitecount
.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
@@ -523,7 +550,11 @@ impl Service {
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "trace")]
pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> {
self.db.roomid_joinedcount.get(room_id).await.deserialized()
self.db
.roomid_joinedcount
.get(room_id)
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
@@ -623,7 +654,11 @@ impl Service {
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db.roomuserid_leftcount.qry(&key).await.deserialized()
self.db
.roomuserid_leftcount
.qry(&key)
.await
.deserialized()
}
/// Returns an iterator over all rooms this user joined.
@@ -750,7 +785,11 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.roomuseroncejoinedids.qry(&key).await.is_ok()
self.db
.roomuseroncejoinedids
.qry(&key)
.await
.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
@@ -762,19 +801,31 @@ impl Service {
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_knocked<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_knockedstate.qry(&key).await.is_ok()
self.db
.userroomid_knockedstate
.qry(&key)
.await
.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_invitestate.qry(&key).await.is_ok()
self.db
.userroomid_invitestate
.qry(&key)
.await
.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_leftstate.qry(&key).await.is_ok()
self.db
.userroomid_leftstate
.qry(&key)
.await
.is_ok()
}
#[tracing::instrument(skip(self), level = "trace")]
@@ -856,7 +907,10 @@ impl Service {
}
pub fn get_appservice_in_room_cache_usage(&self) -> (usize, usize) {
let cache = self.appservice_in_room_cache.read().expect("locked");
let cache = self
.appservice_in_room_cache
.read()
.expect("locked");
(cache.len(), cache.capacity())
}
@@ -899,8 +953,12 @@ impl Service {
.unwrap_or(0),
);
self.db.roomid_joinedcount.raw_put(room_id, joinedcount);
self.db.roomid_invitedcount.raw_put(room_id, invitedcount);
self.db
.roomid_joinedcount
.raw_put(room_id, joinedcount);
self.db
.roomid_invitedcount
.raw_put(room_id, invitedcount);
self.db
.roomuserid_knockedcount
.raw_put(room_id, knockedcount);
@@ -968,11 +1026,16 @@ impl Service {
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
self.db.userroomid_knockedstate.remove(&userroom_id);
self.db.roomuserid_knockedcount.remove(&roomuser_id);
self.db
.userroomid_knockedstate
.remove(&userroom_id);
self.db
.roomuserid_knockedcount
.remove(&roomuser_id);
if let Some(servers) = invite_via.filter(is_not_empty!()) {
self.add_servers_invite_via(room_id, servers).await;
self.add_servers_invite_via(room_id, servers)
.await;
}
}

View File

@@ -87,22 +87,26 @@ impl crate::Service for Service {
async fn memory_usage(&self, out: &mut (dyn Write + Send)) -> Result {
let (cache_len, ents) = {
let cache = self.stateinfo_cache.lock().expect("locked");
let ents = cache.iter().map(at!(1)).flat_map(|vec| vec.iter()).fold(
HashMap::new(),
|mut ents, ssi| {
let ents = cache
.iter()
.map(at!(1))
.flat_map(|vec| vec.iter())
.fold(HashMap::new(), |mut ents, ssi| {
for cs in &[&ssi.added, &ssi.removed, &ssi.full_state] {
ents.insert(Arc::as_ptr(cs), compressed_state_size(cs));
}
ents
},
);
});
(cache.len(), ents)
};
let ents_len = ents.len();
let bytes = ents.values().copied().fold(0_usize, usize::saturating_add);
let bytes = ents
.values()
.copied()
.fold(0_usize, usize::saturating_add);
let bytes = bytes::pretty(bytes);
writeln!(out, "stateinfo_cache: {cache_len} {ents_len} ({bytes})")?;
@@ -110,7 +114,12 @@ impl crate::Service for Service {
Ok(())
}
async fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); }
async fn clear_cache(&self) {
self.stateinfo_cache
.lock()
.expect("locked")
.clear();
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
@@ -123,11 +132,17 @@ impl Service {
&self,
shortstatehash: ShortStateHash,
) -> Result<ShortStateInfoVec> {
if let Some(r) = self.stateinfo_cache.lock()?.get_mut(&shortstatehash) {
if let Some(r) = self
.stateinfo_cache
.lock()?
.get_mut(&shortstatehash)
{
return Ok(r.clone());
}
let stack = self.new_shortstatehash_info(shortstatehash).await?;
let stack = self
.new_shortstatehash_info(shortstatehash)
.await?;
self.cache_shortstatehash_info(shortstatehash, stack.clone())
.await?;
@@ -151,7 +166,9 @@ impl Service {
shortstatehash: ShortStateHash,
stack: ShortStateInfoVec,
) -> Result {
self.stateinfo_cache.lock()?.insert(shortstatehash, stack);
self.stateinfo_cache
.lock()?
.insert(shortstatehash, stack);
Ok(())
}
@@ -262,7 +279,9 @@ impl Service {
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().expect("parent must have a state");
let parent = parent_states
.pop()
.expect("parent must have a state");
let mut parent_new = (*parent.added).clone();
let mut parent_removed = (*parent.removed).clone();
@@ -311,7 +330,9 @@ impl Service {
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let parent = parent_states.pop().expect("parent must have a state");
let parent = parent_states
.pop()
.expect("parent must have a state");
let parent_added_len = parent.added.len();
let parent_removed_len = parent.removed.len();
let parent_diff = checked!(parent_added_len + parent_removed_len)?;
@@ -373,8 +394,11 @@ impl Service {
.await
.ok();
let state_hash =
utils::calculate_hash(new_state_ids_compressed.iter().map(|bytes| &bytes[..]));
let state_hash = utils::calculate_hash(
new_state_ids_compressed
.iter()
.map(|bytes| &bytes[..]),
);
let (new_shortstatehash, already_existed) = self
.services
@@ -390,7 +414,9 @@ impl Service {
}
let states_parents = if let Some(p) = previous_shortstatehash {
self.load_shortstatehash_info(p).await.unwrap_or_default()
self.load_shortstatehash_info(p)
.await
.unwrap_or_default()
} else {
ShortStateInfoVec::new()
};

View File

@@ -141,7 +141,11 @@ impl Service {
shorteventid: PduCount,
_inc: &'a IncludeThreads,
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
let shortroomid: ShortRoomId = self
.services
.short
.get_shortroomid(room_id)
.await?;
let current: RawPduId = PduId {
shortroomid,
@@ -157,7 +161,12 @@ impl Service {
.map(RawPduId::from)
.ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes())
.wide_filter_map(move |pdu_id| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pdu_id)
.await
.ok()?;
let pdu_id: PduId = pdu_id.into();
if pdu.sender != user_id {
@@ -187,6 +196,10 @@ impl Service {
}
pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
self.db.threadid_userids.get(root_id).await.deserialized()
self.db
.threadid_userids
.get(root_id)
.await
.deserialized()
}
}

View File

@@ -159,7 +159,9 @@ impl Data {
let non_outlier = self.non_outlier_pdu_exists(event_id).boxed();
let outlier = self.outlier_pdu_exists(event_id).boxed();
select_ok([non_outlier, outlier]).await.map(at!(0))
select_ok([non_outlier, outlier])
.await
.map(at!(0))
}
/// Returns the pdu.
@@ -187,8 +189,10 @@ impl Data {
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
self.eventid_pduid
.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu
.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(

View File

@@ -184,7 +184,9 @@ impl Service {
sender_user: Option<&UserId>,
room_id: &RoomId,
) -> Result<PduCount> {
self.db.last_timeline_count(sender_user, room_id).await
self.db
.last_timeline_count(sender_user, room_id)
.await
}
/// Returns the `count` of this pdu's id.
@@ -363,7 +365,9 @@ impl Service {
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
self.db
.append_pdu(&pdu_id, pdu, &pdu_json, count2)
.await;
drop(insert_lock);
@@ -395,7 +399,12 @@ impl Service {
if let Some(state_key) = &pdu.state_key {
let target_user_id = UserId::parse(state_key)?;
if self.services.users.is_active_local(target_user_id).await {
if self
.services
.users
.is_active_local(target_user_id)
.await
{
push_target.insert(target_user_id.to_owned());
}
}
@@ -462,7 +471,11 @@ impl Service {
| TimelineEventType::RoomRedaction => {
use RoomVersionId::*;
let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?;
let room_version_id = self
.services
.state
.get_room_version(&pdu.room_id)
.await?;
match room_version_id {
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
@@ -472,7 +485,8 @@ impl Service {
.user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)
.await?
{
self.redact_pdu(redact_id, pdu, shortroomid).await?;
self.redact_pdu(redact_id, pdu, shortroomid)
.await?;
}
}
},
@@ -485,7 +499,8 @@ impl Service {
.user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)
.await?
{
self.redact_pdu(redact_id, pdu, shortroomid).await?;
self.redact_pdu(redact_id, pdu, shortroomid)
.await?;
}
}
},
@@ -508,8 +523,12 @@ impl Service {
let content: RoomMemberEventContent = pdu.get_content()?;
let stripped_state = match content.membership {
| MembershipState::Invite | MembershipState::Knock =>
self.services.state.summary_stripped(pdu).await.into(),
| MembershipState::Invite | MembershipState::Knock => self
.services
.state
.summary_stripped(pdu)
.await
.into(),
| _ => None,
};
@@ -533,9 +552,16 @@ impl Service {
| TimelineEventType::RoomMessage => {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
self.services
.search
.index_pdu(shortroomid, &pdu_id, &body);
if self.services.admin.is_admin_command(pdu, &body).await {
if self
.services
.admin
.is_admin_command(pdu, &body)
.await
{
self.services
.admin
.command(body, Some((*pdu.event_id).into()))?;
@@ -546,7 +572,10 @@ impl Service {
}
if let Ok(content) = pdu.get_content::<ExtractRelatesToEventId>() {
if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await {
if let Ok(related_pducount) = self
.get_pdu_count(&content.relates_to.event_id)
.await
{
self.services
.pdu_metadata
.add_relation(count2, related_pducount);
@@ -834,14 +863,26 @@ impl Service {
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
.await?;
if self.services.admin.is_admin_room(&pdu.room_id).await {
self.check_pdu_for_admin_room(&pdu, sender).boxed().await?;
if self
.services
.admin
.is_admin_room(&pdu.room_id)
.await
{
self.check_pdu_for_admin_room(&pdu, sender)
.boxed()
.await?;
}
// If redaction event is not authorized, do not append it to the timeline
if pdu.kind == TimelineEventType::RoomRedaction {
use RoomVersionId::*;
match self.services.state.get_room_version(&pdu.room_id).await? {
match self
.services
.state
.get_room_version(&pdu.room_id)
.await?
{
| V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if !self
@@ -885,7 +926,10 @@ impl Service {
.join_authorized_via_users_server
.as_ref()
.is_some_and(|authorising_user| {
!self.services.globals.user_is_local(authorising_user)
!self
.services
.globals
.user_is_local(authorising_user)
}) {
return Err!(Request(InvalidParam(
"Authorising user does not belong to this homeserver"
@@ -999,7 +1043,8 @@ impl Service {
user_id: &'a UserId,
room_id: &'a RoomId,
) -> impl Stream<Item = PdusIterItem> + Send + 'a {
self.pdus(Some(user_id), room_id, None).ignore_err()
self.pdus(Some(user_id), room_id, None)
.ignore_err()
}
/// Reverse iteration starting at from.
@@ -1052,7 +1097,11 @@ impl Service {
}
}
let room_version_id = self.services.state.get_room_version(&pdu.room_id).await?;
let room_version_id = self
.services
.state
.get_room_version(&pdu.room_id)
.await?;
pdu.redact(&room_version_id, reason)?;
@@ -1098,15 +1147,18 @@ impl Service {
.await
.unwrap_or_default();
let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| {
if level > &power_levels.users_default
&& !self.services.globals.user_is_local(user_id)
{
Some(user_id.server_name())
} else {
None
}
});
let room_mods = power_levels
.users
.iter()
.filter_map(|(user_id, level)| {
if level > &power_levels.users_default
&& !self.services.globals.user_is_local(user_id)
{
Some(user_id.server_name())
} else {
None
}
});
let canonical_room_alias_server = once(
self.services
@@ -1158,7 +1210,11 @@ impl Service {
match response {
| Ok(response) => {
for pdu in response.pdus {
if let Err(e) = self.backfill_pdu(backfill_server, pdu).boxed().await {
if let Err(e) = self
.backfill_pdu(backfill_server, pdu)
.boxed()
.await
{
debug_warn!("Failed to add backfilled pdu in room {room_id}: {e}");
}
}
@@ -1176,8 +1232,11 @@ impl Service {
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) -> Result<()> {
let (room_id, event_id, value) =
self.services.event_handler.parse_incoming_pdu(&pdu).await?;
let (room_id, event_id, value) = self
.services
.event_handler
.parse_incoming_pdu(&pdu)
.await?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = self
@@ -1203,11 +1262,20 @@ impl Service {
let pdu = self.get_pdu(&event_id).await?;
let shortroomid = self.services.short.get_shortroomid(&room_id).await?;
let shortroomid = self
.services
.short
.get_shortroomid(&room_id)
.await?;
let insert_lock = self.mutex_insert.lock(&room_id).await;
let count: i64 = self.services.globals.next_count().unwrap().try_into()?;
let count: i64 = self
.services
.globals
.next_count()
.unwrap()
.try_into()?;
let pdu_id: RawPduId = PduId {
shortroomid,
@@ -1216,14 +1284,17 @@ impl Service {
.into();
// Insert pdu
self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value);
self.db
.prepend_backfill_pdu(&pdu_id, &event_id, &value);
drop(insert_lock);
if pdu.kind == TimelineEventType::RoomMessage {
let content: ExtractBody = pdu.get_content()?;
if let Some(body) = content.body {
self.services.search.index_pdu(shortroomid, &pdu_id, &body);
self.services
.search
.index_pdu(shortroomid, &pdu_id, &body);
}
}
drop(mutex_lock);

View File

@@ -71,13 +71,18 @@ impl Service {
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
if self
.typing_update_sender
.send(room_id.to_owned())
.is_err()
{
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, true).await?;
self.federation_send(room_id, user_id, true)
.await?;
}
Ok(())
@@ -99,13 +104,18 @@ impl Service {
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
if self
.typing_update_sender
.send(room_id.to_owned())
.is_err()
{
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, false).await?;
self.federation_send(room_id, user_id, false)
.await?;
}
Ok(())
@@ -152,7 +162,11 @@ impl Service {
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
if self
.typing_update_sender
.send(room_id.to_owned())
.is_err()
{
trace!("receiver found what it was looking for and is no longer interested");
}
@@ -233,7 +247,10 @@ impl Service {
let mut buf = EduBuf::new();
serde_json::to_writer(&mut buf, &edu).expect("Serialized Edu::Typing");
self.services.sending.send_edu_room(room_id, buf).await?;
self.services
.sending
.send_edu_room(room_id, buf)
.await?;
Ok(())
}

View File

@@ -48,8 +48,12 @@ impl crate::Service for Service {
#[implement(Service)]
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
let userroom_id = (user_id, room_id);
self.db.userroomid_highlightcount.put(userroom_id, 0_u64);
self.db.userroomid_notificationcount.put(userroom_id, 0_u64);
self.db
.userroomid_highlightcount
.put(userroom_id, 0_u64);
self.db
.userroomid_notificationcount
.put(userroom_id, 0_u64);
let roomuser_id = (room_id, user_id);
let count = self.services.globals.next_count().unwrap();
@@ -118,7 +122,11 @@ pub async fn get_token_shortstatehash(
room_id: &RoomId,
token: u64,
) -> Result<ShortStateHash> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let shortroomid = self
.services
.short
.get_shortroomid(room_id)
.await?;
let key: &[u64] = &[shortroomid, token];
self.db

View File

@@ -54,14 +54,22 @@ where
.parse()
.unwrap(),
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
*http_request.uri_mut() = parts
.try_into()
.expect("our manipulation is always valid");
let reqwest_request = reqwest::Request::try_from(http_request)?;
let mut response = client.execute(reqwest_request).await.map_err(|e| {
warn!("Could not send request to appservice \"{}\" at {dest}: {e:?}", registration.id);
e
})?;
let mut response = client
.execute(reqwest_request)
.await
.map_err(|e| {
warn!(
"Could not send request to appservice \"{}\" at {dest}: {e:?}",
registration.id
);
e
})?;
// reqwest::Response -> http::Response conversion
let status = response.status();

View File

@@ -170,7 +170,8 @@ impl Data {
}
pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) {
self.servername_educount.raw_put(server_name, last_count);
self.servername_educount
.raw_put(server_name, last_count);
}
pub async fn get_latest_educount(&self, server_name: &ServerName) -> u64 {
@@ -187,7 +188,9 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let server = parts
.next()
.expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
@@ -207,7 +210,9 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
} else if key.starts_with(b"$") {
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
let user = parts.next().expect("splitn always returns one element");
let user = parts
.next()
.expect("splitn always returns one element");
let user_string = utils::str_from_bytes(user)
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
let user_id = UserId::parse(user_string)
@@ -235,7 +240,9 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
} else {
let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element");
let server = parts
.next()
.expect("splitn always returns one element");
let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;

View File

@@ -24,7 +24,10 @@ pub(super) fn get_prefix(&self) -> Vec<u8> {
},
| Self::Appservice(server) => {
let sigil = b"+";
let len = sigil.len().saturating_add(server.len()).saturating_add(1);
let len = sigil
.len()
.saturating_add(server.len())
.saturating_add(1);
let mut p = Vec::with_capacity(len);
p.extend_from_slice(sigil);

View File

@@ -100,7 +100,9 @@ impl crate::Service for Service {
pusher: args.depend::<pusher::Service>("pusher"),
federation: args.depend::<federation::Service>("federation"),
},
channels: (0..num_senders).map(|_| loole::unbounded()).collect(),
channels: (0..num_senders)
.map(|_| loole::unbounded())
.collect(),
}))
}
@@ -160,7 +162,10 @@ impl Service {
self.dispatch(Msg {
dest,
event,
queue_id: keys.into_iter().next().expect("request queue key"),
queue_id: keys
.into_iter()
.next()
.expect("request queue key"),
})
}
@@ -173,7 +178,10 @@ impl Service {
self.dispatch(Msg {
dest,
event,
queue_id: keys.into_iter().next().expect("request queue key"),
queue_id: keys
.into_iter()
.next()
.expect("request queue key"),
})
}
@@ -201,7 +209,9 @@ impl Service {
.await;
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
let keys = self
.db
.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg { dest, event, queue_id })?;
@@ -219,7 +229,10 @@ impl Service {
self.dispatch(Msg {
dest,
event,
queue_id: keys.into_iter().next().expect("request queue key"),
queue_id: keys
.into_iter()
.next()
.expect("request queue key"),
})
}
@@ -250,7 +263,9 @@ impl Service {
.await;
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o)));
let keys = self
.db
.queue_requests(requests.iter().map(|(o, e)| (e, o)));
for ((dest, event), queue_id) in requests.into_iter().zip(keys) {
self.dispatch(Msg { dest, event, queue_id })?;
@@ -299,7 +314,10 @@ impl Service {
where
T: OutgoingRequest + Debug + Send,
{
self.services.federation.execute(dest, request).await
self.services
.federation
.execute(dest, request)
.await
}
/// Like send_federation_request() but with a very large timeout

View File

@@ -85,7 +85,8 @@ impl Service {
.boxed()
.await;
self.work_loop(id, &mut futures, &mut statuses).await;
self.work_loop(id, &mut futures, &mut statuses)
.await;
if !futures.is_empty() {
self.finish_responses(&mut futures).boxed().await;
@@ -136,7 +137,9 @@ impl Service {
statuses: &mut CurTransactionStatus,
) {
match response {
| Ok(dest) => self.handle_response_ok(&dest, futures, statuses).await,
| Ok(dest) =>
self.handle_response_ok(&dest, futures, statuses)
.await,
| Err((dest, e)) => Self::handle_response_err(dest, statuses, &e),
}
}
@@ -177,7 +180,10 @@ impl Service {
if !new_events.is_empty() {
self.db.mark_as_active(new_events.iter());
let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect();
let new_events_vec = new_events
.into_iter()
.map(|(_, event)| event)
.collect();
futures.push(self.send_events(dest.clone(), new_events_vec));
} else {
statuses.remove(dest);
@@ -322,7 +328,8 @@ impl Service {
let select_edus = select_edus.into_iter().map(SendingEvent::Edu);
events.extend(select_edus);
self.db.set_latest_educount(server_name, last_count);
self.db
.set_latest_educount(server_name, last_count);
}
}
@@ -415,7 +422,10 @@ impl Service {
events_len: &AtomicUsize,
) -> EduVec {
let mut events = EduVec::new();
let server_rooms = self.services.state_cache.server_rooms(server_name);
let server_rooms = self
.services
.state_cache
.server_rooms(server_name);
pin_mut!(server_rooms);
let mut device_list_changes = HashSet::<OwnedUserId>::new();
@@ -562,7 +572,10 @@ impl Service {
event_ids: vec![event_id.clone()],
};
if read.insert(user_id.to_owned(), receipt_data).is_none() {
if read
.insert(user_id.to_owned(), receipt_data)
.is_none()
{
*num = num.saturating_add(1);
if *num >= SELECT_RECEIPT_LIMIT {
break;
@@ -621,7 +634,10 @@ impl Service {
let update = PresenceUpdate {
user_id: user_id.into(),
presence: presence_event.content.presence,
currently_active: presence_event.content.currently_active.unwrap_or(false),
currently_active: presence_event
.content
.currently_active
.unwrap_or(false),
status_msg: presence_event.content.status_msg,
last_active_ago: presence_event
.content
@@ -653,11 +669,15 @@ impl Service {
fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingFuture<'_> {
debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
| Destination::Federation(server) =>
self.send_events_dest_federation(server, events).boxed(),
| Destination::Appservice(id) => self.send_events_dest_appservice(id, events).boxed(),
| Destination::Push(user_id, pushkey) =>
self.send_events_dest_push(user_id, pushkey, events).boxed(),
| Destination::Federation(server) => self
.send_events_dest_federation(server, events)
.boxed(),
| Destination::Appservice(id) => self
.send_events_dest_appservice(id, events)
.boxed(),
| Destination::Push(user_id, pushkey) => self
.send_events_dest_push(user_id, pushkey, events)
.boxed(),
}
}
@@ -674,7 +694,12 @@ impl Service {
id: String,
events: Vec<SendingEvent>,
) -> SendingResult {
let Some(appservice) = self.services.appservice.get_registration(&id).await else {
let Some(appservice) = self
.services
.appservice
.get_registration(&id)
.await
else {
return Err((
Destination::Appservice(id.clone()),
err!(Database(warn!(?id, "Missing appservice registration"))),
@@ -696,7 +721,12 @@ impl Service {
for event in &events {
match event {
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
if let Ok(pdu) = self
.services
.timeline
.get_pdu_from_id(pdu_id)
.await
{
pdu_jsons.push(pdu.into_room_event());
}
},
@@ -752,7 +782,12 @@ impl Service {
pushkey: String,
events: Vec<SendingEvent>,
) -> SendingResult {
let Ok(pusher) = self.services.pusher.get_pusher(&user_id, &pushkey).await else {
let Ok(pusher) = self
.services
.pusher
.get_pusher(&user_id, &pushkey)
.await
else {
return Err((
Destination::Push(user_id.clone(), pushkey.clone()),
err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))),
@@ -768,7 +803,12 @@ impl Service {
for event in &events {
match event {
| SendingEvent::Pdu(pdu_id) => {
if let Ok(pdu) = self.services.timeline.get_pdu_from_id(pdu_id).await {
if let Ok(pdu) = self
.services
.timeline
.get_pdu_from_id(pdu_id)
.await
{
pdus.push(pdu);
}
},
@@ -826,7 +866,12 @@ impl Service {
| _ => None,
})
.stream()
.wide_filter_map(|pdu_id| self.services.timeline.get_pdu_json_from_id(pdu_id).ok())
.wide_filter_map(|pdu_id| {
self.services
.timeline
.get_pdu_json_from_id(pdu_id)
.ok()
})
.wide_then(|pdu| self.convert_to_outgoing_federation_event(pdu))
.collect()
.await;
@@ -898,7 +943,12 @@ impl Service {
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(room_id).await {
match self
.services
.state
.get_room_version(room_id)
.await
{
| Ok(room_version_id) => match room_version_id {
| RoomVersionId::V1 | RoomVersionId::V2 => {},
| _ => _ = pdu_json.remove("event_id"),

View File

@@ -35,7 +35,10 @@ where
.filter_map(FlatOk::flat_ok)
.flat_map(IntoIterator::into_iter)
.for_each(|(server, sigs)| {
batch.entry(server).or_default().extend(sigs.into_keys());
batch
.entry(server)
.or_default()
.extend(sigs.into_keys());
});
let batch = batch
@@ -51,8 +54,16 @@ where
S: Iterator<Item = (&'a ServerName, K)> + Send + Clone,
K: Iterator<Item = &'a ServerSigningKeyId> + Send + Clone,
{
let notary_only = self.services.server.config.only_query_trusted_key_servers;
let notary_first_always = self.services.server.config.query_trusted_key_servers_first;
let notary_only = self
.services
.server
.config
.only_query_trusted_key_servers;
let notary_first_always = self
.services
.server
.config
.query_trusted_key_servers_first;
let notary_first_on_join = self
.services
.server
@@ -60,7 +71,10 @@ where
.query_trusted_key_servers_first_on_join;
let requested_servers = batch.clone().count();
let requested_keys = batch.clone().flat_map(|(_, key_ids)| key_ids).count();
let requested_keys = batch
.clone()
.flat_map(|(_, key_ids)| key_ids)
.count();
debug!("acquire {requested_keys} keys from {requested_servers}");
@@ -214,7 +228,8 @@ where
| Err(e) => error!("Failed to contact notary {notary:?}: {e}"),
| Ok(results) =>
for server_keys in results {
self.acquire_notary_result(&mut missing, server_keys).await;
self.acquire_notary_result(&mut missing, server_keys)
.await;
},
}
}
@@ -236,5 +251,8 @@ async fn acquire_notary_result(&self, missing: &mut Batch, server_keys: ServerSi
}
fn keys_count(batch: &Batch) -> usize {
batch.iter().flat_map(|(_, key_ids)| key_ids.iter()).count()
batch
.iter()
.flat_map(|(_, key_ids)| key_ids.iter())
.count()
}

View File

@@ -66,27 +66,44 @@ pub async fn get_verify_key(
origin: &ServerName,
key_id: &ServerSigningKeyId,
) -> Result<VerifyKey> {
let notary_first = self.services.server.config.query_trusted_key_servers_first;
let notary_only = self.services.server.config.only_query_trusted_key_servers;
let notary_first = self
.services
.server
.config
.query_trusted_key_servers_first;
let notary_only = self
.services
.server
.config
.only_query_trusted_key_servers;
if let Some(result) = self.verify_keys_for(origin).await.remove(key_id) {
return Ok(result);
}
if notary_first {
if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await {
if let Ok(result) = self
.get_verify_key_from_notaries(origin, key_id)
.await
{
return Ok(result);
}
}
if !notary_only {
if let Ok(result) = self.get_verify_key_from_origin(origin, key_id).await {
if let Ok(result) = self
.get_verify_key_from_origin(origin, key_id)
.await
{
return Ok(result);
}
}
if !notary_first {
if let Ok(result) = self.get_verify_key_from_notaries(origin, key_id).await {
if let Ok(result) = self
.get_verify_key_from_notaries(origin, key_id)
.await
{
return Ok(result);
}
}

View File

@@ -107,8 +107,11 @@ async fn add_signing_keys(&self, new_keys: ServerSigningKeys) {
});
keys.verify_keys.extend(new_keys.verify_keys);
keys.old_verify_keys.extend(new_keys.old_verify_keys);
self.db.server_signingkeys.raw_put(origin, Json(&keys));
keys.old_verify_keys
.extend(new_keys.old_verify_keys);
self.db
.server_signingkeys
.raw_put(origin, Json(&keys));
}
#[implement(Service)]
@@ -177,7 +180,11 @@ pub async fn verify_keys_for(&self, origin: &ServerName) -> VerifyKeys {
#[implement(Service)]
pub async fn signing_keys_for(&self, origin: &ServerName) -> Result<ServerSigningKeys> {
self.db.server_signingkeys.get(origin).await.deserialized()
self.db
.server_signingkeys
.get(origin)
.await
.deserialized()
}
#[implement(Service)]

View File

@@ -42,7 +42,12 @@ where
while let Some(batch) = server_keys
.keys()
.rev()
.take(self.services.server.config.trusted_server_batch_size)
.take(
self.services
.server
.config
.trusted_server_batch_size,
)
.next_back()
.cloned()
{

View File

@@ -11,7 +11,10 @@ pub async fn validate_and_add_event_id(
room_version: &RoomVersionId,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
if let Err(e) = self
.verify_event(&value, Some(room_version))
.await
{
return Err!(BadServerResponse(debug_error!(
"Event {event_id} failed verification: {e:?}"
)));
@@ -29,13 +32,19 @@ pub async fn validate_and_add_event_id_no_fetch(
room_version: &RoomVersionId,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let (event_id, mut value) = gen_event_id_canonical_json(pdu, room_version)?;
if !self.required_keys_exist(&value, room_version).await {
if !self
.required_keys_exist(&value, room_version)
.await
{
return Err!(BadServerResponse(debug_warn!(
"Event {event_id} cannot be verified: missing keys."
)));
}
if let Err(e) = self.verify_event(&value, Some(room_version)).await {
if let Err(e) = self
.verify_event(&value, Some(room_version))
.await
{
return Err!(BadServerResponse(debug_error!(
"Event {event_id} failed verification: {e:?}"
)));

View File

@@ -115,7 +115,8 @@ impl Services {
pub async fn start(self: &Arc<Self>) -> Result<Arc<Self>> {
debug_info!("Starting services...");
self.admin.set_services(Some(Arc::clone(self)).as_ref());
self.admin
.set_services(Some(Arc::clone(self)).as_ref());
super::migrations::migrations(self).await?;
self.manager
.lock()
@@ -188,7 +189,12 @@ impl Services {
fn interrupt(&self) {
debug!("Interrupting services...");
for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() {
for (name, (service, ..)) in self
.service
.read()
.expect("locked for reading")
.iter()
{
if let Some(service) = service.upgrade() {
trace!("Interrupting {name}");
service.interrupt();

View File

@@ -107,15 +107,24 @@ impl Service {
}
pub fn forget_snake_sync_connection(&self, key: &SnakeConnectionsKey) {
self.snake_connections.lock().expect("locked").remove(key);
self.snake_connections
.lock()
.expect("locked")
.remove(key);
}
pub fn remembered(&self, key: &DbConnectionsKey) -> bool {
self.connections.lock().expect("locked").contains_key(key)
self.connections
.lock()
.expect("locked")
.contains_key(key)
}
pub fn forget_sync_request_connection(&self, key: &DbConnectionsKey) {
self.connections.lock().expect("locked").remove(key);
self.connections
.lock()
.expect("locked")
.remove(key);
}
pub fn update_snake_sync_request_with_cache(

View File

@@ -17,11 +17,27 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
// Return when *any* user changed their key
// TODO: only send for user they share a room with
futures.push(self.db.todeviceid_events.watch_prefix(&userdeviceid_prefix));
futures.push(
self.db
.todeviceid_events
.watch_prefix(&userdeviceid_prefix),
);
futures.push(self.db.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.db.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.db.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(
self.db
.userroomid_joined
.watch_prefix(&userid_prefix),
);
futures.push(
self.db
.userroomid_invitestate
.watch_prefix(&userid_prefix),
);
futures.push(
self.db
.userroomid_leftstate
.watch_prefix(&userid_prefix),
);
futures.push(
self.db
.userroomid_notificationcount
@@ -47,7 +63,11 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
roomid_prefix.push(0xFF);
// Key changes
futures.push(self.db.keychangeid_userid.watch_prefix(&roomid_prefix));
futures.push(
self.db
.keychangeid_userid
.watch_prefix(&roomid_prefix),
);
// Room account data
let mut roomuser_prefix = roomid_prefix.clone();
@@ -66,7 +86,10 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
// EDUs
let typing_room_id = room_id.to_owned();
let typing_wait_for_update = async move {
self.services.typing.wait_for_update(&typing_room_id).await;
self.services
.typing
.wait_for_update(&typing_room_id)
.await;
};
futures.push(typing_wait_for_update.boxed());
@@ -87,7 +110,11 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
);
// More key changes (used when user is not joined to any rooms)
futures.push(self.db.keychangeid_userid.watch_prefix(&userid_prefix));
futures.push(
self.db
.keychangeid_userid
.watch_prefix(&userid_prefix),
);
// One time keys
futures.push(

View File

@@ -34,11 +34,17 @@ pub fn add_txnid(
) {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
key.extend_from_slice(
device_id
.map(DeviceId::as_bytes)
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(txn_id.as_bytes());
self.db.userdevicetxnid_response.insert(&key, data);
self.db
.userdevicetxnid_response
.insert(&key, data);
}
// If there's no entry, this is a new transaction

View File

@@ -60,7 +60,12 @@ impl crate::Service for Service {
#[implement(Service)]
pub async fn read_tokens(&self) -> Result<HashSet<String>> {
let mut tokens = HashSet::new();
if let Some(file) = &self.services.config.registration_token_file.as_ref() {
if let Some(file) = &self
.services
.config
.registration_token_file
.as_ref()
{
match std::fs::read_to_string(file) {
| Ok(text) => {
text.split_ascii_whitespace().for_each(|token| {
@@ -91,14 +96,20 @@ pub fn create(
self.set_uiaa_request(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
uiaainfo
.session
.as_ref()
.expect("session should be set"),
json_body,
);
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session should be set"),
uiaainfo
.session
.as_ref()
.expect("session should be set"),
Some(uiaainfo),
);
}
@@ -112,7 +123,8 @@ pub async fn try_auth(
uiaainfo: &UiaaInfo,
) -> Result<(bool, UiaaInfo)> {
let mut uiaainfo = if let Some(session) = auth.session() {
self.get_uiaa_session(user_id, device_id, session).await?
self.get_uiaa_session(user_id, device_id, session)
.await?
} else {
uiaainfo.clone()
};
@@ -180,7 +192,9 @@ pub async fn try_auth(
| AuthData::RegistrationToken(t) => {
let tokens = self.read_tokens().await?;
if tokens.contains(t.token.trim()) {
uiaainfo.completed.push(AuthType::RegistrationToken);
uiaainfo
.completed
.push(AuthType::RegistrationToken);
} else {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
kind: ErrorKind::forbidden(),
@@ -211,7 +225,10 @@ pub async fn try_auth(
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
uiaainfo
.session
.as_ref()
.expect("session is always set"),
Some(&uiaainfo),
);
@@ -222,7 +239,10 @@ pub async fn try_auth(
self.update_uiaa_session(
user_id,
device_id,
uiaainfo.session.as_ref().expect("session is always set"),
uiaainfo
.session
.as_ref()
.expect("session is always set"),
None,
);
@@ -253,7 +273,9 @@ pub fn get_uiaa_request(
) -> Option<CanonicalJsonValue> {
let key = (
user_id.to_owned(),
device_id.unwrap_or_else(|| EMPTY.into()).to_owned(),
device_id
.unwrap_or_else(|| EMPTY.into())
.to_owned(),
session.to_owned(),
);

View File

@@ -175,7 +175,11 @@ impl Service {
/// Find out which user an access token belongs to.
pub async fn find_from_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
self.db.token_userdeviceid.get(token).await.deserialized()
self.db
.token_userdeviceid
.get(token)
.await
.deserialized()
}
/// Returns an iterator over all users on this homeserver (offered for
@@ -204,7 +208,11 @@ impl Service {
/// Returns the password hash for the given user.
pub async fn password_hash(&self, user_id: &UserId) -> Result<String> {
self.db.userid_password.get(user_id).await.deserialized()
self.db
.userid_password
.get(user_id)
.await
.deserialized()
}
/// Hash and set the user's password to the Argon2 hash
@@ -225,14 +233,20 @@ impl Service {
/// Returns the displayname of a user on this homeserver.
pub async fn displayname(&self, user_id: &UserId) -> Result<String> {
self.db.userid_displayname.get(user_id).await.deserialized()
self.db
.userid_displayname
.get(user_id)
.await
.deserialized()
}
/// Sets a new displayname or removes it if displayname is None. You still
/// need to nofify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) {
if let Some(displayname) = displayname {
self.db.userid_displayname.insert(user_id, displayname);
self.db
.userid_displayname
.insert(user_id, displayname);
} else {
self.db.userid_displayname.remove(user_id);
}
@@ -240,14 +254,20 @@ impl Service {
/// Get the `avatar_url` of a user.
pub async fn avatar_url(&self, user_id: &UserId) -> Result<OwnedMxcUri> {
self.db.userid_avatarurl.get(user_id).await.deserialized()
self.db
.userid_avatarurl
.get(user_id)
.await
.deserialized()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) {
match avatar_url {
| Some(avatar_url) => {
self.db.userid_avatarurl.insert(user_id, &avatar_url);
self.db
.userid_avatarurl
.insert(user_id, &avatar_url);
},
| _ => {
self.db.userid_avatarurl.remove(user_id);
@@ -257,7 +277,11 @@ impl Service {
/// Get the blurhash of a user.
pub async fn blurhash(&self, user_id: &UserId) -> Result<String> {
self.db.userid_blurhash.get(user_id).await.deserialized()
self.db
.userid_blurhash
.get(user_id)
.await
.deserialized()
}
/// Sets a new avatar_url or removes it if avatar_url is None.
@@ -302,7 +326,12 @@ impl Service {
let userdeviceid = (user_id, device_id);
// Remove tokens
if let Ok(old_token) = self.db.userdeviceid_token.qry(&userdeviceid).await {
if let Ok(old_token) = self
.db
.userdeviceid_token
.qry(&userdeviceid)
.await
{
self.db.userdeviceid_token.del(userdeviceid);
self.db.token_userdeviceid.remove(&old_token);
}
@@ -339,7 +368,11 @@ impl Service {
pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
let key = (user_id, device_id);
self.db.userdeviceid_token.qry(&key).await.deserialized()
self.db
.userdeviceid_token
.qry(&key)
.await
.deserialized()
}
/// Replaces the access token of one device.
@@ -350,7 +383,13 @@ impl Service {
token: &str,
) -> Result<()> {
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
if self
.db
.userdeviceid_metadata
.qry(&key)
.await
.is_err()
{
return Err!(Database(error!(
?user_id,
?device_id,
@@ -382,7 +421,13 @@ impl Service {
// Only existing devices should be able to call this, but we shouldn't assert
// either...
let key = (user_id, device_id);
if self.db.userdeviceid_metadata.qry(&key).await.is_err() {
if self
.db
.userdeviceid_metadata
.qry(&key)
.await
.is_err()
{
return Err!(Database(error!(
?user_id,
?device_id,
@@ -407,7 +452,9 @@ impl Service {
.raw_put(key, Json(one_time_key_value));
let count = self.services.globals.next_count().unwrap();
self.db.userid_lastonetimekeyupdate.raw_put(user_id, count);
self.db
.userid_lastonetimekeyupdate
.raw_put(user_id, count);
Ok(())
}
@@ -428,7 +475,9 @@ impl Service {
key_algorithm: &OneTimeKeyAlgorithm,
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
let count = self.services.globals.next_count()?.to_be_bytes();
self.db.userid_lastonetimekeyupdate.insert(user_id, count);
self.db
.userid_lastonetimekeyupdate
.insert(user_id, count);
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
@@ -542,10 +591,12 @@ impl Service {
.keys
.into_values();
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained no key.",
))?;
let self_signing_key_id = self_signing_key_ids
.next()
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Self signing key contained no key.",
))?;
if self_signing_key_ids.next().is_some() {
return Err(Error::BadRequest(
@@ -625,7 +676,9 @@ impl Service {
.insert(signature.0, signature.1.into());
let key = (target_id, key_id);
self.db.keyid_key.put(key, Json(cross_signing_key));
self.db
.keyid_key
.put(key, Json(cross_signing_key));
self.mark_device_key_update(target_id).await;
@@ -697,7 +750,11 @@ impl Service {
device_id: &DeviceId,
) -> Result<Raw<DeviceKeys>> {
let key_id = (user_id, device_id);
self.db.keyid_key.qry(&key_id).await.deserialized()
self.db
.keyid_key
.qry(&key_id)
.await
.deserialized()
}
pub async fn get_key<F>(
@@ -710,7 +767,12 @@ impl Service {
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key: serde_json::Value = self.db.keyid_key.get(key_id).await.deserialized()?;
let key: serde_json::Value = self
.db
.keyid_key
.get(key_id)
.await
.deserialized()?;
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
@@ -741,7 +803,11 @@ impl Service {
where
F: Fn(&UserId) -> bool + Send + Sync,
{
let key_id = self.db.userid_selfsigningkeyid.get(user_id).await?;
let key_id = self
.db
.userid_selfsigningkeyid
.get(user_id)
.await?;
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
.await
@@ -834,7 +900,9 @@ impl Service {
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
let key = (user_id, device_id);
self.db.userdeviceid_metadata.put(key, Json(device));
self.db
.userdeviceid_metadata
.put(key, Json(device));
Ok(())
}
@@ -888,7 +956,11 @@ impl Service {
filter_id: &str,
) -> Result<FilterDefinition> {
let key = (user_id, filter_id);
self.db.userfilterid_filter.qry(&key).await.deserialized()
self.db
.userfilterid_filter
.qry(&key)
.await
.deserialized()
}
/// Creates an OpenID token, which can be used to prove that a user has
@@ -911,7 +983,12 @@ impl Service {
/// Find out which user an OpenID access token belongs to.
pub async fn find_from_openid_token(&self, token: &str) -> Result<OwnedUserId> {
let Ok(value) = self.db.openidtoken_expiresatuserid.get(token).await else {
let Ok(value) = self
.db
.openidtoken_expiresatuserid
.get(token)
.await
else {
return Err!(Request(Unauthorized("OpenID token is unrecognised")));
};
@@ -923,7 +1000,9 @@ impl Service {
if expires_at < utils::millis_since_unix_epoch() {
debug_warn!("OpenID token is expired, removing");
self.db.openidtoken_expiresatuserid.remove(token.as_bytes());
self.db
.openidtoken_expiresatuserid
.remove(token.as_bytes());
return Err!(Request(Unauthorized("OpenID token is expired")));
}
@@ -944,7 +1023,9 @@ impl Service {
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in);
let value = (expires_at.0, user_id);
self.db.logintoken_expiresatuserid.raw_put(token, value);
self.db
.logintoken_expiresatuserid
.raw_put(token, value);
expires_in
}
@@ -952,7 +1033,12 @@ impl Service {
/// Find out which user a login token belongs to.
/// Removes the token to prevent double-use attacks.
pub async fn find_from_login_token(&self, token: &str) -> Result<OwnedUserId> {
let Ok(value) = self.db.logintoken_expiresatuserid.get(token).await else {
let Ok(value) = self
.db
.logintoken_expiresatuserid
.get(token)
.await
else {
return Err!(Request(Forbidden("Login token is unrecognised")));
};
let (expires_at, user_id): (u64, OwnedUserId) = value.deserialized()?;
@@ -1010,7 +1096,9 @@ impl Service {
let key = (user_id, profile_key);
if let Some(value) = profile_key_value {
self.db.useridprofilekey_value.put(key, Json(value));
self.db
.useridprofilekey_value
.put(key, Json(value));
} else {
self.db.useridprofilekey_value.del(key);
}
@@ -1038,7 +1126,9 @@ impl Service {
let key = (user_id, "us.cloke.msc4175.tz");
if let Some(timezone) = timezone {
self.db.useridprofilekey_value.put_raw(key, &timezone);
self.db
.useridprofilekey_value
.put_raw(key, &timezone);
} else {
self.db.useridprofilekey_value.del(key);
}