Persist sliding-sync state; mitigate initial-sync.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -29,7 +29,7 @@ pub(crate) enum SyncCommand {
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn list_connections(&self) -> Result {
|
||||
let connections = self.services.sync.list_connections();
|
||||
let connections = self.services.sync.list_loaded_connections().await;
|
||||
|
||||
for connection_key in connections {
|
||||
self.write_str(&format!("{connection_key:?}\n"))
|
||||
@@ -47,7 +47,11 @@ pub(super) async fn show_connection(
|
||||
conn_id: Option<String>,
|
||||
) -> Result {
|
||||
let key = into_connection_key(user_id, device_id, conn_id);
|
||||
let cache = self.services.sync.find_connection(&key)?;
|
||||
let cache = self
|
||||
.services
|
||||
.sync
|
||||
.get_loaded_connection(&key)
|
||||
.await?;
|
||||
|
||||
let out;
|
||||
{
|
||||
@@ -65,11 +69,14 @@ pub(super) async fn drop_connections(
|
||||
device_id: Option<OwnedDeviceId>,
|
||||
conn_id: Option<String>,
|
||||
) -> Result {
|
||||
self.services.sync.clear_connections(
|
||||
user_id.as_deref(),
|
||||
device_id.as_deref(),
|
||||
conn_id.map(Into::into).as_ref(),
|
||||
);
|
||||
self.services
|
||||
.sync
|
||||
.clear_connections(
|
||||
user_id.as_deref(),
|
||||
device_id.as_deref(),
|
||||
conn_id.map(Into::into).as_ref(),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ use tokio::time::{Instant, timeout_at};
|
||||
use tuwunel_core::{
|
||||
Err, Result, apply, at, debug,
|
||||
debug::INFO_SPAN_LEVEL,
|
||||
err,
|
||||
debug_warn,
|
||||
error::inspect_log,
|
||||
extract_variant,
|
||||
smallvec::SmallVec,
|
||||
@@ -115,11 +115,10 @@ pub(crate) async fn sync_events_v5_route(
|
||||
.unwrap_or(0);
|
||||
|
||||
let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref());
|
||||
let conn_val = since
|
||||
.ne(&0)
|
||||
.then(|| services.sync.find_connection(&conn_key))
|
||||
.unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key)))
|
||||
.map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?;
|
||||
let conn_val = services
|
||||
.sync
|
||||
.load_or_init_connection(&conn_key)
|
||||
.await;
|
||||
|
||||
let conn = conn_val.lock();
|
||||
let ping_presence = services
|
||||
@@ -130,10 +129,22 @@ pub(crate) async fn sync_events_v5_route(
|
||||
|
||||
let (mut conn, _) = join(conn, ping_presence).await;
|
||||
|
||||
if since != 0 && conn.next_batch == 0 {
|
||||
return Err!(Request(UnknownPos(warn!("Connection lost; restarting sync stream."))));
|
||||
}
|
||||
|
||||
if since == 0 {
|
||||
*conn = Connection::default();
|
||||
conn.store(&services.sync, &conn_key);
|
||||
debug_warn!(?conn_key, "Client cleared cache and reloaded.");
|
||||
}
|
||||
|
||||
let advancing = since == conn.next_batch;
|
||||
let retarding = since <= conn.globalsince;
|
||||
let retarding = since != 0 && since <= conn.globalsince;
|
||||
if !advancing && !retarding {
|
||||
return Err!(Request(UnknownPos("Requesting unknown or stale stream position.")));
|
||||
return Err!(Request(UnknownPos(warn!(
|
||||
"Requesting unknown or invalid stream position."
|
||||
))));
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
@@ -189,6 +200,7 @@ pub(crate) async fn sync_events_v5_route(
|
||||
if !is_empty_response(&response) {
|
||||
response.pos = conn.next_batch.to_string().into();
|
||||
trace!(conn.globalsince, conn.next_batch, "response {response:?}");
|
||||
conn.store(&services.sync, &conn_key);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
@@ -202,6 +214,7 @@ pub(crate) async fn sync_events_v5_route(
|
||||
{
|
||||
response.pos = conn.next_batch.to_string().into();
|
||||
trace!(conn.globalsince, conn.next_batch, "timeout; empty response {response:?}");
|
||||
conn.store(&services.sync, &conn_key);
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
|
||||
@@ -341,6 +341,11 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "userdeviceid_metadata",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userdeviceconnid_conn",
|
||||
block_size: 1024 * 16,
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userdeviceid_refresh",
|
||||
..descriptor::RANDOM_SMALL
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
mod watch;
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
ops::Bound::Included,
|
||||
sync::{Arc, Mutex as StdMutex},
|
||||
collections::{BTreeMap, btree_map::Entry},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures::{FutureExt, Stream};
|
||||
use ruma::{
|
||||
DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
api::client::sync::sync_events::v5::{
|
||||
@@ -13,9 +13,10 @@ use ruma::{
|
||||
request::{AccountData, E2EE, Receipts, ToDevice, Typing},
|
||||
},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tuwunel_core::{Result, at, err, implement, is_equal_to, smallvec::SmallVec};
|
||||
use tuwunel_database::Map;
|
||||
use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
|
||||
pub struct Service {
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
@@ -23,7 +24,8 @@ pub struct Service {
|
||||
db: Data,
|
||||
}
|
||||
|
||||
pub struct Data {
|
||||
struct Data {
|
||||
userdeviceconnid_conn: Arc<Map>,
|
||||
todeviceid_events: Arc<Map>,
|
||||
userroomid_joined: Arc<Map>,
|
||||
userroomid_invitestate: Arc<Map>,
|
||||
@@ -40,22 +42,22 @@ pub struct Data {
|
||||
roomuserid_lastnotificationread: Arc<Map>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Connection {
|
||||
pub lists: Lists,
|
||||
pub rooms: Rooms,
|
||||
pub subscriptions: Subscriptions,
|
||||
pub extensions: request::Extensions,
|
||||
pub globalsince: u64,
|
||||
pub next_batch: u64,
|
||||
pub lists: Lists,
|
||||
pub extensions: request::Extensions,
|
||||
pub subscriptions: Subscriptions,
|
||||
pub rooms: Rooms,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Room {
|
||||
pub roomsince: u64,
|
||||
}
|
||||
|
||||
type Connections = StdMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
|
||||
type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
|
||||
pub type ConnectionVal = Arc<TokioMutex<Connection>>;
|
||||
pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option<ConnectionId>);
|
||||
|
||||
@@ -67,6 +69,7 @@ impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
db: Data {
|
||||
userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(),
|
||||
todeviceid_events: args.db["todeviceid_events"].clone(),
|
||||
userroomid_joined: args.db["userroomid_joined"].clone(),
|
||||
userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
|
||||
@@ -91,7 +94,153 @@ impl crate::Service for Service {
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn clear_connections(
|
||||
&self,
|
||||
user_id: Option<&UserId>,
|
||||
device_id: Option<&DeviceId>,
|
||||
conn_id: Option<&ConnectionId>,
|
||||
) {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
|
||||
let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
|
||||
&& device_id.is_none_or(is_equal_to!(conn_device_id))
|
||||
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
|
||||
|
||||
if !retain {
|
||||
self.db
|
||||
.userdeviceconnid_conn
|
||||
.del((conn_user_id, conn_device_id, conn_conn_id));
|
||||
}
|
||||
|
||||
retain
|
||||
});
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn drop_connection(&self, key: &ConnectionKey) {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
self.db.userdeviceconnid_conn.del(key);
|
||||
cache.remove(key);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
match cache.entry(key.clone()) {
|
||||
| Entry::Occupied(val) => val.get().clone(),
|
||||
| Entry::Vacant(val) => {
|
||||
let conn = self
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.qry(key)
|
||||
.boxed()
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
.map(TokioMutex::new)
|
||||
.map(Arc::new)
|
||||
.unwrap_or_default();
|
||||
|
||||
val.insert(conn).clone()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
let mut cache = self.connections.lock().await;
|
||||
|
||||
match cache.entry(key.clone()) {
|
||||
| Entry::Occupied(val) => Ok(val.get().clone()),
|
||||
| Entry::Vacant(val) => self
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.qry(key)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
.map(TokioMutex::new)
|
||||
.map(Arc::new)
|
||||
.map(|conn| val.insert(conn).clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.get(key)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn list_loaded_connections(&self) -> Vec<ConnectionKey> {
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn list_stored_connections(&self) -> impl Stream<Item = ConnectionKey> {
|
||||
self.db.userdeviceconnid_conn.keys().ignore_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool {
|
||||
self.connections.lock().await.contains_key(key)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
|
||||
self.db.userdeviceconnid_conn.contains(key).await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
|
||||
where
|
||||
U: Into<OwnedUserId>,
|
||||
D: Into<OwnedDeviceId>,
|
||||
C: Into<ConnectionId>,
|
||||
{
|
||||
(user_id.into(), device_id.into(), conn_id.map(Into::into))
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip(self, service))]
|
||||
pub fn store(&self, service: &Service, key: &ConnectionKey) {
|
||||
service
|
||||
.db
|
||||
.userdeviceconnid_conn
|
||||
.put(key, Cbor(self));
|
||||
|
||||
debug!(
|
||||
since = %self.globalsince,
|
||||
next_batch = %self.next_batch,
|
||||
"Persisted connection state"
|
||||
);
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
|
||||
self.rooms.values_mut().for_each(|room| {
|
||||
if let Some(retard_since) = retard_since {
|
||||
@@ -103,6 +252,7 @@ pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms)
|
||||
where
|
||||
Rooms: Iterator<Item = &'a RoomId> + Send + 'a,
|
||||
@@ -115,6 +265,7 @@ where
|
||||
}
|
||||
|
||||
#[implement(Connection)]
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub fn update_cache(&mut self, request: &Request) {
|
||||
Self::update_cache_lists(request, self);
|
||||
Self::update_cache_subscriptions(request, self);
|
||||
@@ -215,98 +366,3 @@ fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
|
||||
cached.replace(target.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn clear_connections(
|
||||
&self,
|
||||
user_id: Option<&UserId>,
|
||||
device_id: Option<&DeviceId>,
|
||||
conn_id: Option<&ConnectionId>,
|
||||
) {
|
||||
self.connections.lock().expect("locked").retain(
|
||||
|(conn_user_id, conn_device_id, conn_conn_id), _| {
|
||||
!(user_id.is_none_or(is_equal_to!(conn_user_id))
|
||||
&& device_id.is_none_or(is_equal_to!(conn_device_id))
|
||||
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref()))
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn drop_connection(&self, key: &ConnectionKey) {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.remove(key);
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.entry(key.clone())
|
||||
.and_modify(|existing| *existing = ConnectionVal::default())
|
||||
.or_default()
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn device_connections(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
exclude: Option<&ConnectionId>,
|
||||
) -> impl Iterator<Item = ConnectionVal> + Send {
|
||||
type Siblings = SmallVec<[ConnectionVal; 4]>;
|
||||
|
||||
let key = into_connection_key(user_id, device_id, None::<ConnectionId>);
|
||||
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.range((Included(&key), Included(&key)))
|
||||
.filter(|((_, _, id), _)| id.as_ref() != exclude)
|
||||
.map(at!(1))
|
||||
.cloned()
|
||||
.collect::<Siblings>()
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn list_connections(&self) -> Vec<ConnectionKey> {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn find_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.get(key)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn contains_connection(&self, key: &ConnectionKey) -> bool {
|
||||
self.connections
|
||||
.lock()
|
||||
.expect("locked")
|
||||
.contains_key(key)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
|
||||
where
|
||||
U: Into<OwnedUserId>,
|
||||
D: Into<OwnedDeviceId>,
|
||||
C: Into<ConnectionId>,
|
||||
{
|
||||
(user_id.into(), device_id.into(), conn_id.map(Into::into))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user