diff --git a/src/core/utils/set.rs b/src/core/utils/set.rs index d3485bb8..4b2e64a7 100644 --- a/src/core/utils/set.rs +++ b/src/core/utils/set.rs @@ -1,12 +1,17 @@ use std::{ cmp::{Eq, Ord}, + convert::identity, pin::Pin, sync::Arc, }; -use futures::{Stream, StreamExt}; +use futures::{ + Stream, StreamExt, + stream::{Peekable, unfold}, +}; +use tokio::sync::Mutex; -use crate::{is_equal_to, is_less_than}; +use crate::{is_equal_to, is_less_than, utils::stream::ReadyExt}; /// Intersection of sets /// @@ -57,30 +62,31 @@ where /// Intersection of sets /// /// Outputs the set of elements common to both streams. Streams must be sorted. -pub fn intersection_sorted_stream2(a: A, b: B) -> impl Stream + Send +pub fn intersection_sorted_stream2(a: S, b: S) -> impl Stream + Send where - A: Stream + Send, - B: Stream + Send + Unpin, + S: Stream + Send + Unpin, Item: Eq + PartialOrd + Send + Sync, { - use tokio::sync::Mutex; + struct State { + a: S, + b: Peekable, + } - let b = Arc::new(Mutex::new(b.peekable())); - a.map(move |ai| (ai, b.clone())) - .filter_map(async move |(ai, b)| { - let mut lock = b.lock().await; - while let Some(bi) = Pin::new(&mut *lock) - .next_if(|bi| *bi <= ai) - .await - .as_ref() - { - if ai == *bi { - return Some(ai); - } + unfold(State { a, b: b.peekable() }, async |mut state| { + let ai = state.a.next().await?; + while let Some(bi) = Pin::new(&mut state.b) + .next_if(|bi| *bi <= ai) + .await + .as_ref() + { + if ai == *bi { + return Some((Some(ai), state)); } + } - None - }) + Some((None, state)) + }) + .ready_filter_map(identity) } /// Difference of sets @@ -93,8 +99,6 @@ where B: Stream + Send + Unpin, Item: Eq + PartialOrd + Send + Sync, { - use tokio::sync::Mutex; - let b = Arc::new(Mutex::new(b.peekable())); a.map(move |ai| (ai, b.clone())) .filter_map(async move |(ai, b)| { diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index b63e678c..2f091bb8 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -17,7 +17,7 @@ use tuwunel_core::{ result::LogErr, trace, utils::{ - BoolExt, + self, BoolExt, future::OptionStream, stream::{BroadbandExt, ReadyExt, TryIgnore}, }, @@ -207,12 +207,10 @@ pub fn get_shared_rooms<'a>( user_a: &'a UserId, user_b: &'a UserId, ) -> impl Stream + Send + 'a { - use tuwunel_core::utils::set; - - let a = self.rooms_joined(user_a); + let a = self.rooms_joined(user_a).boxed(); let b = self.rooms_joined(user_b).boxed(); - set::intersection_sorted_stream2(a, b) + utils::set::intersection_sorted_stream2(a, b) } /// Returns an iterator of all joined members of a room.