diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index c06a563f..c37d2ade 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -2,6 +2,7 @@ mod watch; use std::{ collections::BTreeMap, + ops::Bound::Included, sync::{Arc, Mutex as StdMutex}, }; @@ -13,7 +14,7 @@ use ruma::{ }, }; use tokio::sync::Mutex as TokioMutex; -use tuwunel_core::{Result, err, implement, is_equal_to}; +use tuwunel_core::{Result, at, err, implement, is_equal_to, smallvec::SmallVec}; use tuwunel_database::Map; pub struct Service { @@ -241,16 +242,6 @@ pub fn drop_connection(&self, key: &ConnectionKey) { .remove(key); } -#[implement(Service)] -pub fn list_connections(&self) -> Vec { - self.connections - .lock() - .expect("locked") - .keys() - .cloned() - .collect() -} - #[implement(Service)] pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal { self.connections @@ -262,6 +253,38 @@ pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal { .clone() } +#[implement(Service)] +pub fn device_connections( + &self, + user_id: &UserId, + device_id: &DeviceId, + exclude: Option<&ConnectionId>, +) -> impl Iterator + Send { + type Siblings = SmallVec<[ConnectionVal; 4]>; + + let key = into_connection_key(user_id, device_id, None::); + + self.connections + .lock() + .expect("locked") + .range((Included(&key), Included(&key))) + .filter(|((_, _, id), _)| id.as_ref() != exclude) + .map(at!(1)) + .cloned() + .collect::() + .into_iter() +} + +#[implement(Service)] +pub fn list_connections(&self) -> Vec { + self.connections + .lock() + .expect("locked") + .keys() + .cloned() + .collect() +} + #[implement(Service)] pub fn find_connection(&self, key: &ConnectionKey) -> Result { self.connections