diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index c91e9924..a434af99 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -37,30 +37,16 @@ pub(crate) async fn upload_keys_route( ) -> Result { let (sender_user, sender_device) = body.sender(); - for (key_id, one_time_key) in body + let one_time_keys = body .one_time_keys .iter() .take(services.config.one_time_key_limit) - { - if one_time_key - .deserialize() - .inspect_err(|e| { - debug_warn!( - ?key_id, - ?one_time_key, - "Invalid one time key JSON submitted by client, skipping: {e}" - ); - }) - .is_err() - { - continue; - } + .map(|(id, val)| (id.as_ref(), val)); - services - .users - .add_one_time_key(sender_user, sender_device, key_id, one_time_key) - .await?; - } + services + .users + .add_one_time_keys(sender_user, sender_device, one_time_keys) + .await?; if let Some(device_keys) = &body.device_keys { let deser_device_keys = device_keys.deserialize().map_err(|e| { diff --git a/src/service/users/device.rs b/src/service/users/device.rs index 3a7be477..a640623c 100644 --- a/src/service/users/device.rs +++ b/src/service/users/device.rs @@ -391,6 +391,14 @@ pub async fn get_device_metadata( .deserialized() } +#[implement(super::Service)] +pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool { + self.db + .userdeviceid_metadata + .contains(&(user_id, device_id)) + .await +} + #[implement(super::Service)] pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { self.db diff --git a/src/service/users/keys.rs b/src/service/users/keys.rs index ed358a57..bd7649af 100644 --- a/src/service/users/keys.rs +++ b/src/service/users/keys.rs @@ -9,11 +9,30 @@ use ruma::{ serde::Raw, }; use tuwunel_core::{ - Err, Error, Result, err, implement, + Err, Error, Result, debug_error, err, implement, utils::{ReadyExt, stream::TryIgnore, string::Unquoted}, }; use tuwunel_database::{Deserialized, Ignore, Json}; +#[implement(super::Service)] +pub async fn add_one_time_keys<'a, Keys>( + &self, + user_id: &UserId, + device_id: &DeviceId, + keys: Keys, +) -> Result +where + Keys: Iterator)> + Send + 'a, +{ + for (id, key) in keys { + self.add_one_time_key(user_id, device_id, id, key) + .await + .ok(); + } + + Ok(()) +} + #[implement(super::Service)] pub async fn add_one_time_key( &self, @@ -22,17 +41,7 @@ pub async fn add_one_time_key( one_time_key_key: &KeyId, one_time_key_value: &Raw, ) -> Result { - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - let key = (user_id, device_id); - if self - .db - .userdeviceid_metadata - .qry(&key) - .await - .is_err() - { + if !self.device_exists(user_id, device_id).await { return Err!(Database(error!( ?user_id, ?device_id, @@ -40,23 +49,33 @@ pub async fn add_one_time_key( ))); } + if let Err(e) = one_time_key_value + .deserialize() + .map_err(Into::into) + { + debug_error!( + ?one_time_key_key, + ?one_time_key_value, + "Invalid one time key JSON submitted by client, skipping: {e}" + ); + + return Err(e); + } + let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); key.push(0xFF); // TODO: Use DeviceKeyId::to_string when it's available (and update everything, // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); + key.extend_from_slice(serde_json::to_string(one_time_key_key)?.as_bytes()); let count = self.services.globals.next_count(); self.db .onetimekeyid_onetimekeys .raw_put(key, Json(one_time_key_value)); + self.db .userid_lastonetimekeyupdate .raw_put(user_id, *count);