Persist sliding-sync state; mitigate initial-sync.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
mod watch;
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
ops::Bound::Included,
|
||||
sync::{Arc, Mutex as StdMutex},
|
||||
collections::{BTreeMap, btree_map::Entry},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures::{FutureExt, Stream};
|
||||
use ruma::{
|
||||
DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
api::client::sync::sync_events::v5::{
|
||||
@@ -13,9 +13,10 @@ use ruma::{
|
||||
request::{AccountData, E2EE, Receipts, ToDevice, Typing},
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tuwunel_core::{Result, at, err, implement, is_equal_to, smallvec::SmallVec};
|
||||
use tuwunel_database::Map;
|
||||
use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
|
||||
pub struct Service {
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
@@ -23,7 +24,8 @@ pub struct Service {
|
||||
db: Data,
|
||||
}
|
||||
|
||||
pub struct Data {
|
||||
struct Data {
|
||||
userdeviceconnid_conn: Arc<Map>,
|
||||
todeviceid_events: Arc<Map>,
|
||||
userroomid_joined: Arc<Map>,
|
||||
userroomid_invitestate: Arc<Map>,
|
||||
@@ -40,22 +42,22 @@ pub struct Data {
|
||||
roomuserid_lastnotificationread: Arc<Map>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Connection {
|
||||
pub lists: Lists,
|
||||
pub rooms: Rooms,
|
||||
pub subscriptions: Subscriptions,
|
||||
pub extensions: request::Extensions,
|
||||
pub globalsince: u64,
|
||||
pub next_batch: u64,
|
||||
pub lists: Lists,
|
||||
pub extensions: request::Extensions,
|
||||
pub subscriptions: Subscriptions,
|
||||
pub rooms: Rooms,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Room {
|
||||
pub roomsince: u64,
|
||||
}
|
||||
|
||||
type Connections = StdMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
|
||||
type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
|
||||
pub type ConnectionVal = Arc<TokioMutex<Connection>>;
|
||||
pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option<ConnectionId>);
|
||||
|
||||
@@ -67,6 +69,7 @@ impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
db: Data {
|
||||
userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(),
|
||||
todeviceid_events: args.db["todeviceid_events"].clone(),
|
||||
userroomid_joined: args.db["userroomid_joined"].clone(),
|
||||
userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
|
||||
@@ -91,7 +94,153 @@ impl crate::Service for Service {
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn clear_connections(
|
||||
&self,
|
||||
user_id: Option<&UserId>,
|
||||
device_id: Option<&DeviceId>,
|
||||
conn_id: Option<&ConnectionId>,
|
||||
) {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
|
||||
let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
|
||||
&& device_id.is_none_or(is_equal_to!(conn_device_id))
|
||||
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
|
||||
|
||||
if !retain {
|
||||
self.db
|
||||
.userdeviceconnid_conn
|
||||
.del((conn_user_id, conn_device_id, conn_conn_id));
|
||||
}
|
||||
|
||||
retain
|
||||
});
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn drop_connection(&self, key: &ConnectionKey) {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
self.db.userdeviceconnid_conn.del(key);
|
||||
cache.remove(key);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
match cache.entry(key.clone()) {
|
||||
| Entry::Occupied(val) => val.get().clone(),
|
||||
| Entry::Vacant(val) => {
|
||||
let conn = self
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.qry(key)
|
||||
.boxed()
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
.map(TokioMutex::new)
|
||||
.map(Arc::new)
|
||||
.unwrap_or_default();
|
||||
|
||||
val.insert(conn).clone()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
match cache.entry(key.clone()) {
|
||||
| Entry::Occupied(val) => Ok(val.get().clone()),
|
||||
| Entry::Vacant(val) => self
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.qry(key)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
.map(TokioMutex::new)
|
||||
.map(Arc::new)
|
||||
.map(|conn| val.insert(conn).clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.get(key)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn list_loaded_connections(&self) -> Vec<ConnectionKey> {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn list_stored_connections(&self) -> impl Stream<Item = ConnectionKey> {
|
||||
self.db.userdeviceconnid_conn.keys().ignore_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool {
|
||||
self.connections.lock().await.contains_key(key)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
|
||||
self.db.userdeviceconnid_conn.contains(key).await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
|
||||
where
|
||||
U: Into<OwnedUserId>,
|
||||
D: Into<OwnedDeviceId>,
|
||||
C: Into<ConnectionId>,
|
||||
{
|
||||
(user_id.into(), device_id.into(), conn_id.map(Into::into))
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip(self, service))]
|
||||
pub fn store(&self, service: &Service, key: &ConnectionKey) {
|
||||
service
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.put(key, Cbor(self));
|
||||
|
||||
debug!(
|
||||
since = %self.globalsince,
|
||||
next_batch = %self.next_batch,
|
||||
"Persisted connection state"
|
||||
);
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
|
||||
self.rooms.values_mut().for_each(|room| {
|
||||
if let Some(retard_since) = retard_since {
|
||||
@@ -103,6 +252,7 @@ pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms)
|
||||
where
|
||||
Rooms: Iterator<Item = &'a RoomId> + Send + 'a,
|
||||
@@ -115,6 +265,7 @@ where
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub fn update_cache(&mut self, request: &Request) {
|
||||
Self::update_cache_lists(request, self);
|
||||
Self::update_cache_subscriptions(request, self);
|
||||
@@ -215,98 +366,3 @@ fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
|
||||
cached.replace(target.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn clear_connections(
|
||||
&self,
|
||||
user_id: Option<&UserId>,
|
||||
device_id: Option<&DeviceId>,
|
||||
conn_id: Option<&ConnectionId>,
|
||||
) {
|
||||
self.connections.lock().expect("locked").retain(
|
||||
|(conn_user_id, conn_device_id, conn_conn_id), _| {
|
||||
!(user_id.is_none_or(is_equal_to!(conn_user_id))
|
||||
&& device_id.is_none_or(is_equal_to!(conn_device_id))
|
||||
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref()))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn drop_connection(&self, key: &ConnectionKey) {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.remove(key);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.entry(key.clone())
|
||||
.and_modify(|existing| *existing = ConnectionVal::default())
|
||||
.or_default()
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn device_connections(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
exclude: Option<&ConnectionId>,
|
||||
) -> impl Iterator<Item = ConnectionVal> + Send {
|
||||
type Siblings = SmallVec<[ConnectionVal; 4]>;
|
||||
|
||||
let key = into_connection_key(user_id, device_id, None::<ConnectionId>);
|
||||
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.range((Included(&key), Included(&key)))
|
||||
.filter(|((_, _, id), _)| id.as_ref() != exclude)
|
||||
.map(at!(1))
|
||||
.cloned()
|
||||
.collect::<Siblings>()
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn list_connections(&self) -> Vec<ConnectionKey> {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn find_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.get(key)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn contains_connection(&self, key: &ConnectionKey) -> bool {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.contains_key(key)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
|
||||
where
|
||||
U: Into<OwnedUserId>,
|
||||
D: Into<OwnedDeviceId>,
|
||||
C: Into<ConnectionId>,
|
||||
{
|
||||
(user_id.into(), device_id.into(), conn_id.map(Into::into))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user