diff --git a/src/admin/query/sync.rs b/src/admin/query/sync.rs index d72d7d27..3e71db38 100644 --- a/src/admin/query/sync.rs +++ b/src/admin/query/sync.rs @@ -29,7 +29,7 @@ pub(crate) enum SyncCommand { #[admin_command] pub(super) async fn list_connections(&self) -> Result { - let connections = self.services.sync.list_connections(); + let connections = self.services.sync.list_loaded_connections().await; for connection_key in connections { self.write_str(&format!("{connection_key:?}\n")) @@ -47,7 +47,11 @@ pub(super) async fn show_connection( conn_id: Option, ) -> Result { let key = into_connection_key(user_id, device_id, conn_id); - let cache = self.services.sync.find_connection(&key)?; + let cache = self + .services + .sync + .get_loaded_connection(&key) + .await?; let out; { @@ -65,11 +69,14 @@ pub(super) async fn drop_connections( device_id: Option, conn_id: Option, ) -> Result { - self.services.sync.clear_connections( - user_id.as_deref(), - device_id.as_deref(), - conn_id.map(Into::into).as_ref(), - ); + self.services + .sync + .clear_connections( + user_id.as_deref(), + device_id.as_deref(), + conn_id.map(Into::into).as_ref(), + ) + .await; Ok(()) } diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 1132e7f0..bda97c47 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -25,7 +25,7 @@ use tokio::time::{Instant, timeout_at}; use tuwunel_core::{ Err, Result, apply, at, debug, debug::INFO_SPAN_LEVEL, - err, + debug_warn, error::inspect_log, extract_variant, smallvec::SmallVec, @@ -115,11 +115,10 @@ pub(crate) async fn sync_events_v5_route( .unwrap_or(0); let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref()); - let conn_val = since - .ne(&0) - .then(|| services.sync.find_connection(&conn_key)) - .unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key))) - .map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?; + let conn_val = services + .sync + .load_or_init_connection(&conn_key) + .await; let conn = conn_val.lock(); let ping_presence = services @@ -130,10 +129,22 @@ pub(crate) async fn sync_events_v5_route( let (mut conn, _) = join(conn, ping_presence).await; + if since != 0 && conn.next_batch == 0 { + return Err!(Request(UnknownPos(warn!("Connection lost; restarting sync stream.")))); + } + + if since == 0 { + *conn = Connection::default(); + conn.store(&services.sync, &conn_key); + debug_warn!(?conn_key, "Client cleared cache and reloaded."); + } + let advancing = since == conn.next_batch; - let retarding = since <= conn.globalsince; + let retarding = since != 0 && since <= conn.globalsince; if !advancing && !retarding { - return Err!(Request(UnknownPos("Requesting unknown or stale stream position."))); + return Err!(Request(UnknownPos(warn!( + "Requesting unknown or invalid stream position." + )))); } debug_assert!( @@ -189,6 +200,7 @@ pub(crate) async fn sync_events_v5_route( if !is_empty_response(&response) { response.pos = conn.next_batch.to_string().into(); trace!(conn.globalsince, conn.next_batch, "response {response:?}"); + conn.store(&services.sync, &conn_key); return Ok(response); } } @@ -202,6 +214,7 @@ pub(crate) async fn sync_events_v5_route( { response.pos = conn.next_batch.to_string().into(); trace!(conn.globalsince, conn.next_batch, "timeout; empty response {response:?}"); + conn.store(&services.sync, &conn_key); return Ok(response); } diff --git a/src/database/maps.rs b/src/database/maps.rs index 338cfe2f..0c97d9a8 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -341,6 +341,11 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "userdeviceid_metadata", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "userdeviceconnid_conn", + block_size: 1024 * 16, + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "userdeviceid_refresh", ..descriptor::RANDOM_SMALL diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index e74360f6..8782ede1 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -1,11 +1,11 @@ mod watch; use std::{ - collections::BTreeMap, - ops::Bound::Included, - sync::{Arc, Mutex as StdMutex}, + collections::{BTreeMap, btree_map::Entry}, + sync::Arc, }; +use futures::{FutureExt, Stream}; use ruma::{ DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId, api::client::sync::sync_events::v5::{ @@ -13,9 +13,10 @@ use ruma::{ request::{AccountData, E2EE, Receipts, ToDevice, Typing}, }, }; +use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as TokioMutex; -use tuwunel_core::{Result, at, err, implement, is_equal_to, smallvec::SmallVec}; -use tuwunel_database::Map; +use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore}; +use tuwunel_database::{Cbor, Deserialized, Map}; pub struct Service { services: Arc, @@ -23,7 +24,8 @@ pub struct Service { db: Data, } -pub struct Data { +struct Data { + userdeviceconnid_conn: Arc, todeviceid_events: Arc, userroomid_joined: Arc, userroomid_invitestate: Arc, @@ -40,22 +42,22 @@ pub struct Data { roomuserid_lastnotificationread: Arc, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Deserialize, Serialize)] pub struct Connection { - pub lists: Lists, - pub rooms: Rooms, - pub subscriptions: Subscriptions, - pub extensions: request::Extensions, pub globalsince: u64, pub next_batch: u64, + pub lists: Lists, + pub extensions: request::Extensions, + pub subscriptions: Subscriptions, + pub rooms: Rooms, } -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] pub struct Room { pub roomsince: u64, } -type Connections = StdMutex>; +type Connections = TokioMutex>; pub type ConnectionVal = Arc>; pub type ConnectionKey = (OwnedUserId, OwnedDeviceId, Option); @@ -67,6 +69,7 @@ impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { db: Data { + userdeviceconnid_conn: args.db["userdeviceconnid_conn"].clone(), todeviceid_events: args.db["todeviceid_events"].clone(), userroomid_joined: args.db["userroomid_joined"].clone(), userroomid_invitestate: args.db["userroomid_invitestate"].clone(), @@ -91,7 +94,153 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn clear_connections( + &self, + user_id: Option<&UserId>, + device_id: Option<&DeviceId>, + conn_id: Option<&ConnectionId>, +) { + self.connections + .lock() + .await + .retain(|(conn_user_id, conn_device_id, conn_conn_id), _| { + let retain = user_id.is_none_or(is_equal_to!(conn_user_id)) + && device_id.is_none_or(is_equal_to!(conn_device_id)) + && (conn_id.is_none() || conn_id == conn_conn_id.as_ref()); + + if !retain { + self.db + .userdeviceconnid_conn + .del((conn_user_id, conn_device_id, conn_conn_id)); + } + + retain + }); +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn drop_connection(&self, key: &ConnectionKey) { + let mut cache = self.connections.lock().await; + + self.db.userdeviceconnid_conn.del(key); + cache.remove(key); +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal { + let mut cache = self.connections.lock().await; + + match cache.entry(key.clone()) { + | Entry::Occupied(val) => val.get().clone(), + | Entry::Vacant(val) => { + let conn = self + .db + .userdeviceconnid_conn + .qry(key) + .boxed() + .await + .deserialized::>() + .map(at!(0)) + .map(TokioMutex::new) + .map(Arc::new) + .unwrap_or_default(); + + val.insert(conn).clone() + }, + } +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn load_connection(&self, key: &ConnectionKey) -> Result { + let mut cache = self.connections.lock().await; + + match cache.entry(key.clone()) { + | Entry::Occupied(val) => Ok(val.get().clone()), + | Entry::Vacant(val) => self + .db + .userdeviceconnid_conn + .qry(key) + .await + .deserialized::>() + .map(at!(0)) + .map(TokioMutex::new) + .map(Arc::new) + .map(|conn| val.insert(conn).clone()), + } +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result { + self.connections + .lock() + .await + .get(key) + .cloned() + .ok_or_else(|| err!(Request(NotFound("Connection not found.")))) +} + +#[implement(Service)] +#[tracing::instrument(level = "trace", skip(self))] +pub async fn list_loaded_connections(&self) -> Vec { + self.connections + .lock() + .await + .keys() + .cloned() + .collect() +} + +#[implement(Service)] +#[tracing::instrument(level = "trace", skip(self))] +pub fn list_stored_connections(&self) -> impl Stream { + self.db.userdeviceconnid_conn.keys().ignore_err() +} + +#[implement(Service)] +#[tracing::instrument(level = "trace", skip(self))] +pub async fn is_connection_loaded(&self, key: &ConnectionKey) -> bool { + self.connections.lock().await.contains_key(key) +} + +#[implement(Service)] +#[tracing::instrument(level = "trace", skip(self))] +pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool { + self.db.userdeviceconnid_conn.contains(key).await +} + +#[inline] +pub fn into_connection_key(user_id: U, device_id: D, conn_id: Option) -> ConnectionKey +where + U: Into, + D: Into, + C: Into, +{ + (user_id.into(), device_id.into(), conn_id.map(Into::into)) +} + #[implement(Connection)] +#[tracing::instrument(level = "debug", skip(self, service))] +pub fn store(&self, service: &Service, key: &ConnectionKey) { + service + .db + .userdeviceconnid_conn + .put(key, Cbor(self)); + + debug!( + since = %self.globalsince, + next_batch = %self.next_batch, + "Persisted connection state" + ); +} + +#[implement(Connection)] +#[tracing::instrument(level = "debug", skip(self))] pub fn update_rooms_prologue(&mut self, retard_since: Option) { self.rooms.values_mut().for_each(|room| { if let Some(retard_since) = retard_since { @@ -103,6 +252,7 @@ pub fn update_rooms_prologue(&mut self, retard_since: Option) { } #[implement(Connection)] +#[tracing::instrument(level = "debug", skip_all)] pub fn update_rooms_epilogue<'a, Rooms>(&mut self, window: Rooms) where Rooms: Iterator + Send + 'a, @@ -115,6 +265,7 @@ where } #[implement(Connection)] +#[tracing::instrument(level = "debug", skip_all)] pub fn update_cache(&mut self, request: &Request) { Self::update_cache_lists(request, self); Self::update_cache_subscriptions(request, self); @@ -215,98 +366,3 @@ fn some_or_sticky(target: Option<&T>, cached: &mut Option) { cached.replace(target.clone()); } } - -#[implement(Service)] -pub fn clear_connections( - &self, - user_id: Option<&UserId>, - device_id: Option<&DeviceId>, - conn_id: Option<&ConnectionId>, -) { - self.connections.lock().expect("locked").retain( - |(conn_user_id, conn_device_id, conn_conn_id), _| { - !(user_id.is_none_or(is_equal_to!(conn_user_id)) - && device_id.is_none_or(is_equal_to!(conn_device_id)) - && (conn_id.is_none() || conn_id == conn_conn_id.as_ref())) - }, - ); -} - -#[implement(Service)] -pub fn drop_connection(&self, key: &ConnectionKey) { - self.connections - .lock() - .expect("locked") - .remove(key); -} - -#[implement(Service)] -pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal { - self.connections - .lock() - .expect("locked") - .entry(key.clone()) - .and_modify(|existing| *existing = ConnectionVal::default()) - .or_default() - .clone() -} - -#[implement(Service)] -pub fn device_connections( - &self, - user_id: &UserId, - device_id: &DeviceId, - exclude: Option<&ConnectionId>, -) -> impl Iterator + Send { - type Siblings = SmallVec<[ConnectionVal; 4]>; - - let key = into_connection_key(user_id, device_id, None::); - - self.connections - .lock() - .expect("locked") - .range((Included(&key), Included(&key))) - .filter(|((_, _, id), _)| id.as_ref() != exclude) - .map(at!(1)) - .cloned() - .collect::() - .into_iter() -} - -#[implement(Service)] -pub fn list_connections(&self) -> Vec { - self.connections - .lock() - .expect("locked") - .keys() - .cloned() - .collect() -} - -#[implement(Service)] -pub fn find_connection(&self, key: &ConnectionKey) -> Result { - self.connections - .lock() - .expect("locked") - .get(key) - .cloned() - .ok_or_else(|| err!(Request(NotFound("Connection not found.")))) -} - -#[implement(Service)] -pub fn contains_connection(&self, key: &ConnectionKey) -> bool { - self.connections - .lock() - .expect("locked") - .contains_key(key) -} - -#[inline] -pub fn into_connection_key(user_id: U, device_id: D, conn_id: Option) -> ConnectionKey -where - U: Into, - D: Into, - C: Into, -{ - (user_id.into(), device_id.into(), conn_id.map(Into::into)) -}