Split tuwunel_service::users into units.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
219
src/service/users/device.rs
Normal file
219
src/service/users/device.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
use ruma::{
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, UserId, api::client::device::Device,
|
||||
events::AnyToDeviceEvent, serde::Raw,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, implement,
|
||||
utils::{self, ReadyExt, stream::TryIgnore},
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
|
||||
|
||||
/// Adds a new device to a user.
|
||||
#[implement(super::Service)]
|
||||
pub async fn create_device(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
initial_device_display_name: Option<String>,
|
||||
client_ip: Option<String>,
|
||||
) -> Result {
|
||||
if !self.exists(user_id).await {
|
||||
return Err!(Request(InvalidParam(error!(
|
||||
"Called create_device for non-existent user {user_id}"
|
||||
))));
|
||||
}
|
||||
|
||||
let key = (user_id, device_id);
|
||||
let val = Device {
|
||||
device_id: device_id.into(),
|
||||
display_name: initial_device_display_name,
|
||||
last_seen_ip: client_ip,
|
||||
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
};
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
self.db.userdeviceid_metadata.put(key, Json(val));
|
||||
self.set_token(user_id, device_id, token).await
|
||||
}
|
||||
|
||||
/// Removes a device from a user.
|
||||
#[implement(super::Service)]
|
||||
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
let userdeviceid = (user_id, device_id);
|
||||
|
||||
// Remove tokens
|
||||
if let Ok(old_token) = self
|
||||
.db
|
||||
.userdeviceid_token
|
||||
.qry(&userdeviceid)
|
||||
.await
|
||||
{
|
||||
self.db.userdeviceid_token.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
}
|
||||
|
||||
// Remove todevice events
|
||||
let prefix = (user_id, device_id, Interfix);
|
||||
self.db
|
||||
.todeviceid_events
|
||||
.keys_prefix_raw(&prefix)
|
||||
.ignore_err()
|
||||
.ready_for_each(|key| self.db.todeviceid_events.remove(key))
|
||||
.await;
|
||||
|
||||
// TODO: Remove onetimekeys
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
self.db.userdeviceid_metadata.del(userdeviceid);
|
||||
self.mark_device_key_update(user_id).await;
|
||||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
#[implement(super::Service)]
|
||||
pub fn all_device_ids<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = &DeviceId> + Send + 'a {
|
||||
let prefix = (user_id, Interfix);
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
.keys_prefix(&prefix)
|
||||
.ignore_err()
|
||||
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn add_to_device_event(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
target_user_id: &UserId,
|
||||
target_device_id: &DeviceId,
|
||||
event_type: &str,
|
||||
content: serde_json::Value,
|
||||
) {
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
|
||||
let key = (target_user_id, target_device_id, count);
|
||||
self.db.todeviceid_events.put(
|
||||
key,
|
||||
Json(json!({
|
||||
"type": event_type,
|
||||
"sender": sender,
|
||||
"content": content,
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub fn get_to_device_events<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
device_id: &'a DeviceId,
|
||||
since: Option<u64>,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = Raw<AnyToDeviceEvent>> + Send + 'a {
|
||||
type Key<'a> = (&'a UserId, &'a DeviceId, u64);
|
||||
|
||||
let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1)));
|
||||
|
||||
self.db
|
||||
.todeviceid_events
|
||||
.stream_from(&from)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| {
|
||||
user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to)
|
||||
})
|
||||
.map(at!(1))
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn remove_to_device_events<Until>(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
until: Until,
|
||||
) where
|
||||
Until: Into<Option<u64>> + Send,
|
||||
{
|
||||
type Key<'a> = (&'a UserId, &'a DeviceId, u64);
|
||||
|
||||
let until = until.into().unwrap_or(u64::MAX);
|
||||
let from = (user_id, device_id, until);
|
||||
self.db
|
||||
.todeviceid_events
|
||||
.rev_keys_from(&from)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| {
|
||||
user_id == *user_id_ && device_id == *device_id_
|
||||
})
|
||||
.ready_for_each(|key: Key<'_>| {
|
||||
self.db.todeviceid_events.del(key);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn update_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device: &Device,
|
||||
) -> Result {
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
let key = (user_id, device_id);
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
.put(key, Json(device));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get device metadata.
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Device> {
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
.qry(&(user_id, device_id))
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result<u64> {
|
||||
self.db
|
||||
.userid_devicelistversion
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub fn all_devices_metadata<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = Device> + Send + 'a {
|
||||
let key = (user_id, Interfix);
|
||||
self.db
|
||||
.userdeviceid_metadata
|
||||
.stream_prefix(&key)
|
||||
.ignore_err()
|
||||
.map(|(_, val): (Ignore, Device)| val)
|
||||
}
|
||||
|
||||
//TODO: this is an ABA
|
||||
fn increment(db: &Arc<Map>, key: &[u8]) {
|
||||
let old = db.get_blocking(key);
|
||||
let new = utils::increment(old.ok().as_deref());
|
||||
db.insert(key, new);
|
||||
}
|
||||
515
src/service/users/keys.rs
Normal file
515
src/service/users/keys.rs
Normal file
@@ -0,0 +1,515 @@
|
||||
use std::{collections::BTreeMap, mem};
|
||||
|
||||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use ruma::{
|
||||
DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId, RoomId, UInt,
|
||||
UserId,
|
||||
api::client::error::ErrorKind,
|
||||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||
serde::Raw,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Error, Result, err, implement,
|
||||
utils::{ReadyExt, stream::TryIgnore, string::Unquoted},
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Ignore, Json};
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn add_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
one_time_key_key: &KeyId<OneTimeKeyAlgorithm, OneTimeKeyName>,
|
||||
one_time_key_value: &Raw<OneTimeKey>,
|
||||
) -> Result {
|
||||
// All devices have metadata
|
||||
// Only existing devices should be able to call this, but we shouldn't assert
|
||||
// either...
|
||||
let key = (user_id, device_id);
|
||||
if self
|
||||
.db
|
||||
.userdeviceid_metadata
|
||||
.qry(&key)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Database(error!(
|
||||
?user_id,
|
||||
?device_id,
|
||||
"User does not exist or device has no metadata."
|
||||
)));
|
||||
}
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xFF);
|
||||
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
|
||||
// because there are no wrapping quotation marks anymore)
|
||||
key.extend_from_slice(
|
||||
serde_json::to_string(one_time_key_key)
|
||||
.expect("DeviceKeyId::to_string always works")
|
||||
.as_bytes(),
|
||||
);
|
||||
|
||||
self.db
|
||||
.onetimekeyid_onetimekeys
|
||||
.raw_put(key, Json(one_time_key_value));
|
||||
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
self.db
|
||||
.userid_lastonetimekeyupdate
|
||||
.raw_put(user_id, count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 {
|
||||
self.db
|
||||
.userid_lastonetimekeyupdate
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn take_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
key_algorithm: &OneTimeKeyAlgorithm,
|
||||
) -> Result<(OwnedKeyId<OneTimeKeyAlgorithm, OneTimeKeyName>, Raw<OneTimeKey>)> {
|
||||
let count = self.services.globals.next_count()?.to_be_bytes();
|
||||
self.db
|
||||
.userid_lastonetimekeyupdate
|
||||
.insert(user_id, count);
|
||||
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xFF);
|
||||
prefix.push(b'"'); // Annoying quotation mark
|
||||
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
||||
prefix.push(b':');
|
||||
|
||||
let one_time_key = self
|
||||
.db
|
||||
.onetimekeyid_onetimekeys
|
||||
.raw_stream_prefix(&prefix)
|
||||
.ignore_err()
|
||||
.map(|(key, val)| {
|
||||
self.db.onetimekeyid_onetimekeys.remove(key);
|
||||
|
||||
let key = key
|
||||
.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))
|
||||
.unwrap();
|
||||
|
||||
let key = serde_json::from_slice(key)
|
||||
.map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))
|
||||
.unwrap();
|
||||
|
||||
let val = serde_json::from_slice(val)
|
||||
.map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))
|
||||
.unwrap();
|
||||
|
||||
(key, val)
|
||||
})
|
||||
.next()
|
||||
.await;
|
||||
|
||||
one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found"))))
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn count_one_time_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> BTreeMap<OneTimeKeyAlgorithm, UInt> {
|
||||
type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore);
|
||||
|
||||
let mut algorithm_counts = BTreeMap::<OneTimeKeyAlgorithm, _>::new();
|
||||
let query = (user_id, device_id);
|
||||
self.db
|
||||
.onetimekeyid_onetimekeys
|
||||
.stream_prefix(&query)
|
||||
.ignore_err()
|
||||
.ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| {
|
||||
let one_time_key_id: &OneTimeKeyId = device_key_id
|
||||
.as_str()
|
||||
.try_into()
|
||||
.expect("Invalid DeviceKeyID in database");
|
||||
|
||||
let count: &mut UInt = algorithm_counts
|
||||
.entry(one_time_key_id.algorithm())
|
||||
.or_default();
|
||||
|
||||
*count = count.saturating_add(1_u32.into());
|
||||
})
|
||||
.await;
|
||||
|
||||
algorithm_counts
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn add_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device_keys: &Raw<DeviceKeys>,
|
||||
) {
|
||||
let key = (user_id, device_id);
|
||||
|
||||
self.db.keyid_key.put(key, Json(device_keys));
|
||||
self.mark_device_key_update(user_id).await;
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn add_cross_signing_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
master_key: &Option<Raw<CrossSigningKey>>,
|
||||
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
notify: bool,
|
||||
) -> Result {
|
||||
// TODO: Check signatures
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
if let Some(master_key) = master_key {
|
||||
let (master_key_key, _) = parse_master_key(user_id, master_key)?;
|
||||
|
||||
self.db
|
||||
.keyid_key
|
||||
.insert(&master_key_key, master_key.json().get().as_bytes());
|
||||
|
||||
self.db
|
||||
.userid_masterkeyid
|
||||
.insert(user_id.as_bytes(), &master_key_key);
|
||||
}
|
||||
|
||||
// Self-signing key
|
||||
if let Some(self_signing_key) = self_signing_key {
|
||||
let mut self_signing_key_ids = self_signing_key
|
||||
.deserialize()
|
||||
.map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))?
|
||||
.keys
|
||||
.into_values();
|
||||
|
||||
let self_signing_key_id = self_signing_key_ids
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained no key.",
|
||||
))?;
|
||||
|
||||
if self_signing_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained more than one key.",
|
||||
));
|
||||
}
|
||||
|
||||
let mut self_signing_key_key = prefix.clone();
|
||||
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
|
||||
|
||||
self.db
|
||||
.keyid_key
|
||||
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes());
|
||||
|
||||
self.db
|
||||
.userid_selfsigningkeyid
|
||||
.insert(user_id.as_bytes(), &self_signing_key_key);
|
||||
}
|
||||
|
||||
// User-signing key
|
||||
if let Some(user_signing_key) = user_signing_key {
|
||||
let user_signing_key_id = parse_user_signing_key(user_signing_key)?;
|
||||
|
||||
let user_signing_key_key = (user_id, &user_signing_key_id);
|
||||
self.db
|
||||
.keyid_key
|
||||
.put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes());
|
||||
|
||||
self.db
|
||||
.userid_usersigningkeyid
|
||||
.raw_put(user_id, user_signing_key_key);
|
||||
}
|
||||
|
||||
if notify {
|
||||
self.mark_device_key_update(user_id).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn sign_key(
|
||||
&self,
|
||||
target_id: &UserId,
|
||||
key_id: &str,
|
||||
signature: (String, String),
|
||||
sender_id: &UserId,
|
||||
) -> Result {
|
||||
let key = (target_id, key_id);
|
||||
|
||||
let mut cross_signing_key: serde_json::Value = self
|
||||
.db
|
||||
.keyid_key
|
||||
.qry(&key)
|
||||
.await
|
||||
.map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))?
|
||||
.deserialized()
|
||||
.map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?;
|
||||
|
||||
let signatures = cross_signing_key
|
||||
.get_mut("signatures")
|
||||
.ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))?
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| {
|
||||
err!(Database(debug_warn!("key in keyid_key has invalid signatures field.")))
|
||||
})?
|
||||
.entry(sender_id.to_string())
|
||||
.or_insert_with(|| serde_json::Map::new().into());
|
||||
|
||||
signatures
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| {
|
||||
err!(Database(debug_warn!("signatures in keyid_key for a user is invalid.")))
|
||||
})?
|
||||
.insert(signature.0, signature.1.into());
|
||||
|
||||
let key = (target_id, key_id);
|
||||
self.db
|
||||
.keyid_key
|
||||
.put(key, Json(cross_signing_key));
|
||||
|
||||
self.mark_device_key_update(target_id).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[inline]
|
||||
pub fn keys_changed<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = &UserId> + Send + 'a {
|
||||
self.keys_changed_user_or_room(user_id.as_str(), from, to)
|
||||
.map(|(user_id, ..)| user_id)
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[inline]
|
||||
pub fn room_keys_changed<'a>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
|
||||
self.keys_changed_user_or_room(room_id.as_str(), from, to)
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
fn keys_changed_user_or_room<'a>(
|
||||
&'a self,
|
||||
user_or_room_id: &'a str,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Stream<Item = (&UserId, u64)> + Send + 'a {
|
||||
type KeyVal<'a> = ((&'a str, u64), &'a UserId);
|
||||
|
||||
let to = to.unwrap_or(u64::MAX);
|
||||
let start = (user_or_room_id, from.saturating_add(1));
|
||||
self.db
|
||||
.keychangeid_userid
|
||||
.stream_from(&start)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |((prefix, count), _): &KeyVal<'_>| {
|
||||
*prefix == user_or_room_id && *count <= to
|
||||
})
|
||||
.map(|((_, count), user_id): KeyVal<'_>| (user_id, count))
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn mark_device_key_update(&self, user_id: &UserId) {
|
||||
let count = self.services.globals.next_count().unwrap();
|
||||
|
||||
self.services
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
// Don't send key updates to unencrypted rooms
|
||||
.filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id))
|
||||
.ready_for_each(|room_id| {
|
||||
let key = (room_id, count);
|
||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||
})
|
||||
.await;
|
||||
|
||||
let key = (user_id, count);
|
||||
self.db.keychangeid_userid.put_raw(key, user_id);
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_device_keys<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Raw<DeviceKeys>> {
|
||||
let key_id = (user_id, device_id);
|
||||
self.db
|
||||
.keyid_key
|
||||
.qry(&key_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_key<F>(
|
||||
&self,
|
||||
key_id: &[u8],
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
{
|
||||
let key: serde_json::Value = self
|
||||
.db
|
||||
.keyid_key
|
||||
.get(key_id)
|
||||
.await
|
||||
.deserialized()?;
|
||||
|
||||
let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?;
|
||||
let raw_value = serde_json::value::to_raw_value(&cleaned)?;
|
||||
Ok(Raw::from_json(raw_value))
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_master_key<F>(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
{
|
||||
let key_id = self.db.userid_masterkeyid.get(user_id).await?;
|
||||
|
||||
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
|
||||
.await
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_self_signing_key<F>(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<Raw<CrossSigningKey>>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
{
|
||||
let key_id = self
|
||||
.db
|
||||
.userid_selfsigningkeyid
|
||||
.get(user_id)
|
||||
.await?;
|
||||
|
||||
self.get_key(&key_id, sender_user, user_id, allowed_signatures)
|
||||
.await
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result<Raw<CrossSigningKey>> {
|
||||
self.db
|
||||
.userid_usersigningkeyid
|
||||
.get(user_id)
|
||||
.and_then(|key_id| self.db.keyid_key.get(&*key_id))
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
pub fn parse_master_key(
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
|
||||
let master_key = master_key
|
||||
.deserialize()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
||||
let mut master_key_ids = master_key.keys.values();
|
||||
let master_key_id = master_key_ids
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
|
||||
if master_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Master key contained more than one key.",
|
||||
));
|
||||
}
|
||||
let mut master_key_key = prefix.clone();
|
||||
master_key_key.extend_from_slice(master_key_id.as_bytes());
|
||||
Ok((master_key_key, master_key))
|
||||
}
|
||||
|
||||
pub(super) fn parse_user_signing_key(user_signing_key: &Raw<CrossSigningKey>) -> Result<String> {
|
||||
let mut user_signing_key_ids = user_signing_key
|
||||
.deserialize()
|
||||
.map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))?
|
||||
.keys
|
||||
.into_values();
|
||||
|
||||
let user_signing_key_id = user_signing_key_ids
|
||||
.next()
|
||||
.ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?;
|
||||
|
||||
if user_signing_key_ids.next().is_some() {
|
||||
return Err!(Request(InvalidParam("User signing key contained more than one key.")));
|
||||
}
|
||||
|
||||
Ok(user_signing_key_id)
|
||||
}
|
||||
|
||||
/// Ensure that a user only sees signatures from themselves and the target user
|
||||
fn clean_signatures<F>(
|
||||
mut cross_signing_key: serde_json::Value,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &F,
|
||||
) -> Result<serde_json::Value>
|
||||
where
|
||||
F: Fn(&UserId) -> bool + Send + Sync,
|
||||
{
|
||||
if let Some(signatures) = cross_signing_key
|
||||
.get_mut("signatures")
|
||||
.and_then(|v| v.as_object_mut())
|
||||
{
|
||||
// Don't allocate for the full size of the current signatures, but require
|
||||
// at most one resize if nothing is dropped
|
||||
let new_capacity = signatures.len() / 2;
|
||||
for (user, signature) in
|
||||
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
|
||||
{
|
||||
let sid = <&UserId>::try_from(user.as_str())
|
||||
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
|
||||
if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
|
||||
signatures.insert(user, signature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cross_signing_key)
|
||||
}
|
||||
153
src/service/users/ldap.rs
Normal file
153
src/service/users/ldap.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
#![cfg(feature = "ldap")]
|
||||
|
||||
use itertools::Itertools;
|
||||
use ldap3::{LdapConnAsync, Scope, SearchEntry};
|
||||
use ruma::UserId;
|
||||
use tuwunel_core::{Result, debug, err, error, implement, result::LogErr, trace};
|
||||
|
||||
/// Performs a LDAP search for the given user.
|
||||
///
|
||||
/// Returns the list of matching users, with a boolean for each result set
|
||||
/// to true if the user is an admin.
|
||||
#[implement(super::Service)]
|
||||
pub async fn search_ldap(&self, user_id: &UserId) -> Result<Vec<(String, bool)>> {
|
||||
let localpart = user_id.localpart().to_owned();
|
||||
let lowercased_localpart = localpart.to_lowercase();
|
||||
|
||||
let config = &self.services.server.config.ldap;
|
||||
let uri = config
|
||||
.uri
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?;
|
||||
|
||||
debug!(?uri, "LDAP creating connection...");
|
||||
let (conn, mut ldap) = LdapConnAsync::new(uri.as_str())
|
||||
.await
|
||||
.map_err(|e| err!(Ldap(error!(?user_id, "LDAP connection setup error: {e}"))))?;
|
||||
|
||||
let driver = self.services.server.runtime().spawn(async move {
|
||||
match conn.drive().await {
|
||||
| Err(e) => error!("LDAP connection error: {e}"),
|
||||
| Ok(()) => debug!("LDAP connection completed."),
|
||||
}
|
||||
});
|
||||
|
||||
match (&config.bind_dn, &config.bind_password_file) {
|
||||
| (Some(bind_dn), Some(bind_password_file)) => {
|
||||
let bind_pw = String::from_utf8(std::fs::read(bind_password_file)?)?;
|
||||
ldap.simple_bind(bind_dn, bind_pw.trim())
|
||||
.await
|
||||
.and_then(ldap3::LdapResult::success)
|
||||
.map_err(|e| err!(Ldap(error!("LDAP bind error: {e}"))))?;
|
||||
},
|
||||
| (..) => {},
|
||||
}
|
||||
|
||||
let attr = [&config.uid_attribute, &config.name_attribute];
|
||||
|
||||
let user_filter = &config
|
||||
.filter
|
||||
.replace("{username}", &lowercased_localpart);
|
||||
|
||||
let (entries, _result) = ldap
|
||||
.search(&config.base_dn, Scope::Subtree, user_filter, &attr)
|
||||
.await
|
||||
.and_then(ldap3::SearchResult::success)
|
||||
.inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Search"))
|
||||
.map_err(|e| err!(Ldap(error!(?attr, ?user_filter, "LDAP search error: {e}"))))?;
|
||||
|
||||
let mut dns = entries
|
||||
.into_iter()
|
||||
.filter_map(|entry| {
|
||||
let search_entry = SearchEntry::construct(entry);
|
||||
debug!(?search_entry, "LDAP search entry");
|
||||
search_entry
|
||||
.attrs
|
||||
.get(&config.uid_attribute)
|
||||
.into_iter()
|
||||
.chain(search_entry.attrs.get(&config.name_attribute))
|
||||
.any(|ids| ids.contains(&localpart) || ids.contains(&lowercased_localpart))
|
||||
.then_some((search_entry.dn, false))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
if !config.admin_base_dn.is_empty() {
|
||||
let admin_base_dn = if config.admin_base_dn.is_empty() {
|
||||
&config.base_dn
|
||||
} else {
|
||||
&config.admin_base_dn
|
||||
};
|
||||
|
||||
let admin_filter = &config
|
||||
.admin_filter
|
||||
.replace("{username}", &lowercased_localpart);
|
||||
|
||||
let (admin_entries, _result) = ldap
|
||||
.search(admin_base_dn, Scope::Subtree, admin_filter, &attr)
|
||||
.await
|
||||
.and_then(ldap3::SearchResult::success)
|
||||
.inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Admin Search"))
|
||||
.map_err(|e| {
|
||||
err!(Ldap(error!(?attr, ?user_filter, "Ldap admin search error: {e}")))
|
||||
})?;
|
||||
|
||||
let mut admin_dns = admin_entries
|
||||
.into_iter()
|
||||
.filter_map(|entry| {
|
||||
let search_entry = SearchEntry::construct(entry);
|
||||
debug!(?search_entry, "LDAP search entry");
|
||||
search_entry
|
||||
.attrs
|
||||
.get(&config.uid_attribute)
|
||||
.into_iter()
|
||||
.chain(search_entry.attrs.get(&config.name_attribute))
|
||||
.any(|ids| ids.contains(&localpart) || ids.contains(&lowercased_localpart))
|
||||
.then_some((search_entry.dn, true))
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
dns.append(&mut admin_dns);
|
||||
}
|
||||
|
||||
ldap.unbind()
|
||||
.await
|
||||
.map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?;
|
||||
|
||||
driver.await.log_err().ok();
|
||||
|
||||
Ok(dns)
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn auth_ldap(&self, user_dn: &str, password: &str) -> Result {
|
||||
let config = &self.services.server.config.ldap;
|
||||
let uri = config
|
||||
.uri
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?;
|
||||
|
||||
debug!(?uri, "LDAP creating connection...");
|
||||
let (conn, mut ldap) = LdapConnAsync::new(uri.as_str())
|
||||
.await
|
||||
.map_err(|e| err!(Ldap(error!(?user_dn, "LDAP connection setup error: {e}"))))?;
|
||||
|
||||
let driver = self.services.server.runtime().spawn(async move {
|
||||
match conn.drive().await {
|
||||
| Err(e) => error!("LDAP connection error: {e}"),
|
||||
| Ok(()) => debug!("LDAP connection completed."),
|
||||
}
|
||||
});
|
||||
|
||||
ldap.simple_bind(user_dn, password)
|
||||
.await
|
||||
.and_then(ldap3::LdapResult::success)
|
||||
.map_err(|e| err!(Request(Forbidden(debug_error!("LDAP authentication error: {e}")))))?;
|
||||
|
||||
ldap.unbind()
|
||||
.await
|
||||
.map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?;
|
||||
|
||||
driver.await.log_err().ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
87
src/service/users/profile.rs
Normal file
87
src/service/users/profile.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use ruma::UserId;
|
||||
use tuwunel_core::{Result, implement, utils::stream::TryIgnore};
|
||||
use tuwunel_database::{Deserialized, Ignore, Interfix, Json};
|
||||
|
||||
/// Gets a specific user profile key
|
||||
#[implement(super::Service)]
|
||||
pub async fn profile_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
profile_key: &str,
|
||||
) -> Result<serde_json::Value> {
|
||||
let key = (user_id, profile_key);
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.qry(&key)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
/// Gets all the user's profile keys and values in an iterator
|
||||
#[implement(super::Service)]
|
||||
pub fn all_profile_keys<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
) -> impl Stream<Item = (String, serde_json::Value)> + 'a + Send {
|
||||
type KeyVal = ((Ignore, String), serde_json::Value);
|
||||
|
||||
let prefix = (user_id, Interfix);
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.stream_prefix(&prefix)
|
||||
.ignore_err()
|
||||
.map(|((_, key), val): KeyVal| (key, val))
|
||||
}
|
||||
|
||||
/// Sets a new profile key value, removes the key if value is None
|
||||
#[implement(super::Service)]
|
||||
pub fn set_profile_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
profile_key: &str,
|
||||
profile_key_value: Option<serde_json::Value>,
|
||||
) {
|
||||
// TODO: insert to the stable MSC4175 key when it's stable
|
||||
let key = (user_id, profile_key);
|
||||
|
||||
if let Some(value) = profile_key_value {
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.put(key, Json(value));
|
||||
} else {
|
||||
self.db.useridprofilekey_value.del(key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the timezone of a user.
|
||||
#[implement(super::Service)]
|
||||
pub async fn timezone(&self, user_id: &UserId) -> Result<String> {
|
||||
// TODO: transparently migrate unstable key usage to the stable key once MSC4133
|
||||
// and MSC4175 are stable, likely a remove/insert in this block.
|
||||
|
||||
// first check the unstable prefix then check the stable prefix
|
||||
let unstable_key = (user_id, "us.cloke.msc4175.tz");
|
||||
let stable_key = (user_id, "m.tz");
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.qry(&unstable_key)
|
||||
.or_else(|_| self.db.useridprofilekey_value.qry(&stable_key))
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
/// Sets a new timezone or removes it if timezone is None.
|
||||
#[implement(super::Service)]
|
||||
pub fn set_timezone(&self, user_id: &UserId, timezone: Option<String>) {
|
||||
// TODO: insert to the stable MSC4175 key when it's stable
|
||||
let key = (user_id, "us.cloke.msc4175.tz");
|
||||
|
||||
if let Some(timezone) = timezone {
|
||||
self.db
|
||||
.useridprofilekey_value
|
||||
.put_raw(key, &timezone);
|
||||
} else {
|
||||
self.db.useridprofilekey_value.del(key);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user