diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 715b39d5..e551a62d 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -160,7 +160,7 @@ pub(crate) async fn build_sync_events( ) -> Result> { let (sender_user, sender_device) = body.sender(); - let next_batch = services.globals.current_count()?; + let next_batch = services.globals.current_count(); let since = body .body .since diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 8acd43ba..20fad035 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -74,7 +74,7 @@ pub(crate) async fn sync_events_v5_route( // Setup watchers, so if there's no response, we can wait for them let watcher = services.sync.watch(sender_user, sender_device); - let next_batch = services.globals.next_count()?; + let next_batch = services.globals.next_count(); let conn_id = body.conn_id.clone(); @@ -136,7 +136,7 @@ pub(crate) async fn sync_events_v5_route( .chain(all_invited_rooms.clone()) .chain(all_knocked_rooms.clone()); - let pos = next_batch.clone().to_string(); + let pos = next_batch.to_string(); let mut todo_rooms: TodoRooms = BTreeMap::new(); @@ -146,7 +146,7 @@ pub(crate) async fn sync_events_v5_route( let e2ee = collect_e2ee(services, sync_info, all_joined_rooms.clone()); - let to_device = collect_to_device(services, sync_info, next_batch).map(Ok); + let to_device = collect_to_device(services, sync_info, *next_batch).map(Ok); let receipts = collect_receipts(services).map(Ok); @@ -189,7 +189,7 @@ pub(crate) async fn sync_events_v5_route( response.rooms = process_rooms( services, sender_user, - next_batch, + *next_batch, all_invited_rooms.clone(), &todo_rooms, &mut response, diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 59171b15..e587fccd 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -41,7 +41,6 @@ pub(crate) async fn send_event_to_device_route( map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); - let count = services.globals.next_count()?; let mut buf = EduBuf::new(); serde_json::to_writer( @@ -49,7 +48,7 @@ pub(crate) async fn send_event_to_device_route( &federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { sender: sender_user.to_owned(), ev_type: body.event_type.clone(), - message_id: count.to_string().into(), + message_id: services.globals.next_count().to_string().into(), messages, }), ) diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 7c260f1c..d8d4d033 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -63,8 +63,8 @@ pub async fn update( return Err!(Request(InvalidParam("Account data doesn't have all required fields."))); } - let count = self.services.globals.next_count().unwrap(); - let roomuserdataid = (room_id, user_id, count, &event_type); + let count = self.services.globals.next_count(); + let roomuserdataid = (room_id, user_id, *count, &event_type); self.db .roomuserdataid_accountdata .put(roomuserdataid, Json(data)); diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5fa828ff..ea64fa03 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,56 +1,51 @@ -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use tuwunel_core::{Result, utils}; +use tuwunel_core::{ + Result, utils, + utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit}, +}; use tuwunel_database::{Database, Deserialized, Map}; pub struct Data { global: Arc, - counter: RwLock, + counter: Arc, pub(super) db: Arc, } +pub(super) type Permit = TwoPhasePermit; +type Counter = TwoPhaseCounter; +type Callback = Box Result + Send + Sync>; + const COUNTER: &[u8] = b"c"; impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { - let db = &args.db; + let db = args.db.clone(); Self { - global: db["global"].clone(), - counter: RwLock::new( - Self::stored_count(&db["global"]).expect("initialized global counter"), - ), db: args.db.clone(), + global: args.db["global"].clone(), + counter: Counter::new( + Self::stored_count(&args.db["global"]).expect("initialized global counter"), + Box::new(move |count| Self::store_count(&db, &db["global"], count)), + ), } } - pub fn next_count(&self) -> Result { - let _cork = self.db.cork(); - let mut lock = self.counter.write().expect("locked"); - let counter: &mut u64 = &mut lock; - debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), - "counter mismatch" - ); - - *counter = counter - .checked_add(1) - .expect("counter must not overflow u64"); - - self.global.insert(COUNTER, counter.to_be_bytes()); - - Ok(*counter) + #[inline] + pub fn next_count(&self) -> Permit { + self.counter + .next() + .expect("failed to obtain next sequence number") } #[inline] - pub fn current_count(&self) -> u64 { - let lock = self.counter.read().expect("locked"); - let counter: &u64 = &lock; - debug_assert!( - *counter == Self::stored_count(&self.global).expect("database failure"), - "counter mismatch" - ); + pub fn current_count(&self) -> u64 { self.counter.current() } - *counter + fn store_count(db: &Arc, global: &Arc, count: u64) -> Result { + let _cork = db.cork(); + global.insert(COUNTER, count.to_be_bytes()); + + Ok(()) } fn stored_count(global: &Arc) -> Result { @@ -59,6 +54,12 @@ impl Data { .as_deref() .map_or(Ok(0_u64), utils::u64_from_bytes) } +} + +impl Data { + pub fn bump_database_version(&self, new_version: u64) { + self.global.raw_put(b"version", new_version); + } pub async fn database_version(&self) -> u64 { self.global @@ -67,9 +68,4 @@ impl Data { .deserialized() .unwrap_or(0) } - - #[inline] - pub fn bump_database_version(&self, new_version: u64) { - self.global.raw_put(b"version", new_version); - } } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c6a060e7..29ecc0ab 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -103,10 +103,12 @@ impl crate::Service for Service { impl Service { #[inline] - pub fn next_count(&self) -> Result { self.db.next_count() } + #[must_use] + pub fn next_count(&self) -> data::Permit { self.db.next_count() } #[inline] - pub fn current_count(&self) -> Result { Ok(self.db.current_count()) } + #[must_use] + pub fn current_count(&self) -> u64 { self.db.current_count() } #[inline] #[must_use] diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 739376e6..ec541ba5 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -52,17 +52,18 @@ pub fn create_backup( user_id: &UserId, backup_metadata: &Raw, ) -> Result { - let version = self.services.globals.next_count()?.to_string(); - let count = self.services.globals.next_count()?; + let version = self.services.globals.next_count(); + let count = self.services.globals.next_count(); - let key = (user_id, &version); + let version_string = version.to_string(); + let key = (user_id, &version_string); self.db .backupid_algorithm .put(key, Json(backup_metadata)); - self.db.backupid_etag.put(key, count); + self.db.backupid_etag.put(key, *count); - Ok(version) + Ok(version_string) } #[implement(Service)] @@ -100,8 +101,8 @@ pub async fn update_backup<'a>( return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - let count = self.services.globals.next_count().unwrap(); - self.db.backupid_etag.put(key, count); + let count = self.services.globals.next_count(); + self.db.backupid_etag.put(key, *count); self.db .backupid_algorithm .put_raw(key, backup_metadata.json().get()); @@ -175,8 +176,8 @@ pub async fn add_key( return Err!(Request(NotFound("Tried to update nonexistent backup."))); } - let count = self.services.globals.next_count().unwrap(); - self.db.backupid_etag.put(key, count); + let count = self.services.globals.next_count(); + self.db.backupid_etag.put(key, *count); let key = (user_id, version, room_id, session_id); self.db diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index 2cae1da8..53dee45c 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -120,12 +120,12 @@ impl Data { status_msg, ); - let count = self.services.globals.next_count()?; - let key = presenceid_key(count, user_id); + let count = self.services.globals.next_count(); + let key = presenceid_key(*count, user_id); + self.userid_presenceid.raw_put(user_id, *count); self.presenceid_presence .raw_put(key, Json(presence)); - self.userid_presenceid.raw_put(user_id, count); if let Ok((last_count, _)) = last_presence { let key = presenceid_key(last_count, user_id); diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index ddb4606a..0270acb4 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -71,6 +71,8 @@ impl Service { return Err!(Request(Forbidden("Only the server user can set this alias"))); } + let count = self.services.globals.next_count(); + // Comes first as we don't want a stuck alias self.db .alias_userid @@ -82,7 +84,8 @@ impl Service { let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xFF); - aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&count.to_be_bytes()); + self.db .aliasid_alias .insert(&aliasid, alias.as_bytes()); diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 65436362..11fb469b 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -56,8 +56,8 @@ impl Data { .ready_for_each(|key| self.readreceiptid_readreceipt.del(key)) .await; - let count = self.services.globals.next_count().unwrap(); - let latest_id = (room_id, count, user_id); + let count = self.services.globals.next_count(); + let latest_id = (room_id, *count, user_id); self.readreceiptid_readreceipt .put(latest_id, Json(event)); } @@ -89,11 +89,11 @@ impl Data { pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) { let key = (room_id, user_id); - let next_count = self.services.globals.next_count().unwrap(); + let next_count = self.services.globals.next_count(); self.roomuserid_privateread.put(key, pdu_count); self.roomuserid_lastprivatereadupdate - .put(key, next_count); + .put(key, *next_count); } pub(super) async fn private_read_get_count( diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index d968d03d..28f5907c 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -81,18 +81,18 @@ where fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId { const BUFSIZE: usize = size_of::(); - let short = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + let short = self.services.globals.next_count(); + debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed"); self.db .eventid_shorteventid - .raw_aput::(event_id, short); + .raw_aput::(event_id, *short); self.db .shorteventid_eventid - .aput_raw::(short, event_id); + .aput_raw::(*short, event_id); - short + *short } #[implement(Service)] @@ -120,18 +120,19 @@ pub async fn get_or_create_shortstatekey( } let key = (event_type, state_key); - let shortstatekey = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&shortstatekey) == BUFSIZE, "buffer requirement changed"); + let shortstatekey = self.services.globals.next_count(); + + debug_assert!(size_of_val(&*shortstatekey) == BUFSIZE, "buffer requirement changed"); self.db .statekey_shortstatekey - .put_aput::(key, shortstatekey); + .put_aput::(key, *shortstatekey); self.db .shortstatekey_statekey - .aput_put::(shortstatekey, key); + .aput_put::(*shortstatekey, key); - shortstatekey + *shortstatekey } #[implement(Service)] @@ -226,14 +227,14 @@ pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (ShortSta return (shortstatehash, true); } - let shortstatehash = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&shortstatehash) == BUFSIZE, "buffer requirement changed"); + let shortstatehash = self.services.globals.next_count(); + debug_assert!(size_of_val(&*shortstatehash) == BUFSIZE, "buffer requirement changed"); self.db .statehash_shortstatehash - .raw_aput::(state_hash, shortstatehash); + .raw_aput::(state_hash, *shortstatehash); - (shortstatehash, false) + (*shortstatehash, false) } #[implement(Service)] @@ -255,13 +256,13 @@ pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> ShortRoomId { .unwrap_or_else(|_| { const BUFSIZE: usize = size_of::(); - let short = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + let short = self.services.globals.next_count(); + debug_assert!(size_of_val(&*short) == BUFSIZE, "buffer requirement changed"); self.db .roomid_shortroomid - .raw_aput::(room_id, short); + .raw_aput::(room_id, *short); - short + *short }) } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 84963c00..0101d103 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -305,7 +305,7 @@ impl Service { } // TODO: statehash with deterministic inputs - let shortstatehash = self.services.globals.next_count()?; + let shortstatehash = self.services.globals.next_count(); let mut statediffnew = CompressedState::new(); statediffnew.insert(new); @@ -318,14 +318,14 @@ impl Service { self.services .state_compressor .save_state_from_diff( - shortstatehash, + *shortstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), 2, states_parents, )?; - Ok(shortstatehash) + Ok(*shortstatehash) }, | _ => Ok(previous_shortstatehash.expect("first event in room must be a state event")), diff --git a/src/service/rooms/state_cache/update.rs b/src/service/rooms/state_cache/update.rs index a2083fd3..a811f4d2 100644 --- a/src/service/rooms/state_cache/update.rs +++ b/src/service/rooms/state_cache/update.rs @@ -272,6 +272,8 @@ pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { #[implement(super::Service)] #[tracing::instrument(skip(self), level = "debug")] pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { + let count = self.services.globals.next_count(); + let userroom_id = (user_id, room_id); let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); @@ -286,7 +288,7 @@ pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { .raw_put(&userroom_id, Json(leftstate)); self.db .roomuserid_leftcount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + .raw_aput::<8, _, _>(&roomuser_id, *count); self.db.userroomid_joined.remove(&userroom_id); self.db.roomuserid_joined.remove(&roomuser_id); @@ -319,6 +321,8 @@ pub fn mark_as_knocked( room_id: &RoomId, knocked_state: Option>>, ) { + let count = self.services.globals.next_count(); + let userroom_id = (user_id, room_id); let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id"); @@ -330,7 +334,7 @@ pub fn mark_as_knocked( .raw_put(&userroom_id, Json(knocked_state.unwrap_or_default())); self.db .roomuserid_knockedcount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + .raw_aput::<8, _, _>(&roomuser_id, *count); self.db.userroomid_joined.remove(&userroom_id); self.db.roomuserid_joined.remove(&roomuser_id); @@ -375,6 +379,8 @@ pub async fn mark_as_invited( last_state: Option>>, invite_via: Option>, ) { + let count = self.services.globals.next_count(); + let roomuser_id = (room_id, user_id); let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id"); @@ -386,7 +392,7 @@ pub async fn mark_as_invited( .raw_put(&userroom_id, Json(last_state.unwrap_or_default())); self.db .roomuserid_invitecount - .raw_aput::<8, _, _>(&roomuser_id, self.services.globals.next_count().unwrap()); + .raw_aput::<8, _, _>(&roomuser_id, *count); self.db.userroomid_joined.remove(&userroom_id); self.db.roomuserid_joined.remove(&roomuser_id); diff --git a/src/service/rooms/timeline/append.rs b/src/service/rooms/timeline/append.rs index 57e55464..278e72c8 100644 --- a/src/service/rooms/timeline/append.rs +++ b/src/service/rooms/timeline/append.rs @@ -159,20 +159,20 @@ where .await; let insert_lock = self.mutex_insert.lock(pdu.room_id()).await; - - let count1 = self.services.globals.next_count().unwrap(); + let count1 = self.services.globals.next_count(); + let count2 = self.services.globals.next_count(); // Mark as read first so the sending client doesn't get a notification even if // appending fails self.services .read_receipt - .private_read_set(pdu.room_id(), pdu.sender(), count1); + .private_read_set(pdu.room_id(), pdu.sender(), *count1); self.services .user .reset_notification_counts(pdu.sender(), pdu.room_id()); - let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); + let count2 = PduCount::Normal(*count2); let pdu_id: RawPduId = PduId { shortroomid, shorteventid: count2 }.into(); // Insert pdu diff --git a/src/service/rooms/timeline/backfill.rs b/src/service/rooms/timeline/backfill.rs index 8702fc76..b4917295 100644 --- a/src/service/rooms/timeline/backfill.rs +++ b/src/service/rooms/timeline/backfill.rs @@ -181,13 +181,9 @@ pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box) -> let insert_lock = self.mutex_insert.lock(&room_id).await; - let count: i64 = self - .services - .globals - .next_count() - .unwrap() - .try_into()?; + let count = self.services.globals.next_count(); + let count: i64 = (*count).try_into()?; let pdu_id: RawPduId = PduId { shortroomid, shorteventid: PduCount::Backfilled(validated!(0 - count)), diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 0e7070c8..90f5b917 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -53,6 +53,7 @@ impl Service { /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result { debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}"); + // update clients self.typing .write() @@ -61,10 +62,11 @@ impl Service { .or_default() .insert(user_id.to_owned(), timeout); + let count = self.services.globals.next_count(); self.last_typing_update .write() .await - .insert(room_id.to_owned(), self.services.globals.next_count()?); + .insert(room_id.to_owned(), *count); if self .typing_update_sender @@ -86,6 +88,7 @@ impl Service { /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result { debug_info!("typing stopped {user_id:?} in {room_id:?}"); + // update clients self.typing .write() @@ -94,10 +97,11 @@ impl Service { .or_default() .remove(user_id); + let count = self.services.globals.next_count(); self.last_typing_update .write() .await - .insert(room_id.to_owned(), self.services.globals.next_count()?); + .insert(room_id.to_owned(), *count); if self .typing_update_sender @@ -146,16 +150,18 @@ impl Service { if !removable.is_empty() { let typing = &mut self.typing.write().await; let room = typing.entry(room_id.to_owned()).or_default(); + for user in &removable { debug_info!("typing timeout {user:?} in {room_id:?}"); room.remove(user); } // update clients + let count = self.services.globals.next_count(); self.last_typing_update .write() .await - .insert(room_id.to_owned(), self.services.globals.next_count()?); + .insert(room_id.to_owned(), *count); if self .typing_update_sender diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 61953e4d..bc57d8f0 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -47,6 +47,8 @@ impl crate::Service for Service { #[implement(Service)] pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { + let count = self.services.globals.next_count(); + let userroom_id = (user_id, room_id); self.db .userroomid_highlightcount @@ -56,10 +58,9 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { .put(userroom_id, 0_u64); let roomuser_id = (room_id, user_id); - let count = self.services.globals.next_count().unwrap(); self.db .roomuserid_lastnotificationread - .put(roomuser_id, count); + .put(roomuser_id, *count); } #[implement(Service)] diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 1f77d964..3a2b3c4e 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -126,8 +126,9 @@ impl Data { if let SendingEvent::Pdu(value) = event { key.extend(value.as_ref()); } else { - let count = self.services.globals.next_count().unwrap(); - key.extend(&count.to_be_bytes()); + let count = self.services.globals.next_count(); + let count = count.to_be_bytes(); + key.extend(&count); } key diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index fb8abd2a..3552def8 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -378,7 +378,7 @@ impl Service { async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> { // selection window let since = self.db.get_latest_educount(server_name).await; - let since_upper = self.services.globals.current_count()?; + let since_upper = self.services.globals.current_count(); let batch = (since, since_upper); debug_assert!(batch.0 <= batch.1, "since range must not be negative"); diff --git a/src/service/users/device.rs b/src/service/users/device.rs index d5afa954..15aed62b 100644 --- a/src/service/users/device.rs +++ b/src/service/users/device.rs @@ -154,9 +154,9 @@ pub async fn add_to_device_event( event_type: &str, content: serde_json::Value, ) { - let count = self.services.globals.next_count().unwrap(); + let count = self.services.globals.next_count(); - let key = (target_user_id, target_device_id, count); + let key = (target_user_id, target_device_id, *count); self.db.todeviceid_events.put( key, Json(json!({ diff --git a/src/service/users/keys.rs b/src/service/users/keys.rs index d940d200..e494bc4f 100644 --- a/src/service/users/keys.rs +++ b/src/service/users/keys.rs @@ -52,14 +52,14 @@ pub async fn add_one_time_key( .as_bytes(), ); + let count = self.services.globals.next_count(); + self.db .onetimekeyid_onetimekeys .raw_put(key, Json(one_time_key_value)); - - let count = self.services.globals.next_count().unwrap(); self.db .userid_lastonetimekeyupdate - .raw_put(user_id, count); + .raw_put(user_id, *count); Ok(()) } @@ -81,10 +81,10 @@ pub async fn take_one_time_key( device_id: &DeviceId, key_algorithm: &OneTimeKeyAlgorithm, ) -> Result<(OwnedKeyId, Raw)> { - let count = self.services.globals.next_count()?.to_be_bytes(); + let count = self.services.globals.next_count(); self.db .userid_lastonetimekeyupdate - .insert(user_id, count); + .insert(user_id, count.to_be_bytes()); let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -340,7 +340,7 @@ fn keys_changed_user_or_room<'a>( #[implement(super::Service)] pub async fn mark_device_key_update(&self, user_id: &UserId) { - let count = self.services.globals.next_count().unwrap(); + let count = self.services.globals.next_count(); self.services .state_cache @@ -348,12 +348,12 @@ pub async fn mark_device_key_update(&self, user_id: &UserId) { // Don't send key updates to unencrypted rooms .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) .ready_for_each(|room_id| { - let key = (room_id, count); + let key = (room_id, *count); self.db.keychangeid_userid.put_raw(key, user_id); }) .await; - let key = (user_id, count); + let key = (user_id, *count); self.db.keychangeid_userid.put_raw(key, user_id); }