Refactor counter increment sites for TwoPhaseCounter.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-07-24 00:15:33 +00:00
parent 05bb1f4ac7
commit 0fcb072239
21 changed files with 129 additions and 117 deletions

View File

@@ -160,7 +160,7 @@ pub(crate) async fn build_sync_events(
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
let (sender_user, sender_device) = body.sender(); let (sender_user, sender_device) = body.sender();
let next_batch = services.globals.current_count()?; let next_batch = services.globals.current_count();
let since = body let since = body
.body .body
.since .since

View File

@@ -74,7 +74,7 @@ pub(crate) async fn sync_events_v5_route(
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
let watcher = services.sync.watch(sender_user, sender_device); let watcher = services.sync.watch(sender_user, sender_device);
let next_batch = services.globals.next_count()?; let next_batch = services.globals.next_count();
let conn_id = body.conn_id.clone(); let conn_id = body.conn_id.clone();
@@ -136,7 +136,7 @@ pub(crate) async fn sync_events_v5_route(
.chain(all_invited_rooms.clone()) .chain(all_invited_rooms.clone())
.chain(all_knocked_rooms.clone()); .chain(all_knocked_rooms.clone());
let pos = next_batch.clone().to_string(); let pos = next_batch.to_string();
let mut todo_rooms: TodoRooms = BTreeMap::new(); let mut todo_rooms: TodoRooms = BTreeMap::new();
@@ -146,7 +146,7 @@ pub(crate) async fn sync_events_v5_route(
let e2ee = collect_e2ee(services, sync_info, all_joined_rooms.clone()); let e2ee = collect_e2ee(services, sync_info, all_joined_rooms.clone());
let to_device = collect_to_device(services, sync_info, next_batch).map(Ok); let to_device = collect_to_device(services, sync_info, *next_batch).map(Ok);
let receipts = collect_receipts(services).map(Ok); let receipts = collect_receipts(services).map(Ok);
@@ -189,7 +189,7 @@ pub(crate) async fn sync_events_v5_route(
response.rooms = process_rooms( response.rooms = process_rooms(
services, services,
sender_user, sender_user,
next_batch, *next_batch,
all_invited_rooms.clone(), all_invited_rooms.clone(),
&todo_rooms, &todo_rooms,
&mut response, &mut response,

View File

@@ -41,7 +41,6 @@ pub(crate) async fn send_event_to_device_route(
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages.insert(target_user_id.clone(), map); messages.insert(target_user_id.clone(), map);
let count = services.globals.next_count()?;
let mut buf = EduBuf::new(); let mut buf = EduBuf::new();
serde_json::to_writer( serde_json::to_writer(
@@ -49,7 +48,7 @@ pub(crate) async fn send_event_to_device_route(
&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { &federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
sender: sender_user.to_owned(), sender: sender_user.to_owned(),
ev_type: body.event_type.clone(), ev_type: body.event_type.clone(),
message_id: count.to_string().into(), message_id: services.globals.next_count().to_string().into(),
messages, messages,
}), }),
) )

View File

@@ -63,8 +63,8 @@ pub async fn update(
return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); return Err!(Request(InvalidParam("Account data doesn't have all required fields.")));
} }
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
let roomuserdataid = (room_id, user_id, count, &event_type); let roomuserdataid = (room_id, user_id, *count, &event_type);
self.db self.db
.roomuserdataid_accountdata .roomuserdataid_accountdata
.put(roomuserdataid, Json(data)); .put(roomuserdataid, Json(data));

View File

@@ -1,56 +1,51 @@
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use tuwunel_core::{Result, utils}; use tuwunel_core::{
Result, utils,
utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit},
};
use tuwunel_database::{Database, Deserialized, Map}; use tuwunel_database::{Database, Deserialized, Map};
pub struct Data { pub struct Data {
global: Arc<Map>, global: Arc<Map>,
counter: RwLock<u64>, counter: Arc<Counter>,
pub(super) db: Arc<Database>, pub(super) db: Arc<Database>,
} }
pub(super) type Permit = TwoPhasePermit<Callback>;
type Counter = TwoPhaseCounter<Callback>;
type Callback = Box<dyn Fn(u64) -> Result + Send + Sync>;
const COUNTER: &[u8] = b"c"; const COUNTER: &[u8] = b"c";
impl Data { impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self { pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db; let db = args.db.clone();
Self { Self {
global: db["global"].clone(),
counter: RwLock::new(
Self::stored_count(&db["global"]).expect("initialized global counter"),
),
db: args.db.clone(), db: args.db.clone(),
global: args.db["global"].clone(),
counter: Counter::new(
Self::stored_count(&args.db["global"]).expect("initialized global counter"),
Box::new(move |count| Self::store_count(&db, &db["global"], count)),
),
} }
} }
pub fn next_count(&self) -> Result<u64> { #[inline]
let _cork = self.db.cork(); pub fn next_count(&self) -> Permit {
let mut lock = self.counter.write().expect("locked"); self.counter
let counter: &mut u64 = &mut lock; .next()
debug_assert!( .expect("failed to obtain next sequence number")
*counter == Self::stored_count(&self.global).expect("database failure"),
"counter mismatch"
);
*counter = counter
.checked_add(1)
.expect("counter must not overflow u64");
self.global.insert(COUNTER, counter.to_be_bytes());
Ok(*counter)
} }
#[inline] #[inline]
pub fn current_count(&self) -> u64 { pub fn current_count(&self) -> u64 { self.counter.current() }
let lock = self.counter.read().expect("locked");
let counter: &u64 = &lock;
debug_assert!(
*counter == Self::stored_count(&self.global).expect("database failure"),
"counter mismatch"
);
*counter fn store_count(db: &Arc<Database>, global: &Arc<Map>, count: u64) -> Result {
let _cork = db.cork();
global.insert(COUNTER, count.to_be_bytes());
Ok(())
} }
fn stored_count(global: &Arc<Map>) -> Result<u64> { fn stored_count(global: &Arc<Map>) -> Result<u64> {
@@ -59,6 +54,12 @@ impl Data {
.as_deref() .as_deref()
.map_or(Ok(0_u64), utils::u64_from_bytes) .map_or(Ok(0_u64), utils::u64_from_bytes)
} }
}
impl Data {
pub fn bump_database_version(&self, new_version: u64) {
self.global.raw_put(b"version", new_version);
}
pub async fn database_version(&self) -> u64 { pub async fn database_version(&self) -> u64 {
self.global self.global
@@ -67,9 +68,4 @@ impl Data {
.deserialized() .deserialized()
.unwrap_or(0) .unwrap_or(0)
} }
#[inline]
pub fn bump_database_version(&self, new_version: u64) {
self.global.raw_put(b"version", new_version);
}
} }

View File

@@ -103,10 +103,12 @@ impl crate::Service for Service {
impl Service { impl Service {
#[inline] #[inline]
pub fn next_count(&self) -> Result<u64> { self.db.next_count() } #[must_use]
pub fn next_count(&self) -> data::Permit { self.db.next_count() }
#[inline] #[inline]
pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) } #[must_use]
pub fn current_count(&self) -> u64 { self.db.current_count() }
#[inline] #[inline]
#[must_use] #[must_use]

View File

@@ -52,17 +52,18 @@ pub fn create_backup(
user_id: &UserId, user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
) -> Result<String> { ) -> Result<String> {
let version = self.services.globals.next_count()?.to_string(); let version = self.services.globals.next_count();
let count = self.services.globals.next_count()?; let count = self.services.globals.next_count();
let key = (user_id, &version); let version_string = version.to_string();
let key = (user_id, &version_string);
self.db self.db
.backupid_algorithm .backupid_algorithm
.put(key, Json(backup_metadata)); .put(key, Json(backup_metadata));
self.db.backupid_etag.put(key, count); self.db.backupid_etag.put(key, *count);
Ok(version) Ok(version_string)
} }
#[implement(Service)] #[implement(Service)]
@@ -100,8 +101,8 @@ pub async fn update_backup<'a>(
return Err!(Request(NotFound("Tried to update nonexistent backup."))); return Err!(Request(NotFound("Tried to update nonexistent backup.")));
} }
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
self.db.backupid_etag.put(key, count); self.db.backupid_etag.put(key, *count);
self.db self.db
.backupid_algorithm .backupid_algorithm
.put_raw(key, backup_metadata.json().get()); .put_raw(key, backup_metadata.json().get());
@@ -175,8 +176,8 @@ pub async fn add_key(
return Err!(Request(NotFound("Tried to update nonexistent backup."))); return Err!(Request(NotFound("Tried to update nonexistent backup.")));
} }
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
self.db.backupid_etag.put(key, count); self.db.backupid_etag.put(key, *count);
let key = (user_id, version, room_id, session_id); let key = (user_id, version, room_id, session_id);
self.db self.db

View File

@@ -120,12 +120,12 @@ impl Data {
status_msg, status_msg,
); );
let count = self.services.globals.next_count()?; let count = self.services.globals.next_count();
let key = presenceid_key(count, user_id); let key = presenceid_key(*count, user_id);
self.userid_presenceid.raw_put(user_id, *count);
self.presenceid_presence self.presenceid_presence
.raw_put(key, Json(presence)); .raw_put(key, Json(presence));
self.userid_presenceid.raw_put(user_id, count);
if let Ok((last_count, _)) = last_presence { if let Ok((last_count, _)) = last_presence {
let key = presenceid_key(last_count, user_id); let key = presenceid_key(last_count, user_id);

View File

@@ -71,6 +71,8 @@ impl Service {
return Err!(Request(Forbidden("Only the server user can set this alias"))); return Err!(Request(Forbidden("Only the server user can set this alias")));
} }
let count = self.services.globals.next_count();
// Comes first as we don't want a stuck alias // Comes first as we don't want a stuck alias
self.db self.db
.alias_userid .alias_userid
@@ -82,7 +84,8 @@ impl Service {
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF); aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); aliasid.extend_from_slice(&count.to_be_bytes());
self.db self.db
.aliasid_alias .aliasid_alias
.insert(&aliasid, alias.as_bytes()); .insert(&aliasid, alias.as_bytes());

View File

@@ -56,8 +56,8 @@ impl Data {
.ready_for_each(|key| self.readreceiptid_readreceipt.del(key)) .ready_for_each(|key| self.readreceiptid_readreceipt.del(key))
.await; .await;
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
let latest_id = (room_id, count, user_id); let latest_id = (room_id, *count, user_id);
self.readreceiptid_readreceipt self.readreceiptid_readreceipt
.put(latest_id, Json(event)); .put(latest_id, Json(event));
} }
@@ -89,11 +89,11 @@ impl Data {
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) { pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) {
let key = (room_id, user_id); let key = (room_id, user_id);
let next_count = self.services.globals.next_count().unwrap(); let next_count = self.services.globals.next_count();
self.roomuserid_privateread.put(key, pdu_count); self.roomuserid_privateread.put(key, pdu_count);
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate
.put(key, next_count); .put(key, *next_count);
} }
pub(super) async fn private_read_get_count( pub(super) async fn private_read_get_count(

View File

@@ -81,18 +81,18 @@ where
fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId { fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
const BUFSIZE: usize = size_of::<ShortEventId>(); const BUFSIZE: usize = size_of::<ShortEventId>();
let short = self.services.globals.next_count().unwrap(); let short = self.services.globals.next_count();
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
self.db self.db
.eventid_shorteventid .eventid_shorteventid
.raw_aput::<BUFSIZE, _, _>(event_id, short); .raw_aput::<BUFSIZE, _, _>(event_id, *short);
self.db self.db
.shorteventid_eventid .shorteventid_eventid
.aput_raw::<BUFSIZE, _, _>(short, event_id); .aput_raw::<BUFSIZE, _, _>(*short, event_id);
short *short
} }
#[implement(Service)] #[implement(Service)]
@@ -120,18 +120,19 @@ pub async fn get_or_create_shortstatekey(
} }
let key = (event_type, state_key); let key = (event_type, state_key);
let shortstatekey = self.services.globals.next_count().unwrap(); let shortstatekey = self.services.globals.next_count();
debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed");
debug_assert!(size_of_val(&*shortstatekey) == BUFSIZE, "buffer requirement changed");
self.db self.db
.statekey_shortstatekey .statekey_shortstatekey
.put_aput::<BUFSIZE, _, _>(key, shortstatekey); .put_aput::<BUFSIZE, _, _>(key, *shortstatekey);
self.db self.db
.shortstatekey_statekey .shortstatekey_statekey
.aput_put::<BUFSIZE, _, _>(shortstatekey, key); .aput_put::<BUFSIZE, _, _>(*shortstatekey, key);
shortstatekey *shortstatekey
} }
#[implement(Service)] #[implement(Service)]
@@ -226,14 +227,14 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortSta
return (shortstatehash, true); return (shortstatehash, true);
} }
let shortstatehash = self.services.globals.next_count().unwrap(); let shortstatehash = self.services.globals.next_count();
debug_assert!(size_of_val(&shortstatehash) == BUFSIZE, "buffer requirement changed"); debug_assert!(size_of_val(&*shortstatehash) == BUFSIZE, "buffer requirement changed");
self.db self.db
.statehash_shortstatehash .statehash_shortstatehash
.raw_aput::<BUFSIZE, _, _>(state_hash, shortstatehash); .raw_aput::<BUFSIZE, _, _>(state_hash, *shortstatehash);
(shortstatehash, false) (*shortstatehash, false)
} }
#[implement(Service)] #[implement(Service)]
@@ -255,13 +256,13 @@ pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId {
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
const BUFSIZE: usize = size_of::<ShortRoomId>(); const BUFSIZE: usize = size_of::<ShortRoomId>();
let short = self.services.globals.next_count().unwrap(); let short = self.services.globals.next_count();
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed");
self.db self.db
.roomid_shortroomid .roomid_shortroomid
.raw_aput::<BUFSIZE, _, _>(room_id, short); .raw_aput::<BUFSIZE, _, _>(room_id, *short);
short *short
}) })
} }

View File

@@ -305,7 +305,7 @@ impl Service {
} }
// TODO: statehash with deterministic inputs // TODO: statehash with deterministic inputs
let shortstatehash = self.services.globals.next_count()?; let shortstatehash = self.services.globals.next_count();
let mut statediffnew = CompressedState::new(); let mut statediffnew = CompressedState::new();
statediffnew.insert(new); statediffnew.insert(new);
@@ -318,14 +318,14 @@ impl Service {
self.services self.services
.state_compressor .state_compressor
.save_state_from_diff( .save_state_from_diff(
shortstatehash, *shortstatehash,
Arc::new(statediffnew), Arc::new(statediffnew),
Arc::new(statediffremoved), Arc::new(statediffremoved),
2, 2,
states_parents, states_parents,
)?; )?;
Ok(shortstatehash) Ok(*shortstatehash)
}, },
| _ => | _ =>
Ok(previous_shortstatehash.expect("first event in room must be a state event")), Ok(previous_shortstatehash.expect("first event in room must be a state event")),

View File

@@ -272,6 +272,8 @@ pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
let count = self.services.globals.next_count();
let userroom_id = (user_id, room_id); let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
@@ -286,7 +288,7 @@ pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
.raw_put(&userroom_id, Json(leftstate)); .raw_put(&userroom_id, Json(leftstate));
self.db self.db
.roomuserid_leftcount .roomuserid_leftcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); .raw_aput::<8, _, _>(&roomuser_id, *count);
self.db.userroomid_joined.remove(&userroom_id); self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id); self.db.roomuserid_joined.remove(&roomuser_id);
@@ -319,6 +321,8 @@ pub fn mark_as_knocked(
room_id: &RoomId, room_id: &RoomId,
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) { ) {
let count = self.services.globals.next_count();
let userroom_id = (user_id, room_id); let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
@@ -330,7 +334,7 @@ pub fn mark_as_knocked(
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
self.db self.db
.roomuserid_knockedcount .roomuserid_knockedcount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); .raw_aput::<8, _, _>(&roomuser_id, *count);
self.db.userroomid_joined.remove(&userroom_id); self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id); self.db.roomuserid_joined.remove(&roomuser_id);
@@ -375,6 +379,8 @@ pub async fn mark_as_invited(
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>, invite_via: Option<Vec<OwnedServerName>>,
) { ) {
let count = self.services.globals.next_count();
let roomuser_id = (room_id, user_id); let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
@@ -386,7 +392,7 @@ pub async fn mark_as_invited(
.raw_put(&userroom_id, Json(last_state.unwrap_or_default())); .raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
self.db self.db
.roomuserid_invitecount .roomuserid_invitecount
.raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); .raw_aput::<8, _, _>(&roomuser_id, *count);
self.db.userroomid_joined.remove(&userroom_id); self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id); self.db.roomuserid_joined.remove(&roomuser_id);

View File

@@ -159,20 +159,20 @@ where
.await; .await;
let insert_lock = self.mutex_insert.lock(pdu.room_id()).await; let insert_lock = self.mutex_insert.lock(pdu.room_id()).await;
let count1 = self.services.globals.next_count();
let count1 = self.services.globals.next_count().unwrap(); let count2 = self.services.globals.next_count();
// Mark as read first so the sending client doesn't get a notification even if // Mark as read first so the sending client doesn't get a notification even if
// appending fails // appending fails
self.services self.services
.read_receipt .read_receipt
.private_read_set(pdu.room_id(), pdu.sender(), count1); .private_read_set(pdu.room_id(), pdu.sender(), *count1);
self.services self.services
.user .user
.reset_notification_counts(pdu.sender(), pdu.room_id()); .reset_notification_counts(pdu.sender(), pdu.room_id());
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); let count2 = PduCount::Normal(*count2);
let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into(); let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into();
// Insert pdu // Insert pdu

View File

@@ -181,13 +181,9 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box<RawJsonValue>) ->
let insert_lock = self.mutex_insert.lock(&room_id).await; let insert_lock = self.mutex_insert.lock(&room_id).await;
let count: i64 = self let count = self.services.globals.next_count();
.services
.globals
.next_count()
.unwrap()
.try_into()?;
let count: i64 = (*count).try_into()?;
let pdu_id: RawPduId = PduId { let pdu_id: RawPduId = PduId {
shortroomid, shortroomid,
shorteventid: PduCount::Backfilled(validated!(0 - count)), shorteventid: PduCount::Backfilled(validated!(0 - count)),

View File

@@ -53,6 +53,7 @@ impl Service {
/// roomtyping_remove is called. /// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result { pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result {
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
// update clients // update clients
self.typing self.typing
.write() .write()
@@ -61,10 +62,11 @@ impl Service {
.or_default() .or_default()
.insert(user_id.to_owned(), timeout); .insert(user_id.to_owned(), timeout);
let count = self.services.globals.next_count();
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), self.services.globals.next_count()?); .insert(room_id.to_owned(), *count);
if self if self
.typing_update_sender .typing_update_sender
@@ -86,6 +88,7 @@ impl Service {
/// Removes a user from typing before the timeout is reached. /// Removes a user from typing before the timeout is reached.
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result { pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result {
debug_info!("typing stopped {user_id:?} in {room_id:?}"); debug_info!("typing stopped {user_id:?} in {room_id:?}");
// update clients // update clients
self.typing self.typing
.write() .write()
@@ -94,10 +97,11 @@ impl Service {
.or_default() .or_default()
.remove(user_id); .remove(user_id);
let count = self.services.globals.next_count();
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), self.services.globals.next_count()?); .insert(room_id.to_owned(), *count);
if self if self
.typing_update_sender .typing_update_sender
@@ -146,16 +150,18 @@ impl Service {
if !removable.is_empty() { if !removable.is_empty() {
let typing = &mut self.typing.write().await; let typing = &mut self.typing.write().await;
let room = typing.entry(room_id.to_owned()).or_default(); let room = typing.entry(room_id.to_owned()).or_default();
for user in &removable { for user in &removable {
debug_info!("typing timeout {user:?} in {room_id:?}"); debug_info!("typing timeout {user:?} in {room_id:?}");
room.remove(user); room.remove(user);
} }
// update clients // update clients
let count = self.services.globals.next_count();
self.last_typing_update self.last_typing_update
.write() .write()
.await .await
.insert(room_id.to_owned(), self.services.globals.next_count()?); .insert(room_id.to_owned(), *count);
if self if self
.typing_update_sender .typing_update_sender

View File

@@ -47,6 +47,8 @@ impl crate::Service for Service {
#[implement(Service)] #[implement(Service)]
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
let count = self.services.globals.next_count();
let userroom_id = (user_id, room_id); let userroom_id = (user_id, room_id);
self.db self.db
.userroomid_highlightcount .userroomid_highlightcount
@@ -56,10 +58,9 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
.put(userroom_id, 0_u64); .put(userroom_id, 0_u64);
let roomuser_id = (room_id, user_id); let roomuser_id = (room_id, user_id);
let count = self.services.globals.next_count().unwrap();
self.db self.db
.roomuserid_lastnotificationread .roomuserid_lastnotificationread
.put(roomuser_id, count); .put(roomuser_id, *count);
} }
#[implement(Service)] #[implement(Service)]

View File

@@ -126,8 +126,9 @@ impl Data {
if let SendingEvent::Pdu(value) = event { if let SendingEvent::Pdu(value) = event {
key.extend(value.as_ref()); key.extend(value.as_ref());
} else { } else {
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
key.extend(&count.to_be_bytes()); let count = count.to_be_bytes();
key.extend(&count);
} }
key key

View File

@@ -378,7 +378,7 @@ impl Service {
async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> { async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> {
// selection window // selection window
let since = self.db.get_latest_educount(server_name).await; let since = self.db.get_latest_educount(server_name).await;
let since_upper = self.services.globals.current_count()?; let since_upper = self.services.globals.current_count();
let batch = (since, since_upper); let batch = (since, since_upper);
debug_assert!(batch.0 <= batch.1, "since range must not be negative"); debug_assert!(batch.0 <= batch.1, "since range must not be negative");

View File

@@ -154,9 +154,9 @@ pub async fn add_to_device_event(
event_type: &str, event_type: &str,
content: serde_json::Value, content: serde_json::Value,
) { ) {
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
let key = (target_user_id, target_device_id, count); let key = (target_user_id, target_device_id, *count);
self.db.todeviceid_events.put( self.db.todeviceid_events.put(
key, key,
Json(json!({ Json(json!({

View File

@@ -52,14 +52,14 @@ pub async fn add_one_time_key(
.as_bytes(), .as_bytes(),
); );
let count = self.services.globals.next_count();
self.db self.db
.onetimekeyid_onetimekeys .onetimekeyid_onetimekeys
.raw_put(key, Json(one_time_key_value)); .raw_put(key, Json(one_time_key_value));
let count = self.services.globals.next_count().unwrap();
self.db self.db
.userid_lastonetimekeyupdate .userid_lastonetimekeyupdate
.raw_put(user_id, count); .raw_put(user_id, *count);
Ok(()) Ok(())
} }
@@ -81,10 +81,10 @@ pub async fn take_one_time_key(
device_id: &DeviceId, device_id: &DeviceId,
key_algorithm: &OneTimeKeyAlgorithm, key_algorithm: &OneTimeKeyAlgorithm,
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> { ) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
let count = self.services.globals.next_count()?.to_be_bytes(); let count = self.services.globals.next_count();
self.db self.db
.userid_lastonetimekeyupdate .userid_lastonetimekeyupdate
.insert(user_id, count); .insert(user_id, count.to_be_bytes());
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
@@ -340,7 +340,7 @@ fn keys_changed_user_or_room<'a>(
#[implement(super::Service)] #[implement(super::Service)]
pub async fn mark_device_key_update(&self, user_id: &UserId) { pub async fn mark_device_key_update(&self, user_id: &UserId) {
let count = self.services.globals.next_count().unwrap(); let count = self.services.globals.next_count();
self.services self.services
.state_cache .state_cache
@@ -348,12 +348,12 @@ pub async fn mark_device_key_update(&self, user_id: &UserId) {
// Don't send key updates to unencrypted rooms // Don't send key updates to unencrypted rooms
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
.ready_for_each(|room_id| { .ready_for_each(|room_id| {
let key = (room_id, count); let key = (room_id, *count);
self.db.keychangeid_userid.put_raw(key, user_id); self.db.keychangeid_userid.put_raw(key, user_id);
}) })
.await; .await;
let key = (user_id, count); let key = (user_id, *count);
self.db.keychangeid_userid.put_raw(key, user_id); self.db.keychangeid_userid.put_raw(key, user_id);
} }