Refactor counter increment sites for TwoPhaseCounter.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -160,7 +160,7 @@ pub(crate) async fn build_sync_events(
|
|||||||
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
|
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
|
||||||
let (sender_user, sender_device) = body.sender();
|
let (sender_user, sender_device) = body.sender();
|
||||||
|
|
||||||
let next_batch = services.globals.current_count()?;
|
let next_batch = services.globals.current_count();
|
||||||
let since = body
|
let since = body
|
||||||
.body
|
.body
|
||||||
.since
|
.since
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ pub(crate) async fn sync_events_v5_route(
|
|||||||
// Setup watchers, so if there's no response, we can wait for them
|
// Setup watchers, so if there's no response, we can wait for them
|
||||||
let watcher = services.sync.watch(sender_user, sender_device);
|
let watcher = services.sync.watch(sender_user, sender_device);
|
||||||
|
|
||||||
let next_batch = services.globals.next_count()?;
|
let next_batch = services.globals.next_count();
|
||||||
|
|
||||||
let conn_id = body.conn_id.clone();
|
let conn_id = body.conn_id.clone();
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ pub(crate) async fn sync_events_v5_route(
|
|||||||
.chain(all_invited_rooms.clone())
|
.chain(all_invited_rooms.clone())
|
||||||
.chain(all_knocked_rooms.clone());
|
.chain(all_knocked_rooms.clone());
|
||||||
|
|
||||||
let pos = next_batch.clone().to_string();
|
let pos = next_batch.to_string();
|
||||||
|
|
||||||
let mut todo_rooms: TodoRooms = BTreeMap::new();
|
let mut todo_rooms: TodoRooms = BTreeMap::new();
|
||||||
|
|
||||||
@@ -146,7 +146,7 @@ pub(crate) async fn sync_events_v5_route(
|
|||||||
|
|
||||||
let e2ee = collect_e2ee(services, sync_info, all_joined_rooms.clone());
|
let e2ee = collect_e2ee(services, sync_info, all_joined_rooms.clone());
|
||||||
|
|
||||||
let to_device = collect_to_device(services, sync_info, next_batch).map(Ok);
|
let to_device = collect_to_device(services, sync_info, *next_batch).map(Ok);
|
||||||
|
|
||||||
let receipts = collect_receipts(services).map(Ok);
|
let receipts = collect_receipts(services).map(Ok);
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ pub(crate) async fn sync_events_v5_route(
|
|||||||
response.rooms = process_rooms(
|
response.rooms = process_rooms(
|
||||||
services,
|
services,
|
||||||
sender_user,
|
sender_user,
|
||||||
next_batch,
|
*next_batch,
|
||||||
all_invited_rooms.clone(),
|
all_invited_rooms.clone(),
|
||||||
&todo_rooms,
|
&todo_rooms,
|
||||||
&mut response,
|
&mut response,
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ pub(crate) async fn send_event_to_device_route(
|
|||||||
map.insert(target_device_id_maybe.clone(), event.clone());
|
map.insert(target_device_id_maybe.clone(), event.clone());
|
||||||
let mut messages = BTreeMap::new();
|
let mut messages = BTreeMap::new();
|
||||||
messages.insert(target_user_id.clone(), map);
|
messages.insert(target_user_id.clone(), map);
|
||||||
let count = services.globals.next_count()?;
|
|
||||||
|
|
||||||
let mut buf = EduBuf::new();
|
let mut buf = EduBuf::new();
|
||||||
serde_json::to_writer(
|
serde_json::to_writer(
|
||||||
@@ -49,7 +48,7 @@ pub(crate) async fn send_event_to_device_route(
|
|||||||
&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
|
&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
|
||||||
sender: sender_user.to_owned(),
|
sender: sender_user.to_owned(),
|
||||||
ev_type: body.event_type.clone(),
|
ev_type: body.event_type.clone(),
|
||||||
message_id: count.to_string().into(),
|
message_id: services.globals.next_count().to_string().into(),
|
||||||
messages,
|
messages,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,8 +63,8 @@ pub async fn update(
|
|||||||
return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
|
return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
let roomuserdataid = (room_id, user_id, count, &event_type);
|
let roomuserdataid = (room_id, user_id, *count, &event_type);
|
||||||
self.db
|
self.db
|
||||||
.roomuserdataid_accountdata
|
.roomuserdataid_accountdata
|
||||||
.put(roomuserdataid, Json(data));
|
.put(roomuserdataid, Json(data));
|
||||||
|
|||||||
@@ -1,56 +1,51 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tuwunel_core::{Result, utils};
|
use tuwunel_core::{
|
||||||
|
Result, utils,
|
||||||
|
utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit},
|
||||||
|
};
|
||||||
use tuwunel_database::{Database, Deserialized, Map};
|
use tuwunel_database::{Database, Deserialized, Map};
|
||||||
|
|
||||||
pub struct Data {
|
pub struct Data {
|
||||||
global: Arc<Map>,
|
global: Arc<Map>,
|
||||||
counter: RwLock<u64>,
|
counter: Arc<Counter>,
|
||||||
pub(super) db: Arc<Database>,
|
pub(super) db: Arc<Database>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) type Permit = TwoPhasePermit<Callback>;
|
||||||
|
type Counter = TwoPhaseCounter<Callback>;
|
||||||
|
type Callback = Box<dyn Fn(u64) -> Result + Send + Sync>;
|
||||||
|
|
||||||
const COUNTER: &[u8] = b"c";
|
const COUNTER: &[u8] = b"c";
|
||||||
|
|
||||||
impl Data {
|
impl Data {
|
||||||
pub(super) fn new(args: &crate::Args<'_>) -> Self {
|
pub(super) fn new(args: &crate::Args<'_>) -> Self {
|
||||||
let db = &args.db;
|
let db = args.db.clone();
|
||||||
Self {
|
Self {
|
||||||
global: db["global"].clone(),
|
|
||||||
counter: RwLock::new(
|
|
||||||
Self::stored_count(&db["global"]).expect("initialized global counter"),
|
|
||||||
),
|
|
||||||
db: args.db.clone(),
|
db: args.db.clone(),
|
||||||
|
global: args.db["global"].clone(),
|
||||||
|
counter: Counter::new(
|
||||||
|
Self::stored_count(&args.db["global"]).expect("initialized global counter"),
|
||||||
|
Box::new(move |count| Self::store_count(&db, &db["global"], count)),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next_count(&self) -> Result<u64> {
|
#[inline]
|
||||||
let _cork = self.db.cork();
|
pub fn next_count(&self) -> Permit {
|
||||||
let mut lock = self.counter.write().expect("locked");
|
self.counter
|
||||||
let counter: &mut u64 = &mut lock;
|
.next()
|
||||||
debug_assert!(
|
.expect("failed to obtain next sequence number")
|
||||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
|
||||||
"counter mismatch"
|
|
||||||
);
|
|
||||||
|
|
||||||
*counter = counter
|
|
||||||
.checked_add(1)
|
|
||||||
.expect("counter must not overflow u64");
|
|
||||||
|
|
||||||
self.global.insert(COUNTER, counter.to_be_bytes());
|
|
||||||
|
|
||||||
Ok(*counter)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn current_count(&self) -> u64 {
|
pub fn current_count(&self) -> u64 { self.counter.current() }
|
||||||
let lock = self.counter.read().expect("locked");
|
|
||||||
let counter: &u64 = &lock;
|
|
||||||
debug_assert!(
|
|
||||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
|
||||||
"counter mismatch"
|
|
||||||
);
|
|
||||||
|
|
||||||
*counter
|
fn store_count(db: &Arc<Database>, global: &Arc<Map>, count: u64) -> Result {
|
||||||
|
let _cork = db.cork();
|
||||||
|
global.insert(COUNTER, count.to_be_bytes());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn stored_count(global: &Arc<Map>) -> Result<u64> {
|
fn stored_count(global: &Arc<Map>) -> Result<u64> {
|
||||||
@@ -59,6 +54,12 @@ impl Data {
|
|||||||
.as_deref()
|
.as_deref()
|
||||||
.map_or(Ok(0_u64), utils::u64_from_bytes)
|
.map_or(Ok(0_u64), utils::u64_from_bytes)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Data {
|
||||||
|
pub fn bump_database_version(&self, new_version: u64) {
|
||||||
|
self.global.raw_put(b"version", new_version);
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn database_version(&self) -> u64 {
|
pub async fn database_version(&self) -> u64 {
|
||||||
self.global
|
self.global
|
||||||
@@ -67,9 +68,4 @@ impl Data {
|
|||||||
.deserialized()
|
.deserialized()
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn bump_database_version(&self, new_version: u64) {
|
|
||||||
self.global.raw_put(b"version", new_version);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,10 +103,12 @@ impl crate::Service for Service {
|
|||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
#[must_use]
|
||||||
|
pub fn next_count(&self) -> data::Permit { self.db.next_count() }
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) }
|
#[must_use]
|
||||||
|
pub fn current_count(&self) -> u64 { self.db.current_count() }
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
|
|||||||
@@ -52,17 +52,18 @@ pub fn create_backup(
|
|||||||
user_id: &UserId,
|
user_id: &UserId,
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
backup_metadata: &Raw<BackupAlgorithm>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let version = self.services.globals.next_count()?.to_string();
|
let version = self.services.globals.next_count();
|
||||||
let count = self.services.globals.next_count()?;
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let key = (user_id, &version);
|
let version_string = version.to_string();
|
||||||
|
let key = (user_id, &version_string);
|
||||||
self.db
|
self.db
|
||||||
.backupid_algorithm
|
.backupid_algorithm
|
||||||
.put(key, Json(backup_metadata));
|
.put(key, Json(backup_metadata));
|
||||||
|
|
||||||
self.db.backupid_etag.put(key, count);
|
self.db.backupid_etag.put(key, *count);
|
||||||
|
|
||||||
Ok(version)
|
Ok(version_string)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
@@ -100,8 +101,8 @@ pub async fn update_backup<'a>(
|
|||||||
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
|
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
self.db.backupid_etag.put(key, count);
|
self.db.backupid_etag.put(key, *count);
|
||||||
self.db
|
self.db
|
||||||
.backupid_algorithm
|
.backupid_algorithm
|
||||||
.put_raw(key, backup_metadata.json().get());
|
.put_raw(key, backup_metadata.json().get());
|
||||||
@@ -175,8 +176,8 @@ pub async fn add_key(
|
|||||||
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
|
return Err!(Request(NotFound("Tried to update nonexistent backup.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
self.db.backupid_etag.put(key, count);
|
self.db.backupid_etag.put(key, *count);
|
||||||
|
|
||||||
let key = (user_id, version, room_id, session_id);
|
let key = (user_id, version, room_id, session_id);
|
||||||
self.db
|
self.db
|
||||||
|
|||||||
@@ -120,12 +120,12 @@ impl Data {
|
|||||||
status_msg,
|
status_msg,
|
||||||
);
|
);
|
||||||
|
|
||||||
let count = self.services.globals.next_count()?;
|
let count = self.services.globals.next_count();
|
||||||
let key = presenceid_key(count, user_id);
|
let key = presenceid_key(*count, user_id);
|
||||||
|
|
||||||
|
self.userid_presenceid.raw_put(user_id, *count);
|
||||||
self.presenceid_presence
|
self.presenceid_presence
|
||||||
.raw_put(key, Json(presence));
|
.raw_put(key, Json(presence));
|
||||||
self.userid_presenceid.raw_put(user_id, count);
|
|
||||||
|
|
||||||
if let Ok((last_count, _)) = last_presence {
|
if let Ok((last_count, _)) = last_presence {
|
||||||
let key = presenceid_key(last_count, user_id);
|
let key = presenceid_key(last_count, user_id);
|
||||||
|
|||||||
@@ -71,6 +71,8 @@ impl Service {
|
|||||||
return Err!(Request(Forbidden("Only the server user can set this alias")));
|
return Err!(Request(Forbidden("Only the server user can set this alias")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
// Comes first as we don't want a stuck alias
|
// Comes first as we don't want a stuck alias
|
||||||
self.db
|
self.db
|
||||||
.alias_userid
|
.alias_userid
|
||||||
@@ -82,7 +84,8 @@ impl Service {
|
|||||||
|
|
||||||
let mut aliasid = room_id.as_bytes().to_vec();
|
let mut aliasid = room_id.as_bytes().to_vec();
|
||||||
aliasid.push(0xFF);
|
aliasid.push(0xFF);
|
||||||
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
|
aliasid.extend_from_slice(&count.to_be_bytes());
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.aliasid_alias
|
.aliasid_alias
|
||||||
.insert(&aliasid, alias.as_bytes());
|
.insert(&aliasid, alias.as_bytes());
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ impl Data {
|
|||||||
.ready_for_each(|key| self.readreceiptid_readreceipt.del(key))
|
.ready_for_each(|key| self.readreceiptid_readreceipt.del(key))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
let latest_id = (room_id, count, user_id);
|
let latest_id = (room_id, *count, user_id);
|
||||||
self.readreceiptid_readreceipt
|
self.readreceiptid_readreceipt
|
||||||
.put(latest_id, Json(event));
|
.put(latest_id, Json(event));
|
||||||
}
|
}
|
||||||
@@ -89,11 +89,11 @@ impl Data {
|
|||||||
|
|
||||||
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) {
|
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) {
|
||||||
let key = (room_id, user_id);
|
let key = (room_id, user_id);
|
||||||
let next_count = self.services.globals.next_count().unwrap();
|
let next_count = self.services.globals.next_count();
|
||||||
|
|
||||||
self.roomuserid_privateread.put(key, pdu_count);
|
self.roomuserid_privateread.put(key, pdu_count);
|
||||||
self.roomuserid_lastprivatereadupdate
|
self.roomuserid_lastprivatereadupdate
|
||||||
.put(key, next_count);
|
.put(key, *next_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn private_read_get_count(
|
pub(super) async fn private_read_get_count(
|
||||||
|
|||||||
@@ -81,18 +81,18 @@ where
|
|||||||
fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
|
fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
|
||||||
const BUFSIZE: usize = size_of::<ShortEventId>();
|
const BUFSIZE: usize = size_of::<ShortEventId>();
|
||||||
|
|
||||||
let short = self.services.globals.next_count().unwrap();
|
let short = self.services.globals.next_count();
|
||||||
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed");
|
debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.eventid_shorteventid
|
.eventid_shorteventid
|
||||||
.raw_aput::<BUFSIZE, _, _>(event_id, short);
|
.raw_aput::<BUFSIZE, _, _>(event_id, *short);
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.shorteventid_eventid
|
.shorteventid_eventid
|
||||||
.aput_raw::<BUFSIZE, _, _>(short, event_id);
|
.aput_raw::<BUFSIZE, _, _>(*short, event_id);
|
||||||
|
|
||||||
short
|
*short
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
@@ -120,18 +120,19 @@ pub async fn get_or_create_shortstatekey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let key = (event_type, state_key);
|
let key = (event_type, state_key);
|
||||||
let shortstatekey = self.services.globals.next_count().unwrap();
|
let shortstatekey = self.services.globals.next_count();
|
||||||
debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed");
|
|
||||||
|
debug_assert!(size_of_val(&*shortstatekey) == BUFSIZE, "buffer requirement changed");
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.statekey_shortstatekey
|
.statekey_shortstatekey
|
||||||
.put_aput::<BUFSIZE, _, _>(key, shortstatekey);
|
.put_aput::<BUFSIZE, _, _>(key, *shortstatekey);
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.shortstatekey_statekey
|
.shortstatekey_statekey
|
||||||
.aput_put::<BUFSIZE, _, _>(shortstatekey, key);
|
.aput_put::<BUFSIZE, _, _>(*shortstatekey, key);
|
||||||
|
|
||||||
shortstatekey
|
*shortstatekey
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
@@ -226,14 +227,14 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortSta
|
|||||||
return (shortstatehash, true);
|
return (shortstatehash, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
let shortstatehash = self.services.globals.next_count().unwrap();
|
let shortstatehash = self.services.globals.next_count();
|
||||||
debug_assert!(size_of_val(&shortstatehash) == BUFSIZE, "buffer requirement changed");
|
debug_assert!(size_of_val(&*shortstatehash) == BUFSIZE, "buffer requirement changed");
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.statehash_shortstatehash
|
.statehash_shortstatehash
|
||||||
.raw_aput::<BUFSIZE, _, _>(state_hash, shortstatehash);
|
.raw_aput::<BUFSIZE, _, _>(state_hash, *shortstatehash);
|
||||||
|
|
||||||
(shortstatehash, false)
|
(*shortstatehash, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
@@ -255,13 +256,13 @@ pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId {
|
|||||||
.unwrap_or_else(|_| {
|
.unwrap_or_else(|_| {
|
||||||
const BUFSIZE: usize = size_of::<ShortRoomId>();
|
const BUFSIZE: usize = size_of::<ShortRoomId>();
|
||||||
|
|
||||||
let short = self.services.globals.next_count().unwrap();
|
let short = self.services.globals.next_count();
|
||||||
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed");
|
debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.roomid_shortroomid
|
.roomid_shortroomid
|
||||||
.raw_aput::<BUFSIZE, _, _>(room_id, short);
|
.raw_aput::<BUFSIZE, _, _>(room_id, *short);
|
||||||
|
|
||||||
short
|
*short
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ impl Service {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: statehash with deterministic inputs
|
// TODO: statehash with deterministic inputs
|
||||||
let shortstatehash = self.services.globals.next_count()?;
|
let shortstatehash = self.services.globals.next_count();
|
||||||
|
|
||||||
let mut statediffnew = CompressedState::new();
|
let mut statediffnew = CompressedState::new();
|
||||||
statediffnew.insert(new);
|
statediffnew.insert(new);
|
||||||
@@ -318,14 +318,14 @@ impl Service {
|
|||||||
self.services
|
self.services
|
||||||
.state_compressor
|
.state_compressor
|
||||||
.save_state_from_diff(
|
.save_state_from_diff(
|
||||||
shortstatehash,
|
*shortstatehash,
|
||||||
Arc::new(statediffnew),
|
Arc::new(statediffnew),
|
||||||
Arc::new(statediffremoved),
|
Arc::new(statediffremoved),
|
||||||
2,
|
2,
|
||||||
states_parents,
|
states_parents,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(shortstatehash)
|
Ok(*shortstatehash)
|
||||||
},
|
},
|
||||||
| _ =>
|
| _ =>
|
||||||
Ok(previous_shortstatehash.expect("first event in room must be a state event")),
|
Ok(previous_shortstatehash.expect("first event in room must be a state event")),
|
||||||
|
|||||||
@@ -272,6 +272,8 @@ pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
|
|||||||
#[implement(super::Service)]
|
#[implement(super::Service)]
|
||||||
#[tracing::instrument(skip(self), level = "debug")]
|
#[tracing::instrument(skip(self), level = "debug")]
|
||||||
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
|
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let userroom_id = (user_id, room_id);
|
let userroom_id = (user_id, room_id);
|
||||||
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
|
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
|
||||||
|
|
||||||
@@ -286,7 +288,7 @@ pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
|
|||||||
.raw_put(&userroom_id, Json(leftstate));
|
.raw_put(&userroom_id, Json(leftstate));
|
||||||
self.db
|
self.db
|
||||||
.roomuserid_leftcount
|
.roomuserid_leftcount
|
||||||
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
|
.raw_aput::<8, _, _>(&roomuser_id, *count);
|
||||||
|
|
||||||
self.db.userroomid_joined.remove(&userroom_id);
|
self.db.userroomid_joined.remove(&userroom_id);
|
||||||
self.db.roomuserid_joined.remove(&roomuser_id);
|
self.db.roomuserid_joined.remove(&roomuser_id);
|
||||||
@@ -319,6 +321,8 @@ pub fn mark_as_knocked(
|
|||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||||
) {
|
) {
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let userroom_id = (user_id, room_id);
|
let userroom_id = (user_id, room_id);
|
||||||
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
|
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
|
||||||
|
|
||||||
@@ -330,7 +334,7 @@ pub fn mark_as_knocked(
|
|||||||
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
|
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
|
||||||
self.db
|
self.db
|
||||||
.roomuserid_knockedcount
|
.roomuserid_knockedcount
|
||||||
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
|
.raw_aput::<8, _, _>(&roomuser_id, *count);
|
||||||
|
|
||||||
self.db.userroomid_joined.remove(&userroom_id);
|
self.db.userroomid_joined.remove(&userroom_id);
|
||||||
self.db.roomuserid_joined.remove(&roomuser_id);
|
self.db.roomuserid_joined.remove(&roomuser_id);
|
||||||
@@ -375,6 +379,8 @@ pub async fn mark_as_invited(
|
|||||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||||
invite_via: Option<Vec<OwnedServerName>>,
|
invite_via: Option<Vec<OwnedServerName>>,
|
||||||
) {
|
) {
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let roomuser_id = (room_id, user_id);
|
let roomuser_id = (room_id, user_id);
|
||||||
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
|
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
|
||||||
|
|
||||||
@@ -386,7 +392,7 @@ pub async fn mark_as_invited(
|
|||||||
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
|
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
|
||||||
self.db
|
self.db
|
||||||
.roomuserid_invitecount
|
.roomuserid_invitecount
|
||||||
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap());
|
.raw_aput::<8, _, _>(&roomuser_id, *count);
|
||||||
|
|
||||||
self.db.userroomid_joined.remove(&userroom_id);
|
self.db.userroomid_joined.remove(&userroom_id);
|
||||||
self.db.roomuserid_joined.remove(&roomuser_id);
|
self.db.roomuserid_joined.remove(&roomuser_id);
|
||||||
|
|||||||
@@ -159,20 +159,20 @@ where
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
let insert_lock = self.mutex_insert.lock(pdu.room_id()).await;
|
let insert_lock = self.mutex_insert.lock(pdu.room_id()).await;
|
||||||
|
let count1 = self.services.globals.next_count();
|
||||||
let count1 = self.services.globals.next_count().unwrap();
|
let count2 = self.services.globals.next_count();
|
||||||
|
|
||||||
// Mark as read first so the sending client doesn't get a notification even if
|
// Mark as read first so the sending client doesn't get a notification even if
|
||||||
// appending fails
|
// appending fails
|
||||||
self.services
|
self.services
|
||||||
.read_receipt
|
.read_receipt
|
||||||
.private_read_set(pdu.room_id(), pdu.sender(), count1);
|
.private_read_set(pdu.room_id(), pdu.sender(), *count1);
|
||||||
|
|
||||||
self.services
|
self.services
|
||||||
.user
|
.user
|
||||||
.reset_notification_counts(pdu.sender(), pdu.room_id());
|
.reset_notification_counts(pdu.sender(), pdu.room_id());
|
||||||
|
|
||||||
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
|
let count2 = PduCount::Normal(*count2);
|
||||||
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
|
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
|
||||||
|
|
||||||
// Insert pdu
|
// Insert pdu
|
||||||
|
|||||||
@@ -181,13 +181,9 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
|
|||||||
|
|
||||||
let insert_lock = self.mutex_insert.lock(&room_id).await;
|
let insert_lock = self.mutex_insert.lock(&room_id).await;
|
||||||
|
|
||||||
let count: i64 = self
|
let count = self.services.globals.next_count();
|
||||||
.services
|
|
||||||
.globals
|
|
||||||
.next_count()
|
|
||||||
.unwrap()
|
|
||||||
.try_into()?;
|
|
||||||
|
|
||||||
|
let count: i64 = (*count).try_into()?;
|
||||||
let pdu_id: RawPduId = PduId {
|
let pdu_id: RawPduId = PduId {
|
||||||
shortroomid,
|
shortroomid,
|
||||||
shorteventid: PduCount::Backfilled(validated!(0 - count)),
|
shorteventid: PduCount::Backfilled(validated!(0 - count)),
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ impl Service {
|
|||||||
/// roomtyping_remove is called.
|
/// roomtyping_remove is called.
|
||||||
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result {
|
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result {
|
||||||
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
|
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
|
||||||
|
|
||||||
// update clients
|
// update clients
|
||||||
self.typing
|
self.typing
|
||||||
.write()
|
.write()
|
||||||
@@ -61,10 +62,11 @@ impl Service {
|
|||||||
.or_default()
|
.or_default()
|
||||||
.insert(user_id.to_owned(), timeout);
|
.insert(user_id.to_owned(), timeout);
|
||||||
|
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
self.last_typing_update
|
self.last_typing_update
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert(room_id.to_owned(), self.services.globals.next_count()?);
|
.insert(room_id.to_owned(), *count);
|
||||||
|
|
||||||
if self
|
if self
|
||||||
.typing_update_sender
|
.typing_update_sender
|
||||||
@@ -86,6 +88,7 @@ impl Service {
|
|||||||
/// Removes a user from typing before the timeout is reached.
|
/// Removes a user from typing before the timeout is reached.
|
||||||
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result {
|
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result {
|
||||||
debug_info!("typing stopped {user_id:?} in {room_id:?}");
|
debug_info!("typing stopped {user_id:?} in {room_id:?}");
|
||||||
|
|
||||||
// update clients
|
// update clients
|
||||||
self.typing
|
self.typing
|
||||||
.write()
|
.write()
|
||||||
@@ -94,10 +97,11 @@ impl Service {
|
|||||||
.or_default()
|
.or_default()
|
||||||
.remove(user_id);
|
.remove(user_id);
|
||||||
|
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
self.last_typing_update
|
self.last_typing_update
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert(room_id.to_owned(), self.services.globals.next_count()?);
|
.insert(room_id.to_owned(), *count);
|
||||||
|
|
||||||
if self
|
if self
|
||||||
.typing_update_sender
|
.typing_update_sender
|
||||||
@@ -146,16 +150,18 @@ impl Service {
|
|||||||
if !removable.is_empty() {
|
if !removable.is_empty() {
|
||||||
let typing = &mut self.typing.write().await;
|
let typing = &mut self.typing.write().await;
|
||||||
let room = typing.entry(room_id.to_owned()).or_default();
|
let room = typing.entry(room_id.to_owned()).or_default();
|
||||||
|
|
||||||
for user in &removable {
|
for user in &removable {
|
||||||
debug_info!("typing timeout {user:?} in {room_id:?}");
|
debug_info!("typing timeout {user:?} in {room_id:?}");
|
||||||
room.remove(user);
|
room.remove(user);
|
||||||
}
|
}
|
||||||
|
|
||||||
// update clients
|
// update clients
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
self.last_typing_update
|
self.last_typing_update
|
||||||
.write()
|
.write()
|
||||||
.await
|
.await
|
||||||
.insert(room_id.to_owned(), self.services.globals.next_count()?);
|
.insert(room_id.to_owned(), *count);
|
||||||
|
|
||||||
if self
|
if self
|
||||||
.typing_update_sender
|
.typing_update_sender
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ impl crate::Service for Service {
|
|||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
|
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let userroom_id = (user_id, room_id);
|
let userroom_id = (user_id, room_id);
|
||||||
self.db
|
self.db
|
||||||
.userroomid_highlightcount
|
.userroomid_highlightcount
|
||||||
@@ -56,10 +58,9 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
|
|||||||
.put(userroom_id, 0_u64);
|
.put(userroom_id, 0_u64);
|
||||||
|
|
||||||
let roomuser_id = (room_id, user_id);
|
let roomuser_id = (room_id, user_id);
|
||||||
let count = self.services.globals.next_count().unwrap();
|
|
||||||
self.db
|
self.db
|
||||||
.roomuserid_lastnotificationread
|
.roomuserid_lastnotificationread
|
||||||
.put(roomuser_id, count);
|
.put(roomuser_id, *count);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[implement(Service)]
|
#[implement(Service)]
|
||||||
|
|||||||
@@ -126,8 +126,9 @@ impl Data {
|
|||||||
if let SendingEvent::Pdu(value) = event {
|
if let SendingEvent::Pdu(value) = event {
|
||||||
key.extend(value.as_ref());
|
key.extend(value.as_ref());
|
||||||
} else {
|
} else {
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
key.extend(&count.to_be_bytes());
|
let count = count.to_be_bytes();
|
||||||
|
key.extend(&count);
|
||||||
}
|
}
|
||||||
|
|
||||||
key
|
key
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ impl Service {
|
|||||||
async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> {
|
async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> {
|
||||||
// selection window
|
// selection window
|
||||||
let since = self.db.get_latest_educount(server_name).await;
|
let since = self.db.get_latest_educount(server_name).await;
|
||||||
let since_upper = self.services.globals.current_count()?;
|
let since_upper = self.services.globals.current_count();
|
||||||
let batch = (since, since_upper);
|
let batch = (since, since_upper);
|
||||||
debug_assert!(batch.0 <= batch.1, "since range must not be negative");
|
debug_assert!(batch.0 <= batch.1, "since range must not be negative");
|
||||||
|
|
||||||
|
|||||||
@@ -154,9 +154,9 @@ pub async fn add_to_device_event(
|
|||||||
event_type: &str,
|
event_type: &str,
|
||||||
content: serde_json::Value,
|
content: serde_json::Value,
|
||||||
) {
|
) {
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
let key = (target_user_id, target_device_id, count);
|
let key = (target_user_id, target_device_id, *count);
|
||||||
self.db.todeviceid_events.put(
|
self.db.todeviceid_events.put(
|
||||||
key,
|
key,
|
||||||
Json(json!({
|
Json(json!({
|
||||||
|
|||||||
@@ -52,14 +52,14 @@ pub async fn add_one_time_key(
|
|||||||
.as_bytes(),
|
.as_bytes(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
self.db
|
self.db
|
||||||
.onetimekeyid_onetimekeys
|
.onetimekeyid_onetimekeys
|
||||||
.raw_put(key, Json(one_time_key_value));
|
.raw_put(key, Json(one_time_key_value));
|
||||||
|
|
||||||
let count = self.services.globals.next_count().unwrap();
|
|
||||||
self.db
|
self.db
|
||||||
.userid_lastonetimekeyupdate
|
.userid_lastonetimekeyupdate
|
||||||
.raw_put(user_id, count);
|
.raw_put(user_id, *count);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -81,10 +81,10 @@ pub async fn take_one_time_key(
|
|||||||
device_id: &DeviceId,
|
device_id: &DeviceId,
|
||||||
key_algorithm: &OneTimeKeyAlgorithm,
|
key_algorithm: &OneTimeKeyAlgorithm,
|
||||||
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
|
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
|
||||||
let count = self.services.globals.next_count()?.to_be_bytes();
|
let count = self.services.globals.next_count();
|
||||||
self.db
|
self.db
|
||||||
.userid_lastonetimekeyupdate
|
.userid_lastonetimekeyupdate
|
||||||
.insert(user_id, count);
|
.insert(user_id, count.to_be_bytes());
|
||||||
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xFF);
|
prefix.push(0xFF);
|
||||||
@@ -340,7 +340,7 @@ fn keys_changed_user_or_room<'a>(
|
|||||||
|
|
||||||
#[implement(super::Service)]
|
#[implement(super::Service)]
|
||||||
pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
||||||
let count = self.services.globals.next_count().unwrap();
|
let count = self.services.globals.next_count();
|
||||||
|
|
||||||
self.services
|
self.services
|
||||||
.state_cache
|
.state_cache
|
||||||
@@ -348,12 +348,12 @@ pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
|||||||
// Don't send key updates to unencrypted rooms
|
// Don't send key updates to unencrypted rooms
|
||||||
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
|
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
|
||||||
.ready_for_each(|room_id| {
|
.ready_for_each(|room_id| {
|
||||||
let key = (room_id, count);
|
let key = (room_id, *count);
|
||||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let key = (user_id, count);
|
let key = (user_id, *count);
|
||||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user