diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index 6a0196e9..ce8c5540 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -74,6 +74,10 @@ pub(crate) enum RoomStateCacheCommand { user_id: OwnedUserId, room_id: OwnedRoomId, }, + + UserMemberships { + user_id: OwnedUserId, + }, } pub(super) async fn process(subcommand: RoomStateCacheCommand, context: &Context<'_>) -> Result { @@ -316,6 +320,22 @@ pub(super) async fn process(subcommand: RoomStateCacheCommand, context: &Context .await; let query_time = timer.elapsed(); + context + .write_str(&format!( + "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" + )) + .await + }, + | RoomStateCacheCommand::UserMemberships { user_id } => { + let timer = tokio::time::Instant::now(); + let results = services + .state_cache + .all_user_memberships(&user_id) + .map(|(membership, room_id)| (membership, room_id.to_owned())) + .collect::>() + .await; + let query_time = timer.elapsed(); + context .write_str(&format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 16893f76..83445a34 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -6,7 +6,11 @@ use std::{ sync::{Arc, RwLock}, }; -use futures::{Stream, StreamExt, future::join5, pin_mut}; +use futures::{ + Stream, StreamExt, + future::{OptionFuture, join5}, + pin_mut, +}; use ruma::{ OwnedRoomId, RoomId, ServerName, UserId, events::{AnyStrippedStateEvent, AnySyncStateEvent, room::member::MembershipState}, @@ -16,7 +20,10 @@ use tuwunel_core::{ Result, implement, result::LogErr, trace, - utils::stream::{BroadbandExt, ReadyExt, TryIgnore}, + utils::{ + future::OptionStream, + stream::{BroadbandExt, ReadyExt, TryIgnore}, + }, warn, }; use tuwunel_database::{Deserialized, Ignore, Interfix, Map}; @@ -382,6 +389,71 @@ pub async fn get_joined_count(&self, room_id: &RoomId, user_id: &UserId) -> Resu .deserialized() } +/// Returns an iterator over all memberships for a user. +#[implement(Service)] +#[inline] +pub fn all_user_memberships<'a>( + &'a self, + user_id: &'a UserId, +) -> impl Stream + Send + 'a { + use MembershipState::*; + + self.user_memberships(user_id, &[Join, Invite, Knock, Leave]) +} + +/// Returns an iterator over all specified memberships for a user. +#[implement(Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn user_memberships<'a>( + &'a self, + user_id: &'a UserId, + filter: &[MembershipState], +) -> impl Stream + Send + 'a { + use MembershipState::*; + use futures::stream::select; + + let joined: OptionFuture<_> = filter + .contains(&Join) + .then(|| { + self.rooms_joined(user_id) + .map(|room_id| (Join, room_id)) + .into_future() + }) + .into(); + + let invited: OptionFuture<_> = filter + .contains(&Invite) + .then(|| { + self.rooms_invited(user_id) + .map(|room_id| (Invite, room_id)) + .into_future() + }) + .into(); + + let knocked: OptionFuture<_> = filter + .contains(&Knock) + .then(|| { + self.rooms_knocked(user_id) + .map(|room_id| (Knock, room_id)) + .into_future() + }) + .into(); + + let left: OptionFuture<_> = filter + .contains(&Leave) + .then(|| { + self.rooms_left(user_id) + .map(|room_id| (Leave, room_id)) + .into_future() + }) + .into(); + + select( + select(joined.stream(), left.stream()), + select(invited.stream(), knocked.stream()), + ) +} + /// Returns an iterator over all rooms this user joined. #[implement(Service)] #[tracing::instrument(skip(self), level = "debug")]