Persist sliding-sync state; mitigate initial-sync.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-10-25 13:14:23 +00:00
parent 718c3adcb2
commit f66a83763e
4 changed files with 204 additions and 123 deletions

View File

@@ -29,7 +29,7 @@ pub(crate) enum SyncCommand {
#[admin_command]
pub(super) async fn list_connections(&self) -> Result {
let connections = self.services.sync.list_connections();
let connections = self.services.sync.list_loaded_connections().await;
for connection_key in connections {
self.write_str(&format!("{connection_key:?}\n"))
@@ -47,7 +47,11 @@ pub(super) async fn show_connection(
conn_id: Option<String>,
) -> Result {
let key = into_connection_key(user_id, device_id, conn_id);
let cache = self.services.sync.find_connection(&key)?;
let cache = self
.services
.sync
.get_loaded_connection(&key)
.await?;
let out;
{
@@ -65,11 +69,14 @@ pub(super) async fn drop_connections(
device_id: Option<OwnedDeviceId>,
conn_id: Option<String>,
) -> Result {
self.services.sync.clear_connections(
user_id.as_deref(),
device_id.as_deref(),
conn_id.map(Into::into).as_ref(),
);
self.services
.sync
.clear_connections(
user_id.as_deref(),
device_id.as_deref(),
conn_id.map(Into::into).as_ref(),
)
.await;
Ok(())
}

View File

@@ -25,7 +25,7 @@ use tokio::time::{Instant, timeout_at};
use tuwunel_core::{
Err, Result, apply, at, debug,
debug::INFO_SPAN_LEVEL,
err,
debug_warn,
error::inspect_log,
extract_variant,
smallvec::SmallVec,
@@ -115,11 +115,10 @@ pub(crate) async fn sync_events_v5_route(
.unwrap_or(0);
let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref());
let conn_val = since
.ne(&0)
.then(|| services.sync.find_connection(&conn_key))
.unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key)))
.map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?;
let conn_val = services
.sync
.load_or_init_connection(&conn_key)
.await;
let conn = conn_val.lock();
let ping_presence = services
@@ -130,10 +129,22 @@ pub(crate) async fn sync_events_v5_route(
let (mut conn, _) = join(conn, ping_presence).await;
if since != 0 && conn.next_batch == 0 {
return Err!(Request(UnknownPos(warn!("Connection lost; restarting sync stream."))));
}
if since == 0 {
*conn = Connection::default();
conn.store(&services.sync, &conn_key);
debug_warn!(?conn_key, "Client cleared cache and reloaded.");
}
let advancing = since == conn.next_batch;
let retarding = since <= conn.globalsince;
let retarding = since != 0 && since <= conn.globalsince;
if !advancing && !retarding {
return Err!(Request(UnknownPos("Requesting unknown or stale stream position.")));
return Err!(Request(UnknownPos(warn!(
"Requesting unknown or invalid stream position."
))));
}
debug_assert!(
@@ -189,6 +200,7 @@ pub(crate) async fn sync_events_v5_route(
if !is_empty_response(&response) {
response.pos = conn.next_batch.to_string().into();
trace!(conn.globalsince, conn.next_batch, "response {response:?}");
conn.store(&services.sync, &conn_key);
return Ok(response);
}
}
@@ -202,6 +214,7 @@ pub(crate) async fn sync_events_v5_route(
{
response.pos = conn.next_batch.to_string().into();
trace!(conn.globalsince, conn.next_batch, "timeout; empty response {response:?}");
conn.store(&services.sync, &conn_key);
return Ok(response);
}

View File

@@ -341,6 +341,11 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userdeviceid_metadata",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceconnid_conn",
block_size: 1024 * 16,
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userdeviceid_refresh",
..descriptor::RANDOM_SMALL

View File

@@ -1,11 +1,11 @@
mod watch;
use std::{
collections::BTreeMap,
ops::Bound::Included,
sync::{Arc, Mutex as StdMutex},
collections::{BTreeMap, btree_map::Entry},
sync::Arc,
};
use futures::{FutureExt, Stream};
use ruma::{
DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
api::client::sync::sync_events::v5::{
@@ -13,9 +13,10 @@ use ruma::{
request::{AccountData, E2EE, Receipts, ToDevice, Typing},
},
};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex as TokioMutex;
use tuwunel_core::{Result, at, err, implement, is_equal_to, smallvec::SmallVec};
use tuwunel_database::Map;
use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
use tuwunel_database::{Cbor, Deserialized, Map};
pub struct Service {
services: Arc<crate::services::OnceServices>,
@@ -23,7 +24,8 @@ pub struct Service {
db: Data,
}
pub struct Data {
struct Data {
userdeviceconnid_conn: Arc<Map>,
todeviceid_events: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_invitestate: Arc<Map>,
@@ -40,22 +42,22 @@ pub struct Data {
roomuserid_lastnotificationread: Arc<Map>,
}
#[derive(Debug, Default)]
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct Connection {
pub lists: Lists,
pub rooms: Rooms,
pub subscriptions: Subscriptions,
pub extensions: request::Extensions,
pub globalsince: u64,
pub next_batch: u64,
pub lists: Lists,
pub extensions: request::Extensions,
pub subscriptions: Subscriptions,
pub rooms: Rooms,
}
#[derive(Clone, Copy, Debug, Default)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
pub struct Room {
pub roomsince: u64,
}
type Connections = StdMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
pub type ConnectionVal = Arc<TokioMutex<Connection>>;
pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option<ConnectionId>);
@@ -67,6 +69,7 @@ impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data {
userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(),
todeviceid_events: args.db["todeviceid_events"].clone(),
userroomid_joined: args.db["userroomid_joined"].clone(),
userroomid_invitestate: args.db["userroomid_invitestate"].clone(),
@@ -91,7 +94,153 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn clear_connections(
&self,
user_id: Option<&UserId>,
device_id: Option<&DeviceId>,
conn_id: Option<&ConnectionId>,
) {
self.connections
.lock()
.await
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
&& device_id.is_none_or(is_equal_to!(conn_device_id))
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
if !retain {
self.db
.userdeviceconnid_conn
.del((conn_user_id, conn_device_id, conn_conn_id));
}
retain
});
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn drop_connection(&self, key: &ConnectionKey) {
let mut cache = self.connections.lock().await;
self.db.userdeviceconnid_conn.del(key);
cache.remove(key);
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
let mut cache = self.connections.lock().await;
match cache.entry(key.clone()) {
| Entry::Occupied(val) => val.get().clone(),
| Entry::Vacant(val) => {
let conn = self
.db
.userdeviceconnid_conn
.qry(key)
.boxed()
.await
.deserialized::<Cbor<_>>()
.map(at!(0))
.map(TokioMutex::new)
.map(Arc::new)
.unwrap_or_default();
val.insert(conn).clone()
},
}
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
let mut cache = self.connections.lock().await;
match cache.entry(key.clone()) {
| Entry::Occupied(val) => Ok(val.get().clone()),
| Entry::Vacant(val) => self
.db
.userdeviceconnid_conn
.qry(key)
.await
.deserialized::<Cbor<_>>()
.map(at!(0))
.map(TokioMutex::new)
.map(Arc::new)
.map(|conn| val.insert(conn).clone()),
}
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
self.connections
.lock()
.await
.get(key)
.cloned()
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
}
#[implement(Service)]
#[tracing::instrument(level = "trace", skip(self))]
pub async fn list_loaded_connections(&self) -> Vec<ConnectionKey> {
self.connections
.lock()
.await
.keys()
.cloned()
.collect()
}
#[implement(Service)]
#[tracing::instrument(level = "trace", skip(self))]
pub fn list_stored_connections(&self) -> impl Stream<Item = ConnectionKey> {
self.db.userdeviceconnid_conn.keys().ignore_err()
}
#[implement(Service)]
#[tracing::instrument(level = "trace", skip(self))]
pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool {
self.connections.lock().await.contains_key(key)
}
#[implement(Service)]
#[tracing::instrument(level = "trace", skip(self))]
pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
self.db.userdeviceconnid_conn.contains(key).await
}
#[inline]
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
where
U: Into<OwnedUserId>,
D: Into<OwnedDeviceId>,
C: Into<ConnectionId>,
{
(user_id.into(), device_id.into(), conn_id.map(Into::into))
}
#[implement(Connection)]
#[tracing::instrument(level = "debug", skip(self, service))]
pub fn store(&self, service: &Service, key: &ConnectionKey) {
service
.db
.userdeviceconnid_conn
.put(key, Cbor(self));
debug!(
since = %self.globalsince,
next_batch = %self.next_batch,
"Persisted connection state"
);
}
#[implement(Connection)]
#[tracing::instrument(level = "debug", skip(self))]
pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
self.rooms.values_mut().for_each(|room| {
if let Some(retard_since) = retard_since {
@@ -103,6 +252,7 @@ pub fn update_rooms_prologue(&mut self, retard_since: Option<u64>) {
}
#[implement(Connection)]
#[tracing::instrument(level = "debug", skip_all)]
pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms)
where
Rooms: Iterator<Item = &'a RoomId> + Send + 'a,
@@ -115,6 +265,7 @@ where
}
#[implement(Connection)]
#[tracing::instrument(level = "debug", skip_all)]
pub fn update_cache(&mut self, request: &Request) {
Self::update_cache_lists(request, self);
Self::update_cache_subscriptions(request, self);
@@ -215,98 +366,3 @@ fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
cached.replace(target.clone());
}
}
#[implement(Service)]
pub fn clear_connections(
&self,
user_id: Option<&UserId>,
device_id: Option<&DeviceId>,
conn_id: Option<&ConnectionId>,
) {
self.connections.lock().expect("locked").retain(
|(conn_user_id, conn_device_id, conn_conn_id), _| {
!(user_id.is_none_or(is_equal_to!(conn_user_id))
&& device_id.is_none_or(is_equal_to!(conn_device_id))
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref()))
},
);
}
#[implement(Service)]
pub fn drop_connection(&self, key: &ConnectionKey) {
self.connections
.lock()
.expect("locked")
.remove(key);
}
#[implement(Service)]
pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
self.connections
.lock()
.expect("locked")
.entry(key.clone())
.and_modify(|existing| *existing = ConnectionVal::default())
.or_default()
.clone()
}
#[implement(Service)]
pub fn device_connections(
&self,
user_id: &UserId,
device_id: &DeviceId,
exclude: Option<&ConnectionId>,
) -> impl Iterator<Item = ConnectionVal> + Send {
type Siblings = SmallVec<[ConnectionVal; 4]>;
let key = into_connection_key(user_id, device_id, None::<ConnectionId>);
self.connections
.lock()
.expect("locked")
.range((Included(&key), Included(&key)))
.filter(|((_, _, id), _)| id.as_ref() != exclude)
.map(at!(1))
.cloned()
.collect::<Siblings>()
.into_iter()
}
#[implement(Service)]
pub fn list_connections(&self) -> Vec<ConnectionKey> {
self.connections
.lock()
.expect("locked")
.keys()
.cloned()
.collect()
}
#[implement(Service)]
pub fn find_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
self.connections
.lock()
.expect("locked")
.get(key)
.cloned()
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
}
#[implement(Service)]
pub fn contains_connection(&self, key: &ConnectionKey) -> bool {
self.connections
.lock()
.expect("locked")
.contains_key(key)
}
#[inline]
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
where
U: Into<OwnedUserId>,
D: Into<OwnedDeviceId>,
C: Into<ConnectionId>,
{
(user_id.into(), device_id.into(), conn_id.map(Into::into))
}