Add bulk one_time_keys adder to interface.

Add device_exists to interface.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-10-27 05:49:55 +00:00
parent a30c043386
commit 8959d9e2c1
3 changed files with 50 additions and 37 deletions

View File

@@ -37,30 +37,16 @@ pub(crate) async fn upload_keys_route(
) -> Result<upload_keys::v3::Response> { ) -> Result<upload_keys::v3::Response> {
let (sender_user, sender_device) = body.sender(); let (sender_user, sender_device) = body.sender();
for (key_id, one_time_key) in body let one_time_keys = body
.one_time_keys .one_time_keys
.iter() .iter()
.take(services.config.one_time_key_limit) .take(services.config.one_time_key_limit)
{ .map(|(id, val)| (id.as_ref(), val));
if one_time_key
.deserialize()
.inspect_err(|e| {
debug_warn!(
?key_id,
?one_time_key,
"Invalid one time key JSON submitted by client, skipping: {e}"
);
})
.is_err()
{
continue;
}
services services
.users .users
.add_one_time_key(sender_user, sender_device, key_id, one_time_key) .add_one_time_keys(sender_user, sender_device, one_time_keys)
.await?; .await?;
}
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
let deser_device_keys = device_keys.deserialize().map_err(|e| { let deser_device_keys = device_keys.deserialize().map_err(|e| {

View File

@@ -391,6 +391,14 @@ pub async fn get_device_metadata(
.deserialized() .deserialized()
} }
#[implement(super::Service)]
pub async fn device_exists(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
self.db
.userdeviceid_metadata
.contains(&(user_id, device_id))
.await
}
#[implement(super::Service)] #[implement(super::Service)]
pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> { pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
self.db self.db

View File

@@ -9,11 +9,30 @@ use ruma::{
serde::Raw, serde::Raw,
}; };
use tuwunel_core::{ use tuwunel_core::{
Err, Error, Result, err, implement, Err, Error, Result, debug_error, err, implement,
utils::{ReadyExt, stream::TryIgnore, string::Unquoted}, utils::{ReadyExt, stream::TryIgnore, string::Unquoted},
}; };
use tuwunel_database::{Deserialized, Ignore, Json}; use tuwunel_database::{Deserialized, Ignore, Json};
#[implement(super::Service)]
pub async fn add_one_time_keys<'a, Keys>(
&self,
user_id: &UserId,
device_id: &DeviceId,
keys: Keys,
) -> Result
where
Keys: Iterator<Item = (&'a OneTimeKeyId, &'a Raw<OneTimeKey>)> + Send + 'a,
{
for (id, key) in keys {
self.add_one_time_key(user_id, device_id, id, key)
.await
.ok();
}
Ok(())
}
#[implement(super::Service)] #[implement(super::Service)]
pub async fn add_one_time_key( pub async fn add_one_time_key(
&self, &self,
@@ -22,17 +41,7 @@ pub async fn add_one_time_key(
one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
one_time_key_value: &Raw<OneTimeKey>, one_time_key_value: &Raw<OneTimeKey>,
) -> Result { ) -> Result {
// All devices have metadata if !self.device_exists(user_id, device_id).await {
// Only existing devices should be able to call this, but we shouldn't assert
// either...
let key = (user_id, device_id);
if self
.db
.userdeviceid_metadata
.qry(&key)
.await
.is_err()
{
return Err!(Database(error!( return Err!(Database(error!(
?user_id, ?user_id,
?device_id, ?device_id,
@@ -40,23 +49,33 @@ pub async fn add_one_time_key(
))); )));
} }
if let Err(e) = one_time_key_value
.deserialize()
.map_err(Into::into)
{
debug_error!(
?one_time_key_key,
?one_time_key_value,
"Invalid one time key JSON submitted by client, skipping: {e}"
);
return Err(e);
}
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
key.push(0xFF); key.push(0xFF);
// TODO: Use DeviceKeyId::to_string when it's available (and update everything, // TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore) // because there are no wrapping quotation marks anymore)
key.extend_from_slice( key.extend_from_slice(serde_json::to_string(one_time_key_key)?.as_bytes());
serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
);
let count = self.services.globals.next_count(); let count = self.services.globals.next_count();
self.db self.db
.onetimekeyid_onetimekeys .onetimekeyid_onetimekeys
.raw_put(key, Json(one_time_key_value)); .raw_put(key, Json(one_time_key_value));
self.db self.db
.userid_lastonetimekeyupdate .userid_lastonetimekeyupdate
.raw_put(user_id, *count); .raw_put(user_id, *count);