From f2740822e25bf6ebf251f78d57b5f1197ac86e9c Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 27 Sep 2025 09:52:53 +0000 Subject: [PATCH] De-indent rooms state service definitions. Signed-off-by: Jason Volk --- src/service/rooms/state/mod.rs | 913 +++++++++++++++++---------------- 1 file changed, 462 insertions(+), 451 deletions(-) diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 106de896..07129c29 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -12,7 +12,7 @@ use ruma::{ serde::Raw, }; use tuwunel_core::{ - Event, PduEvent, Result, err, + Event, PduEvent, Result, err, implement, matrix::{RoomVersionRules, StateKey, TypeStateKey, room_version}, result::{AndThenRef, FlatOk}, state_res::{StateMap, auth_types_for_event}, @@ -73,130 +73,210 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Set the room to the given statehash and update caches. - #[tracing::instrument( - name = "force", - level = "debug", - skip_all, - fields( - count = ?self.services.globals.pending_count(), - %shortstatehash, - ) - )] - pub async fn force_state( - &self, - room_id: &RoomId, - shortstatehash: u64, - statediffnew: Arc, - _statediffremoved: Arc, - state_lock: &RoomMutexGuard, - ) -> Result { - let event_ids = statediffnew - .iter() - .stream() - .map(|&new| parse_compressed_state_event(new).1) - .then(|shorteventid| { +/// Set the room to the given statehash and update caches. +#[implement(Service)] +#[tracing::instrument( + name = "force", + level = "debug", + skip_all, + fields( + count = ?self.services.globals.pending_count(), + %shortstatehash, + ) +)] +pub async fn force_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + statediffnew: Arc, + _statediffremoved: Arc, + state_lock: &RoomMutexGuard, +) -> Result { + let event_ids = statediffnew + .iter() + .stream() + .map(|&new| parse_compressed_state_event(new).1) + .then(|shorteventid| { + self.services + .short + .get_eventid_from_short::>(shorteventid) + }) + .ignore_err(); + + pin_mut!(event_ids); + while let Some(event_id) = event_ids.next().await { + let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { + continue; + }; + + match pdu.kind { + | TimelineEventType::RoomMember => { + let Some(user_id) = pdu + .state_key + .as_ref() + .map(UserId::parse) + .flat_ok() + else { + continue; + }; + + let Ok(membership_event) = pdu.get_content::() else { + continue; + }; + self.services - .short - .get_eventid_from_short::>(shorteventid) - }) - .ignore_err(); - - pin_mut!(event_ids); - while let Some(event_id) = event_ids.next().await { - let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else { - continue; - }; - - match pdu.kind { - | TimelineEventType::RoomMember => { - let Some(user_id) = pdu - .state_key - .as_ref() - .map(UserId::parse) - .flat_ok() - else { - continue; - }; - - let Ok(membership_event) = pdu.get_content::() else { - continue; - }; - - self.services - .state_cache - .update_membership( - room_id, - user_id, - membership_event, - &pdu.sender, - None, - None, - false, - ) - .await?; - }, - | TimelineEventType::SpaceChild => { - self.services - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .remove(&pdu.room_id); - }, - | _ => continue, - } + .state_cache + .update_membership( + room_id, + user_id, + membership_event, + &pdu.sender, + None, + None, + false, + ) + .await?; + }, + | TimelineEventType::SpaceChild => { + self.services + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .remove(&pdu.room_id); + }, + | _ => continue, } - - self.services - .state_cache - .update_joined_count(room_id) - .await; - - self.set_room_state(room_id, shortstatehash, state_lock); - - Ok(()) } - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument( - name = "set", - level = "debug", - skip(self, state_ids_compressed), - fields( - count = ?self.services.globals.pending_count(), - ) - )] - pub async fn set_event_state( - &self, - event_id: &EventId, - room_id: &RoomId, - state_ids_compressed: Arc, - ) -> Result { - const KEY_LEN: usize = size_of::(); - const VAL_LEN: usize = size_of::(); + self.services + .state_cache + .update_joined_count(room_id) + .await; - let shorteventid = self - .services - .short - .get_or_create_shorteventid(event_id) - .await; + self.set_room_state(room_id, shortstatehash, state_lock); - let previous_shortstatehash = self.get_room_shortstatehash(room_id).await; + Ok(()) +} - let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..])); +/// Generates a new StateHash and associates it with the incoming event. +/// +/// This adds all current state events (not including the incoming event) +/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. +#[implement(Service)] +#[tracing::instrument( + name = "set", + level = "debug", + skip(self, state_ids_compressed), + fields( + count = ?self.services.globals.pending_count(), + ) +)] +pub async fn set_event_state( + &self, + event_id: &EventId, + room_id: &RoomId, + state_ids_compressed: Arc, +) -> Result { + const KEY_LEN: usize = size_of::(); + const VAL_LEN: usize = size_of::(); - let (shortstatehash, already_existed) = self - .services - .short - .get_or_create_shortstatehash(&state_hash) - .await; + let shorteventid = self + .services + .short + .get_or_create_shorteventid(event_id) + .await; - if !already_existed { + let previous_shortstatehash = self.get_room_shortstatehash(room_id).await; + + let state_hash = calculate_hash(state_ids_compressed.iter().map(|s| &s[..])); + + let (shortstatehash, already_existed) = self + .services + .short + .get_or_create_shortstatehash(&state_hash) + .await; + + if !already_existed { + let states_parents = match previous_shortstatehash { + | Ok(p) => + self.services + .state_compressor + .load_shortstatehash_info(p) + .await?, + | _ => Vec::new(), + }; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: CompressedState = state_ids_compressed + .difference(&parent_stateinfo.full_state) + .copied() + .collect(); + + let statediffremoved: CompressedState = parent_stateinfo + .full_state + .difference(&state_ids_compressed) + .copied() + .collect(); + + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (state_ids_compressed, Arc::new(CompressedState::new())) + }; + + self.services + .state_compressor + .save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; + } + + self.db + .shorteventid_shortstatehash + .aput::(shorteventid, shortstatehash); + + Ok(shortstatehash) +} + +/// Generates a new StateHash and associates it with the incoming event. +/// +/// This adds all current state events (not including the incoming event) +/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. +#[implement(Service)] +#[tracing::instrument( + name = "set", + level = "debug", + skip(self, new_pdu), + fields( + count = ?self.services.globals.pending_count(), + ) +)] +pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + const BUFSIZE: usize = size_of::(); + + let shorteventid = self + .services + .short + .get_or_create_shorteventid(&new_pdu.event_id) + .await; + + let previous_shortstatehash = self + .get_room_shortstatehash(&new_pdu.room_id) + .await; + + if let Ok(p) = previous_shortstatehash { + self.db + .shorteventid_shortstatehash + .aput::(shorteventid, p); + } + + match &new_pdu.state_key { + | Some(state_key) => { let states_parents = match previous_shortstatehash { | Ok(p) => self.services @@ -206,349 +286,280 @@ impl Service { | _ => Vec::new(), }; - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: CompressedState = state_ids_compressed - .difference(&parent_stateinfo.full_state) - .copied() - .collect(); + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) + .await; - let statediffremoved: CompressedState = parent_stateinfo - .full_state - .difference(&state_ids_compressed) - .copied() - .collect(); + let new = self + .services + .state_compressor + .compress_state_event(shortstatekey, &new_pdu.event_id) + .await; - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (state_ids_compressed, Arc::new(CompressedState::new())) - }; + let replaces = states_parents + .last() + .map(|info| { + info.full_state + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) + .unwrap_or_default(); + + if Some(&new) == replaces { + return Ok(previous_shortstatehash.expect("must exist")); + } + + // TODO: statehash with deterministic inputs + let shortstatehash = self.services.globals.next_count(); + + let mut statediffnew = CompressedState::new(); + statediffnew.insert(new); + + let mut statediffremoved = CompressedState::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(*replaces); + } self.services .state_compressor .save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 1_000_000, // high number because no state will be based on this one + *shortstatehash, + Arc::new(statediffnew), + Arc::new(statediffremoved), + 2, states_parents, )?; - } - self.db - .shorteventid_shortstatehash - .aput::(shorteventid, shortstatehash); - - Ok(shortstatehash) - } - - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument( - name = "set", - level = "debug", - skip(self, new_pdu), - fields( - count = ?self.services.globals.pending_count(), - ) - )] - pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result { - const BUFSIZE: usize = size_of::(); - - let shorteventid = self - .services - .short - .get_or_create_shorteventid(&new_pdu.event_id) - .await; - - let previous_shortstatehash = self - .get_room_shortstatehash(&new_pdu.room_id) - .await; - - if let Ok(p) = previous_shortstatehash { - self.db - .shorteventid_shortstatehash - .aput::(shorteventid, p); - } - - match &new_pdu.state_key { - | Some(state_key) => { - let states_parents = match previous_shortstatehash { - | Ok(p) => - self.services - .state_compressor - .load_shortstatehash_info(p) - .await?, - | _ => Vec::new(), - }; - - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key) - .await; - - let new = self - .services - .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id) - .await; - - let replaces = states_parents - .last() - .map(|info| { - info.full_state - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - }) - .unwrap_or_default(); - - if Some(&new) == replaces { - return Ok(previous_shortstatehash.expect("must exist")); - } - - // TODO: statehash with deterministic inputs - let shortstatehash = self.services.globals.next_count(); - - let mut statediffnew = CompressedState::new(); - statediffnew.insert(new); - - let mut statediffremoved = CompressedState::new(); - if let Some(replaces) = replaces { - statediffremoved.insert(*replaces); - } - - self.services - .state_compressor - .save_state_from_diff( - *shortstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, - states_parents, - )?; - - Ok(*shortstatehash) - }, - | _ => - Ok(previous_shortstatehash.expect("first event in room must be a state event")), - } - } - - /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self, _mutex_lock), level = "debug")] - pub fn set_room_state( - &self, - room_id: &RoomId, - shortstatehash: u64, - // Take mutex guard to make sure users get the room state mutex - _mutex_lock: &RoomMutexGuard, - ) { - const BUFSIZE: usize = size_of::(); - - self.db - .roomid_shortstatehash - .raw_aput::(room_id, shortstatehash); - } - - /// This fetches auth events from the current state. - #[allow(clippy::too_many_arguments)] - #[tracing::instrument(skip(self, content), level = "debug")] - pub async fn get_auth_events( - &self, - room_id: &RoomId, - kind: &TimelineEventType, - sender: &UserId, - state_key: Option<&str>, - content: &serde_json::value::RawValue, - auth_rules: &AuthorizationRules, - include_create: bool, - ) -> Result> - where - StateEventType: Send + Sync, - StateKey: Send + Sync, - { - let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { - return Ok(StateMap::new()); - }; - - let sauthevents: HashMap = - auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)? - .into_iter() - .stream() - .broad_filter_map(async |(event_type, state_key): TypeStateKey| { - self.services - .short - .get_shortstatekey(&event_type, &state_key) - .await - .map(move |sstatekey| (sstatekey, (event_type, state_key))) - .ok() - }) - .collect() - .await; - - self.services - .state_accessor - .state_full_shortids(shortstatehash) - .ready_filter_map(Result::ok) - .ready_filter_map(|(shortstatekey, shorteventid)| { - sauthevents - .get(&shortstatekey) - .map(move |(ty, sk)| ((ty, sk), shorteventid)) - }) - .unzip() - .map(|(state_keys, event_ids): (Vec<_>, Vec<_>)| { - self.services - .short - .multi_get_eventid_from_short(event_ids.into_iter().stream()) - .zip(state_keys.into_iter().stream()) - }) - .flatten_stream() - .ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?))) - .broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| { - let pdu = self.services.timeline.get_pdu(&event_id).await; - - Some(((ty.clone(), sk.clone()), pdu.ok()?)) - }) - .collect() - .map(Ok) - .await - } - - #[tracing::instrument(skip_all, level = "debug")] - pub async fn summary_stripped( - &self, - event: &Pdu, - ) -> Vec> { - let cells = [ - (&StateEventType::RoomCreate, ""), - (&StateEventType::RoomJoinRules, ""), - (&StateEventType::RoomCanonicalAlias, ""), - (&StateEventType::RoomName, ""), - (&StateEventType::RoomAvatar, ""), - (&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events - (&StateEventType::RoomEncryption, ""), - (&StateEventType::RoomTopic, ""), - ]; - - let fetches = cells.into_iter().map(|(event_type, state_key)| { - self.services - .state_accessor - .room_state_get(event.room_id(), event_type, state_key) - }); - - join_all(fetches) - .await - .into_iter() - .filter_map(Result::ok) - .map(Event::into_format) - .chain(once(event.to_format())) - .collect() - } - - /// Returns the room's version rules - #[inline] - pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result { - self.get_room_version(room_id) - .await - .and_then_ref(room_version::rules) - } - - /// Returns the room's version. - #[tracing::instrument( - level = "trace" - skip(self), - ret, - )] - pub async fn get_room_version(&self, room_id: &RoomId) -> Result { - self.services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomCreate, "") - .await - .as_ref() - .map(room_version::from_create_content) - .cloned() - .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) - } - - #[tracing::instrument( - level = "trace" - skip(self), - ret, - )] - pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { - self.db - .roomid_shortstatehash - .get(room_id) - .await - .deserialized() - } - - pub fn get_forward_extremities<'a>( - &'a self, - room_id: &'a RoomId, - ) -> impl Stream + Send + '_ { - let prefix = (room_id, Interfix); - - self.db - .roomid_pduleaves - .keys_prefix(&prefix) - .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) - .ignore_err() - } - - #[tracing::instrument( - level = "debug" - skip_all, - fields(%room_id), - )] - pub async fn set_forward_extremities<'a, I>( - &'a self, - room_id: &'a RoomId, - event_ids: I, - _state_lock: &'a RoomMutexGuard, - ) where - I: Iterator + Send + 'a, - { - let prefix = (room_id, Interfix); - self.db - .roomid_pduleaves - .keys_prefix_raw(&prefix) - .ignore_err() - .ready_for_each(|key| self.db.roomid_pduleaves.remove(key)) - .await; - - for event_id in event_ids { - let key = (room_id, event_id); - self.db.roomid_pduleaves.put_raw(key, event_id); - } - } - - pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result { - let prefix = (room_id, Interfix); - - self.db - .roomid_pduleaves - .keys_prefix_raw(&prefix) - .ignore_err() - .ready_for_each(|key| { - trace!("Removing key: {key:?}"); - self.db.roomid_pduleaves.remove(key); - }) - .await; - - Ok(()) - } - - pub(super) async fn delete_room_shortstatehash( - &self, - room_id: &RoomId, - _mutex_lock: &Guard, - ) -> Result { - self.db.roomid_shortstatehash.remove(room_id); - - Ok(()) + Ok(*shortstatehash) + }, + | _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")), } } + +/// Set the state hash to a new version, but does not update state_cache. +#[implement(Service)] +#[tracing::instrument(skip(self, _mutex_lock), level = "debug")] +pub fn set_room_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, +) { + const BUFSIZE: usize = size_of::(); + + self.db + .roomid_shortstatehash + .raw_aput::(room_id, shortstatehash); +} + +/// This fetches auth events from the current state. +#[implement(Service)] +#[allow(clippy::too_many_arguments)] +#[tracing::instrument(skip(self, content), level = "debug")] +pub async fn get_auth_events( + &self, + room_id: &RoomId, + kind: &TimelineEventType, + sender: &UserId, + state_key: Option<&str>, + content: &serde_json::value::RawValue, + auth_rules: &AuthorizationRules, + include_create: bool, +) -> Result> +where + StateEventType: Send + Sync, + StateKey: Send + Sync, +{ + let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { + return Ok(StateMap::new()); + }; + + let sauthevents: HashMap = + auth_types_for_event(kind, sender, state_key, content, auth_rules, include_create)? + .into_iter() + .stream() + .broad_filter_map(async |(event_type, state_key): TypeStateKey| { + self.services + .short + .get_shortstatekey(&event_type, &state_key) + .await + .map(move |sstatekey| (sstatekey, (event_type, state_key))) + .ok() + }) + .collect() + .await; + + self.services + .state_accessor + .state_full_shortids(shortstatehash) + .ready_filter_map(Result::ok) + .ready_filter_map(|(shortstatekey, shorteventid)| { + sauthevents + .get(&shortstatekey) + .map(move |(ty, sk)| ((ty, sk), shorteventid)) + }) + .unzip() + .map(|(state_keys, event_ids): (Vec<_>, Vec<_>)| { + self.services + .short + .multi_get_eventid_from_short(event_ids.into_iter().stream()) + .zip(state_keys.into_iter().stream()) + }) + .flatten_stream() + .ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?))) + .broad_filter_map(async |((ty, sk), event_id): ((&_, &_), OwnedEventId)| { + let pdu = self.services.timeline.get_pdu(&event_id).await; + + Some(((ty.clone(), sk.clone()), pdu.ok()?)) + }) + .collect() + .map(Ok) + .await +} + +#[implement(Service)] +#[tracing::instrument(skip_all, level = "debug")] +pub async fn summary_stripped(&self, event: &Pdu) -> Vec> { + let cells = [ + (&StateEventType::RoomCreate, ""), + (&StateEventType::RoomJoinRules, ""), + (&StateEventType::RoomCanonicalAlias, ""), + (&StateEventType::RoomName, ""), + (&StateEventType::RoomAvatar, ""), + (&StateEventType::RoomMember, event.sender().as_str()), // Add recommended events + (&StateEventType::RoomEncryption, ""), + (&StateEventType::RoomTopic, ""), + ]; + + let fetches = cells.into_iter().map(|(event_type, state_key)| { + self.services + .state_accessor + .room_state_get(event.room_id(), event_type, state_key) + }); + + join_all(fetches) + .await + .into_iter() + .filter_map(Result::ok) + .map(Event::into_format) + .chain(once(event.to_format())) + .collect() +} + +/// Returns the room's version rules +#[implement(Service)] +#[inline] +pub async fn get_room_version_rules(&self, room_id: &RoomId) -> Result { + self.get_room_version(room_id) + .await + .and_then_ref(room_version::rules) +} + +/// Returns the room's version. +#[implement(Service)] +#[tracing::instrument( + level = "debug" + skip(self), + ret, +)] +pub async fn get_room_version(&self, room_id: &RoomId) -> Result { + self.services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomCreate, "") + .await + .as_ref() + .map(room_version::from_create_content) + .cloned() + .map_err(|e| err!(Request(NotFound("No create event found: {e:?}")))) +} + +#[implement(Service)] +#[tracing::instrument( + level = "debug" + skip(self), + ret, +)] +pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result { + self.db + .roomid_shortstatehash + .get(room_id) + .await + .deserialized() +} + +#[implement(Service)] +#[tracing::instrument( + level = "trace" + skip(self), +)] +pub fn get_forward_extremities<'a>( + &'a self, + room_id: &'a RoomId, +) -> impl Stream + Send + '_ { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix(&prefix) + .map_ok(|(_, event_id): (Ignore, &EventId)| event_id) + .ignore_err() +} + +#[implement(Service)] +#[tracing::instrument( + level = "debug" + skip_all, + fields(%room_id), +)] +pub async fn set_forward_extremities<'a, I>( + &'a self, + room_id: &'a RoomId, + event_ids: I, + _state_lock: &'a RoomMutexGuard, +) where + I: Iterator + Send + 'a, +{ + let prefix = (room_id, Interfix); + self.db + .roomid_pduleaves + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.roomid_pduleaves.remove(key)) + .await; + + for event_id in event_ids { + let key = (room_id, event_id); + self.db.roomid_pduleaves.put_raw(key, event_id); + } +} + +#[implement(Service)] +pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId) -> Result { + let prefix = (room_id, Interfix); + + self.db + .roomid_pduleaves + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| { + trace!("Removing key: {key:?}"); + self.db.roomid_pduleaves.remove(key); + }) + .await; + + Ok(()) +} + +#[implement(Service)] +pub(super) async fn delete_room_shortstatehash( + &self, + room_id: &RoomId, + _mutex_lock: &Guard, +) -> Result { + self.db.roomid_shortstatehash.remove(room_id); + + Ok(()) +}