diff --git a/src/api/client/context.rs b/src/api/client/context.rs index e2bdc426..7e2ffdd4 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -134,7 +134,7 @@ pub(crate) async fn get_context_route( .map_or_else(|| body.event_id.as_ref(), |pdu| pdu.event_id.as_ref()); let state_ids = services - .state_accessor + .state .pdu_shortstatehash(state_at) .or_else(|_| services.state.get_room_shortstatehash(room_id)) .map_ok(|shortstatehash| { diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index e1364529..0e1a27f2 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -580,7 +580,7 @@ async fn handle_left_room( }; let Ok(left_shortstatehash) = services - .state_accessor + .state .pdu_shortstatehash(&left_event_id) .await else { @@ -703,16 +703,12 @@ async fn load_joined_room( .iter() .map(at!(0)) .map(PduCount::into_unsigned) - .map(|shorteventid| { - services - .state_accessor - .get_shortstatehash(shorteventid) - }) + .map(|shorteventid| services.state.get_shortstatehash(shorteventid)) .next() .into(); let current_shortstatehash = services - .state_accessor + .state .get_shortstatehash(last_timeline_count.into_unsigned()) .or_else(|_| services.state.get_room_shortstatehash(room_id)); diff --git a/src/api/server/state.rs b/src/api/server/state.rs index bd4061cb..d7ba63bc 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -34,7 +34,7 @@ pub(crate) async fn get_room_state_route( .await?; let shortstatehash = services - .state_accessor + .state .pdu_shortstatehash(&body.event_id) .await .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 4c579435..093ad02a 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -26,7 +26,7 @@ pub(crate) async fn get_room_state_ids_route( .await?; let shortstatehash = services - .state_accessor + .state .pdu_shortstatehash(&body.event_id) .await .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 0ed2d946..9c854095 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -40,7 +40,7 @@ where let Ok(prev_event_sstatehash) = self .services - .state_accessor + .state .pdu_shortstatehash(prev_event_id) .await else { @@ -105,7 +105,7 @@ where let sstatehash = self .services - .state_accessor + .state .pdu_shortstatehash(prev_event_id); try_join(sstatehash, prev_event) diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 07129c29..a28cb887 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,7 +1,9 @@ use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc}; use async_trait::async_trait; -use futures::{FutureExt, Stream, StreamExt, TryStreamExt, future::join_all, pin_mut}; +use futures::{ + FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future::join_all, pin_mut, +}; use ruma::{ EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, events::{ @@ -490,6 +492,44 @@ pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result Result { + self.services + .short + .get_shorteventid(event_id) + .and_then(|shorteventid| self.get_shortstatehash(shorteventid)) + .await +} + +/// Returns the state hash at this event. +#[implement(Service)] +#[tracing::instrument( + level = "debug" + skip(self), + ret, +)] +pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result { + const BUFSIZE: usize = size_of::(); + + self.db + .shorteventid_shortstatehash + .aqry::(&shorteventid) + .await + .deserialized() +} + +#[implement(Service)] +pub(super) async fn delete_room_shortstatehash( + &self, + room_id: &RoomId, + _mutex_lock: &Guard, +) -> Result { + self.db.roomid_shortstatehash.remove(room_id); + + Ok(()) +} + #[implement(Service)] #[tracing::instrument( level = "trace" @@ -552,14 +592,3 @@ pub(super) async fn delete_all_rooms_forward_extremities(&self, room_id: &RoomId Ok(()) } - -#[implement(Service)] -pub(super) async fn delete_room_shortstatehash( - &self, - room_id: &RoomId, - _mutex_lock: &Guard, -) -> Result { - self.db.roomid_shortstatehash.remove(room_id); - - Ok(()) -} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 8c15c451..86f8c2b4 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -31,26 +31,15 @@ use tuwunel_core::{ Result, err, matrix::{Event, room_version, state_res::events::RoomCreateEvent}, }; -use tuwunel_database::Map; pub struct Service { services: Arc, - db: Data, -} - -struct Data { - shorteventid_shortstatehash: Arc, } #[async_trait] impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { - Ok(Arc::new(Self { - services: args.services.clone(), - db: Data { - shorteventid_shortstatehash: args.db["shorteventid_shortstatehash"].clone(), - }, - })) + Ok(Arc::new(Self { services: args.services.clone() })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } diff --git a/src/service/rooms/state_accessor/server_can.rs b/src/service/rooms/state_accessor/server_can.rs index 224018b0..3c4d23cb 100644 --- a/src/service/rooms/state_accessor/server_can.rs +++ b/src/service/rooms/state_accessor/server_can.rs @@ -18,7 +18,12 @@ pub async fn server_can_see_event( room_id: &RoomId, event_id: &EventId, ) -> bool { - let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + let Ok(shortstatehash) = self + .services + .state + .pdu_shortstatehash(event_id) + .await + else { return true; }; diff --git a/src/service/rooms/state_accessor/state.rs b/src/service/rooms/state_accessor/state.rs index f6c099f0..95dbd615 100644 --- a/src/service/rooms/state_accessor/state.rs +++ b/src/service/rooms/state_accessor/state.rs @@ -18,7 +18,6 @@ use tuwunel_core::{ stream::{BroadbandExt, IterStream, ReadyExt, TryIgnore}, }, }; -use tuwunel_database::Deserialized; use crate::rooms::{ short::{ShortEventId, ShortStateHash, ShortStateKey}, @@ -421,25 +420,3 @@ async fn load_full_state(&self, shortstatehash: ShortStateHash) -> Result Result { - self.services - .short - .get_shorteventid(event_id) - .and_then(|shorteventid| self.get_shortstatehash(shorteventid)) - .await -} - -/// Returns the state hash at this event. -#[implement(super::Service)] -pub async fn get_shortstatehash(&self, shorteventid: ShortEventId) -> Result { - const BUFSIZE: usize = size_of::(); - - self.db - .shorteventid_shortstatehash - .aqry::(&shorteventid) - .await - .deserialized() -} diff --git a/src/service/rooms/state_accessor/user_can.rs b/src/service/rooms/state_accessor/user_can.rs index 5d574562..a8f77bfc 100644 --- a/src/service/rooms/state_accessor/user_can.rs +++ b/src/service/rooms/state_accessor/user_can.rs @@ -83,7 +83,12 @@ pub async fn user_can_see_event( room_id: &RoomId, event_id: &EventId, ) -> bool { - let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + let Ok(shortstatehash) = self + .services + .state + .pdu_shortstatehash(event_id) + .await + else { return true; }; diff --git a/src/service/rooms/timeline/append.rs b/src/service/rooms/timeline/append.rs index 43001018..42a51c3e 100644 --- a/src/service/rooms/timeline/append.rs +++ b/src/service/rooms/timeline/append.rs @@ -117,7 +117,7 @@ where { if let Ok(shortstatehash) = self .services - .state_accessor + .state .pdu_shortstatehash(pdu.event_id()) .await { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index f48e41a8..6b4c344c 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -179,7 +179,7 @@ pub async fn prev_shortstatehash( let prev = self.prev_timeline_count(room_id, before).await?; self.services - .state_accessor + .state .get_shortstatehash(prev.into_unsigned()) .await } @@ -197,7 +197,7 @@ pub async fn next_shortstatehash( let next = self.next_timeline_count(room_id, after).await?; self.services - .state_accessor + .state .get_shortstatehash(next.into_unsigned()) .await }