From 8a83c23537f8629ef4c0ea7fc5d4a2f51c4ebfb5 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 6 Jun 2025 09:33:50 +0000 Subject: [PATCH] Split tuwunel_service::users into units. Signed-off-by: Jason Volk --- src/service/users/device.rs | 219 ++++++++ src/service/users/keys.rs | 515 +++++++++++++++++++ src/service/users/ldap.rs | 153 ++++++ src/service/users/mod.rs | 959 +---------------------------------- src/service/users/profile.rs | 87 ++++ 5 files changed, 991 insertions(+), 942 deletions(-) create mode 100644 src/service/users/device.rs create mode 100644 src/service/users/keys.rs create mode 100644 src/service/users/ldap.rs create mode 100644 src/service/users/profile.rs diff --git a/src/service/users/device.rs b/src/service/users/device.rs new file mode 100644 index 00000000..96403ff5 --- /dev/null +++ b/src/service/users/device.rs @@ -0,0 +1,219 @@ +use std::sync::Arc; + +use futures::{Stream, StreamExt}; +use ruma::{ + DeviceId, MilliSecondsSinceUnixEpoch, UserId, api::client::device::Device, + events::AnyToDeviceEvent, serde::Raw, +}; +use serde_json::json; +use tuwunel_core::{ + Err, Result, at, implement, + utils::{self, ReadyExt, stream::TryIgnore}, +}; +use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map}; + +/// Adds a new device to a user. +#[implement(super::Service)] +pub async fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option, + client_ip: Option, +) -> Result { + if !self.exists(user_id).await { + return Err!(Request(InvalidParam(error!( + "Called create_device for non-existent user {user_id}" + )))); + } + + let key = (user_id, device_id); + let val = Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: client_ip, + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }; + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + self.db.userdeviceid_metadata.put(key, Json(val)); + self.set_token(user_id, device_id, token).await +} + +/// Removes a device from a user. +#[implement(super::Service)] +pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { + let userdeviceid = (user_id, device_id); + + // Remove tokens + if let Ok(old_token) = self + .db + .userdeviceid_token + .qry(&userdeviceid) + .await + { + self.db.userdeviceid_token.del(userdeviceid); + self.db.token_userdeviceid.remove(&old_token); + } + + // Remove todevice events + let prefix = (user_id, device_id, Interfix); + self.db + .todeviceid_events + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| self.db.todeviceid_events.remove(key)) + .await; + + // TODO: Remove onetimekeys + + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + self.db.userdeviceid_metadata.del(userdeviceid); + self.mark_device_key_update(user_id).await; +} + +/// Returns an iterator over all device ids of this user. +#[implement(super::Service)] +pub fn all_device_ids<'a>( + &'a self, + user_id: &'a UserId, +) -> impl Stream + Send + 'a { + let prefix = (user_id, Interfix); + self.db + .userdeviceid_metadata + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, device_id): (Ignore, &DeviceId)| device_id) +} + +#[implement(super::Service)] +pub async fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, +) { + let count = self.services.globals.next_count().unwrap(); + + let key = (target_user_id, target_device_id, count); + self.db.todeviceid_events.put( + key, + Json(json!({ + "type": event_type, + "sender": sender, + "content": content, + })), + ); +} + +#[implement(super::Service)] +pub fn get_to_device_events<'a>( + &'a self, + user_id: &'a UserId, + device_id: &'a DeviceId, + since: Option, + to: Option, +) -> impl Stream> + Send + 'a { + type Key<'a> = (&'a UserId, &'a DeviceId, u64); + + let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1))); + + self.db + .todeviceid_events + .stream_from(&from) + .ignore_err() + .ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| { + user_id == *user_id_ && device_id == *device_id_ && to.is_none_or(|to| *count <= to) + }) + .map(at!(1)) +} + +#[implement(super::Service)] +pub async fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: Until, +) where + Until: Into> + Send, +{ + type Key<'a> = (&'a UserId, &'a DeviceId, u64); + + let until = until.into().unwrap_or(u64::MAX); + let from = (user_id, device_id, until); + self.db + .todeviceid_events + .rev_keys_from(&from) + .ignore_err() + .ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| { + user_id == *user_id_ && device_id == *device_id_ + }) + .ready_for_each(|key: Key<'_>| { + self.db.todeviceid_events.del(key); + }) + .await; +} + +#[implement(super::Service)] +pub async fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, +) -> Result { + increment(&self.db.userid_devicelistversion, user_id.as_bytes()); + + let key = (user_id, device_id); + self.db + .userdeviceid_metadata + .put(key, Json(device)); + + Ok(()) +} + +/// Get device metadata. +#[implement(super::Service)] +pub async fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, +) -> Result { + self.db + .userdeviceid_metadata + .qry(&(user_id, device_id)) + .await + .deserialized() +} + +#[implement(super::Service)] +pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { + self.db + .userid_devicelistversion + .get(user_id) + .await + .deserialized() +} + +#[implement(super::Service)] +pub fn all_devices_metadata<'a>( + &'a self, + user_id: &'a UserId, +) -> impl Stream + Send + 'a { + let key = (user_id, Interfix); + self.db + .userdeviceid_metadata + .stream_prefix(&key) + .ignore_err() + .map(|(_, val): (Ignore, Device)| val) +} + +//TODO: this is an ABA +fn increment(db: &Arc, key: &[u8]) { + let old = db.get_blocking(key); + let new = utils::increment(old.ok().as_deref()); + db.insert(key, new); +} diff --git a/src/service/users/keys.rs b/src/service/users/keys.rs new file mode 100644 index 00000000..d940d200 --- /dev/null +++ b/src/service/users/keys.rs @@ -0,0 +1,515 @@ +use std::{collections::BTreeMap, mem}; + +use futures::{Stream, StreamExt, TryFutureExt}; +use ruma::{ + DeviceId, KeyId, OneTimeKeyAlgorithm, OneTimeKeyId, OneTimeKeyName, OwnedKeyId, RoomId, UInt, + UserId, + api::client::error::ErrorKind, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + serde::Raw, +}; +use tuwunel_core::{ + Err, Error, Result, err, implement, + utils::{ReadyExt, stream::TryIgnore, string::Unquoted}, +}; +use tuwunel_database::{Deserialized, Ignore, Json}; + +#[implement(super::Service)] +pub async fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &KeyId, + one_time_key_value: &Raw, +) -> Result { + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + let key = (user_id, device_id); + if self + .db + .userdeviceid_metadata + .qry(&key) + .await + .is_err() + { + return Err!(Database(error!( + ?user_id, + ?device_id, + "User does not exist or device has no metadata." + ))); + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), + ); + + self.db + .onetimekeyid_onetimekeys + .raw_put(key, Json(one_time_key_value)); + + let count = self.services.globals.next_count().unwrap(); + self.db + .userid_lastonetimekeyupdate + .raw_put(user_id, count); + + Ok(()) +} + +#[implement(super::Service)] +pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { + self.db + .userid_lastonetimekeyupdate + .get(user_id) + .await + .deserialized() + .unwrap_or(0) +} + +#[implement(super::Service)] +pub async fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &OneTimeKeyAlgorithm, +) -> Result<(OwnedKeyId, Raw)> { + let count = self.services.globals.next_count()?.to_be_bytes(); + self.db + .userid_lastonetimekeyupdate + .insert(user_id, count); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + let one_time_key = self + .db + .onetimekeyid_onetimekeys + .raw_stream_prefix(&prefix) + .ignore_err() + .map(|(key, val)| { + self.db.onetimekeyid_onetimekeys.remove(key); + + let key = key + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) + .unwrap(); + + let key = serde_json::from_slice(key) + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) + .unwrap(); + + let val = serde_json::from_slice(val) + .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) + .unwrap(); + + (key, val) + }) + .next() + .await; + + one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) +} + +#[implement(super::Service)] +pub async fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, +) -> BTreeMap { + type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); + + let mut algorithm_counts = BTreeMap::::new(); + let query = (user_id, device_id); + self.db + .onetimekeyid_onetimekeys + .stream_prefix(&query) + .ignore_err() + .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { + let one_time_key_id: &OneTimeKeyId = device_key_id + .as_str() + .try_into() + .expect("Invalid DeviceKeyID in database"); + + let count: &mut UInt = algorithm_counts + .entry(one_time_key_id.algorithm()) + .or_default(); + + *count = count.saturating_add(1_u32.into()); + }) + .await; + + algorithm_counts +} + +#[implement(super::Service)] +pub async fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw, +) { + let key = (user_id, device_id); + + self.db.keyid_key.put(key, Json(device_keys)); + self.mark_device_key_update(user_id).await; +} + +#[implement(super::Service)] +pub async fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Option>, + self_signing_key: &Option>, + user_signing_key: &Option>, + notify: bool, +) -> Result { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + if let Some(master_key) = master_key { + let (master_key_key, _) = parse_master_key(user_id, master_key)?; + + self.db + .keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes()); + + self.db + .userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key); + } + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained no key.", + ))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.db + .keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); + + self.db + .userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key); + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let user_signing_key_id = parse_user_signing_key(user_signing_key)?; + + let user_signing_key_key = (user_id, &user_signing_key_id); + self.db + .keyid_key + .put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes()); + + self.db + .userid_usersigningkeyid + .raw_put(user_id, user_signing_key_key); + } + + if notify { + self.mark_device_key_update(user_id).await; + } + + Ok(()) +} + +#[implement(super::Service)] +pub async fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, +) -> Result { + let key = (target_id, key_id); + + let mut cross_signing_key: serde_json::Value = self + .db + .keyid_key + .qry(&key) + .await + .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))? + .deserialized() + .map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| err!(Database(debug_warn!("key in keyid_key has no signatures field"))))? + .as_object_mut() + .ok_or_else(|| { + err!(Database(debug_warn!("key in keyid_key has invalid signatures field."))) + })? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| { + err!(Database(debug_warn!("signatures in keyid_key for a user is invalid."))) + })? + .insert(signature.0, signature.1.into()); + + let key = (target_id, key_id); + self.db + .keyid_key + .put(key, Json(cross_signing_key)); + + self.mark_device_key_update(target_id).await; + + Ok(()) +} + +#[implement(super::Service)] +#[inline] +pub fn keys_changed<'a>( + &'a self, + user_id: &'a UserId, + from: u64, + to: Option, +) -> impl Stream + Send + 'a { + self.keys_changed_user_or_room(user_id.as_str(), from, to) + .map(|(user_id, ..)| user_id) +} + +#[implement(super::Service)] +#[inline] +pub fn room_keys_changed<'a>( + &'a self, + room_id: &'a RoomId, + from: u64, + to: Option, +) -> impl Stream + Send + 'a { + self.keys_changed_user_or_room(room_id.as_str(), from, to) +} + +#[implement(super::Service)] +fn keys_changed_user_or_room<'a>( + &'a self, + user_or_room_id: &'a str, + from: u64, + to: Option, +) -> impl Stream + Send + 'a { + type KeyVal<'a> = ((&'a str, u64), &'a UserId); + + let to = to.unwrap_or(u64::MAX); + let start = (user_or_room_id, from.saturating_add(1)); + self.db + .keychangeid_userid + .stream_from(&start) + .ignore_err() + .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| { + *prefix == user_or_room_id && *count <= to + }) + .map(|((_, count), user_id): KeyVal<'_>| (user_id, count)) +} + +#[implement(super::Service)] +pub async fn mark_device_key_update(&self, user_id: &UserId) { + let count = self.services.globals.next_count().unwrap(); + + self.services + .state_cache + .rooms_joined(user_id) + // Don't send key updates to unencrypted rooms + .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) + .ready_for_each(|room_id| { + let key = (room_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); + }) + .await; + + let key = (user_id, count); + self.db.keychangeid_userid.put_raw(key, user_id); +} + +#[implement(super::Service)] +pub async fn get_device_keys<'a>( + &'a self, + user_id: &'a UserId, + device_id: &DeviceId, +) -> Result> { + let key_id = (user_id, device_id); + self.db + .keyid_key + .qry(&key_id) + .await + .deserialized() +} + +#[implement(super::Service)] +pub async fn get_key( + &self, + key_id: &[u8], + sender_user: Option<&UserId>, + user_id: &UserId, + allowed_signatures: &F, +) -> Result> +where + F: Fn(&UserId) -> bool + Send + Sync, +{ + let key: serde_json::Value = self + .db + .keyid_key + .get(key_id) + .await + .deserialized()?; + + let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; + let raw_value = serde_json::value::to_raw_value(&cleaned)?; + Ok(Raw::from_json(raw_value)) +} + +#[implement(super::Service)] +pub async fn get_master_key( + &self, + sender_user: Option<&UserId>, + user_id: &UserId, + allowed_signatures: &F, +) -> Result> +where + F: Fn(&UserId) -> bool + Send + Sync, +{ + let key_id = self.db.userid_masterkeyid.get(user_id).await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await +} + +#[implement(super::Service)] +pub async fn get_self_signing_key( + &self, + sender_user: Option<&UserId>, + user_id: &UserId, + allowed_signatures: &F, +) -> Result> +where + F: Fn(&UserId) -> bool + Send + Sync, +{ + let key_id = self + .db + .userid_selfsigningkeyid + .get(user_id) + .await?; + + self.get_key(&key_id, sender_user, user_id, allowed_signatures) + .await +} + +#[implement(super::Service)] +pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { + self.db + .userid_usersigningkeyid + .get(user_id) + .and_then(|key_id| self.db.keyid_key.get(&*key_id)) + .await + .deserialized() +} + +pub fn parse_master_key( + user_id: &UserId, + master_key: &Raw, +) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) +} + +pub(super) fn parse_user_signing_key(user_signing_key: &Raw) -> Result { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; + + if user_signing_key_ids.next().is_some() { + return Err!(Request(InvalidParam("User signing key contained more than one key."))); + } + + Ok(user_signing_key_id) +} + +/// Ensure that a user only sees signatures from themselves and the target user +fn clean_signatures( + mut cross_signing_key: serde_json::Value, + sender_user: Option<&UserId>, + user_id: &UserId, + allowed_signatures: &F, +) -> Result +where + F: Fn(&UserId) -> bool + Send + Sync, +{ + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in + mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) + { + let sid = <&UserId>::try_from(user.as_str()) + .map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { + signatures.insert(user, signature); + } + } + } + + Ok(cross_signing_key) +} diff --git a/src/service/users/ldap.rs b/src/service/users/ldap.rs new file mode 100644 index 00000000..e3449ce7 --- /dev/null +++ b/src/service/users/ldap.rs @@ -0,0 +1,153 @@ +#![cfg(feature = "ldap")] + +use itertools::Itertools; +use ldap3::{LdapConnAsync, Scope, SearchEntry}; +use ruma::UserId; +use tuwunel_core::{Result, debug, err, error, implement, result::LogErr, trace}; + +/// Performs a LDAP search for the given user. +/// +/// Returns the list of matching users, with a boolean for each result set +/// to true if the user is an admin. +#[implement(super::Service)] +pub async fn search_ldap(&self, user_id: &UserId) -> Result> { + let localpart = user_id.localpart().to_owned(); + let lowercased_localpart = localpart.to_lowercase(); + + let config = &self.services.server.config.ldap; + let uri = config + .uri + .as_ref() + .ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?; + + debug!(?uri, "LDAP creating connection..."); + let (conn, mut ldap) = LdapConnAsync::new(uri.as_str()) + .await + .map_err(|e| err!(Ldap(error!(?user_id, "LDAP connection setup error: {e}"))))?; + + let driver = self.services.server.runtime().spawn(async move { + match conn.drive().await { + | Err(e) => error!("LDAP connection error: {e}"), + | Ok(()) => debug!("LDAP connection completed."), + } + }); + + match (&config.bind_dn, &config.bind_password_file) { + | (Some(bind_dn), Some(bind_password_file)) => { + let bind_pw = String::from_utf8(std::fs::read(bind_password_file)?)?; + ldap.simple_bind(bind_dn, bind_pw.trim()) + .await + .and_then(ldap3::LdapResult::success) + .map_err(|e| err!(Ldap(error!("LDAP bind error: {e}"))))?; + }, + | (..) => {}, + } + + let attr = [&config.uid_attribute, &config.name_attribute]; + + let user_filter = &config + .filter + .replace("{username}", &lowercased_localpart); + + let (entries, _result) = ldap + .search(&config.base_dn, Scope::Subtree, user_filter, &attr) + .await + .and_then(ldap3::SearchResult::success) + .inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Search")) + .map_err(|e| err!(Ldap(error!(?attr, ?user_filter, "LDAP search error: {e}"))))?; + + let mut dns = entries + .into_iter() + .filter_map(|entry| { + let search_entry = SearchEntry::construct(entry); + debug!(?search_entry, "LDAP search entry"); + search_entry + .attrs + .get(&config.uid_attribute) + .into_iter() + .chain(search_entry.attrs.get(&config.name_attribute)) + .any(|ids| ids.contains(&localpart) || ids.contains(&lowercased_localpart)) + .then_some((search_entry.dn, false)) + }) + .collect_vec(); + + if !config.admin_base_dn.is_empty() { + let admin_base_dn = if config.admin_base_dn.is_empty() { + &config.base_dn + } else { + &config.admin_base_dn + }; + + let admin_filter = &config + .admin_filter + .replace("{username}", &lowercased_localpart); + + let (admin_entries, _result) = ldap + .search(admin_base_dn, Scope::Subtree, admin_filter, &attr) + .await + .and_then(ldap3::SearchResult::success) + .inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Admin Search")) + .map_err(|e| { + err!(Ldap(error!(?attr, ?user_filter, "Ldap admin search error: {e}"))) + })?; + + let mut admin_dns = admin_entries + .into_iter() + .filter_map(|entry| { + let search_entry = SearchEntry::construct(entry); + debug!(?search_entry, "LDAP search entry"); + search_entry + .attrs + .get(&config.uid_attribute) + .into_iter() + .chain(search_entry.attrs.get(&config.name_attribute)) + .any(|ids| ids.contains(&localpart) || ids.contains(&lowercased_localpart)) + .then_some((search_entry.dn, true)) + }) + .collect_vec(); + + dns.append(&mut admin_dns); + } + + ldap.unbind() + .await + .map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?; + + driver.await.log_err().ok(); + + Ok(dns) +} + +#[implement(super::Service)] +pub async fn auth_ldap(&self, user_dn: &str, password: &str) -> Result { + let config = &self.services.server.config.ldap; + let uri = config + .uri + .as_ref() + .ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?; + + debug!(?uri, "LDAP creating connection..."); + let (conn, mut ldap) = LdapConnAsync::new(uri.as_str()) + .await + .map_err(|e| err!(Ldap(error!(?user_dn, "LDAP connection setup error: {e}"))))?; + + let driver = self.services.server.runtime().spawn(async move { + match conn.drive().await { + | Err(e) => error!("LDAP connection error: {e}"), + | Ok(()) => debug!("LDAP connection completed."), + } + }); + + ldap.simple_bind(user_dn, password) + .await + .and_then(ldap3::LdapResult::success) + .map_err(|e| err!(Request(Forbidden(debug_error!("LDAP authentication error: {e}")))))?; + + ldap.unbind() + .await + .map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?; + + driver.await.log_err().ok(); + + Ok(()) +} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index c8c90792..5992f823 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,23 +1,23 @@ -use std::{collections::BTreeMap, mem, sync::Arc}; +mod device; +mod keys; +mod ldap; +mod profile; + +use std::sync::Arc; use futures::{Stream, StreamExt, TryFutureExt}; use ruma::{ - DeviceId, KeyId, MilliSecondsSinceUnixEpoch, OneTimeKeyAlgorithm, OneTimeKeyId, - OneTimeKeyName, OwnedDeviceId, OwnedKeyId, OwnedMxcUri, OwnedUserId, RoomId, UInt, UserId, - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{ - AnyToDeviceEvent, GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent, - }, - serde::Raw, + DeviceId, OwnedDeviceId, OwnedMxcUri, OwnedUserId, UserId, + api::client::filter::FilterDefinition, + events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent}, }; -use serde_json::json; use tuwunel_core::{ - Err, Error, Result, Server, at, debug_warn, err, trace, - utils::{self, ReadyExt, stream::TryIgnore, string::Unquoted}, + Err, Result, Server, debug_warn, err, trace, + utils::{self, ReadyExt, stream::TryIgnore}, }; -use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map}; +use tuwunel_database::{Deserialized, Json, Map}; +pub use self::keys::parse_master_key; use crate::{Dep, account_data, admin, globals, rooms}; pub struct Service { @@ -133,7 +133,7 @@ impl Service { user_id: &UserId, password: Option<&str>, origin: Option<&str>, - ) -> Result<()> { + ) -> Result { origin.map_or_else( || self.db.userid_origin.insert(user_id, "password"), |origin| self.db.userid_origin.insert(user_id, origin), @@ -142,7 +142,7 @@ impl Service { } /// Deactivate account - pub async fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + pub async fn deactivate_account(&self, user_id: &UserId) -> Result { // Remove all associated devices self.all_device_ids(user_id) .for_each(|device_id| self.remove_device(user_id, device_id)) @@ -243,7 +243,7 @@ impl Service { } /// Hash and set the user's password to the Argon2 hash - pub async fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub async fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result { // Cannot change the password of a LDAP user. There are two special cases : // - a `None` password can be used to deactivate a LDAP user // - a "*" password is used as the default password of an active LDAP user @@ -337,79 +337,6 @@ impl Service { } } - /// Adds a new device to a user. - pub async fn create_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - token: &str, - initial_device_display_name: Option, - client_ip: Option, - ) -> Result<()> { - if !self.exists(user_id).await { - return Err!(Request(InvalidParam(error!( - "Called create_device for non-existent user {user_id}" - )))); - } - - let key = (user_id, device_id); - let val = Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: client_ip, - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }; - - increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - self.db.userdeviceid_metadata.put(key, Json(val)); - self.set_token(user_id, device_id, token).await - } - - /// Removes a device from a user. - pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) { - let userdeviceid = (user_id, device_id); - - // Remove tokens - if let Ok(old_token) = self - .db - .userdeviceid_token - .qry(&userdeviceid) - .await - { - self.db.userdeviceid_token.del(userdeviceid); - self.db.token_userdeviceid.remove(&old_token); - } - - // Remove todevice events - let prefix = (user_id, device_id, Interfix); - self.db - .todeviceid_events - .keys_prefix_raw(&prefix) - .ignore_err() - .ready_for_each(|key| self.db.todeviceid_events.remove(key)) - .await; - - // TODO: Remove onetimekeys - - increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - - self.db.userdeviceid_metadata.del(userdeviceid); - self.mark_device_key_update(user_id).await; - } - - /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>( - &'a self, - user_id: &'a UserId, - ) -> impl Stream + Send + 'a { - let prefix = (user_id, Interfix); - self.db - .userdeviceid_metadata - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, device_id): (Ignore, &DeviceId)| device_id) - } - pub async fn get_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result { let key = (user_id, device_id); self.db @@ -420,12 +347,7 @@ impl Service { } /// Replaces the access token of one device. - pub async fn set_token( - &self, - user_id: &UserId, - device_id: &DeviceId, - token: &str, - ) -> Result<()> { + pub async fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result { let key = (user_id, device_id); if self .db @@ -454,536 +376,6 @@ impl Service { Ok(()) } - pub async fn add_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - one_time_key_key: &KeyId, - one_time_key_value: &Raw, - ) -> Result { - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert - // either... - let key = (user_id, device_id); - if self - .db - .userdeviceid_metadata - .qry(&key) - .await - .is_err() - { - return Err!(Database(error!( - ?user_id, - ?device_id, - "User does not exist or device has no metadata." - ))); - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xFF); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xFF); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.db - .onetimekeyid_onetimekeys - .raw_put(key, Json(one_time_key_value)); - - let count = self.services.globals.next_count().unwrap(); - self.db - .userid_lastonetimekeyupdate - .raw_put(user_id, count); - - Ok(()) - } - - pub async fn last_one_time_keys_update(&self, user_id: &UserId) -> u64 { - self.db - .userid_lastonetimekeyupdate - .get(user_id) - .await - .deserialized() - .unwrap_or(0) - } - - pub async fn take_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - key_algorithm: &OneTimeKeyAlgorithm, - ) -> Result<(OwnedKeyId, Raw)> { - let count = self.services.globals.next_count()?.to_be_bytes(); - self.db - .userid_lastonetimekeyupdate - .insert(user_id, count); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xFF); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - let one_time_key = self - .db - .onetimekeyid_onetimekeys - .raw_stream_prefix(&prefix) - .ignore_err() - .map(|(key, val)| { - self.db.onetimekeyid_onetimekeys.remove(key); - - let key = key - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid."))) - .unwrap(); - - let key = serde_json::from_slice(key) - .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}"))) - .unwrap(); - - let val = serde_json::from_slice(val) - .map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}"))) - .unwrap(); - - (key, val) - }) - .next() - .await; - - one_time_key.ok_or_else(|| err!(Request(NotFound("No one-time-key found")))) - } - - pub async fn count_one_time_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> BTreeMap { - type KeyVal<'a> = ((Ignore, Ignore, &'a Unquoted), Ignore); - - let mut algorithm_counts = BTreeMap::::new(); - let query = (user_id, device_id); - self.db - .onetimekeyid_onetimekeys - .stream_prefix(&query) - .ignore_err() - .ready_for_each(|((Ignore, Ignore, device_key_id), Ignore): KeyVal<'_>| { - let one_time_key_id: &OneTimeKeyId = device_key_id - .as_str() - .try_into() - .expect("Invalid DeviceKeyID in database"); - - let count: &mut UInt = algorithm_counts - .entry(one_time_key_id.algorithm()) - .or_default(); - - *count = count.saturating_add(1_u32.into()); - }) - .await; - - algorithm_counts - } - - pub async fn add_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - device_keys: &Raw, - ) { - let key = (user_id, device_id); - - self.db.keyid_key.put(key, Json(device_keys)); - self.mark_device_key_update(user_id).await; - } - - pub async fn add_cross_signing_keys( - &self, - user_id: &UserId, - master_key: &Option>, - self_signing_key: &Option>, - user_signing_key: &Option>, - notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - if let Some(master_key) = master_key { - let (master_key_key, _) = parse_master_key(user_id, master_key)?; - - self.db - .keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes()); - - self.db - .userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key); - } - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|e| err!(Request(InvalidParam("Invalid self signing key: {e:?}"))))? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids - .next() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained no key.", - ))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.db - .keyid_key - .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes()); - - self.db - .userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key); - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let user_signing_key_id = parse_user_signing_key(user_signing_key)?; - - let user_signing_key_key = (user_id, &user_signing_key_id); - self.db - .keyid_key - .put_raw(user_signing_key_key, user_signing_key.json().get().as_bytes()); - - self.db - .userid_usersigningkeyid - .raw_put(user_id, user_signing_key_key); - } - - if notify { - self.mark_device_key_update(user_id).await; - } - - Ok(()) - } - - pub async fn sign_key( - &self, - target_id: &UserId, - key_id: &str, - signature: (String, String), - sender_id: &UserId, - ) -> Result { - let key = (target_id, key_id); - - let mut cross_signing_key: serde_json::Value = self - .db - .keyid_key - .qry(&key) - .await - .map_err(|_| err!(Request(InvalidParam("Tried to sign nonexistent key"))))? - .deserialized() - .map_err(|e| err!(Database(debug_warn!("key in keyid_key is invalid: {e:?}"))))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| { - err!(Database(debug_warn!("key in keyid_key has no signatures field"))) - })? - .as_object_mut() - .ok_or_else(|| { - err!(Database(debug_warn!("key in keyid_key has invalid signatures field."))) - })? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| { - err!(Database(debug_warn!("signatures in keyid_key for a user is invalid."))) - })? - .insert(signature.0, signature.1.into()); - - let key = (target_id, key_id); - self.db - .keyid_key - .put(key, Json(cross_signing_key)); - - self.mark_device_key_update(target_id).await; - - Ok(()) - } - - #[inline] - pub fn keys_changed<'a>( - &'a self, - user_id: &'a UserId, - from: u64, - to: Option, - ) -> impl Stream + Send + 'a { - self.keys_changed_user_or_room(user_id.as_str(), from, to) - .map(|(user_id, ..)| user_id) - } - - #[inline] - pub fn room_keys_changed<'a>( - &'a self, - room_id: &'a RoomId, - from: u64, - to: Option, - ) -> impl Stream + Send + 'a { - self.keys_changed_user_or_room(room_id.as_str(), from, to) - } - - fn keys_changed_user_or_room<'a>( - &'a self, - user_or_room_id: &'a str, - from: u64, - to: Option, - ) -> impl Stream + Send + 'a { - type KeyVal<'a> = ((&'a str, u64), &'a UserId); - - let to = to.unwrap_or(u64::MAX); - let start = (user_or_room_id, from.saturating_add(1)); - self.db - .keychangeid_userid - .stream_from(&start) - .ignore_err() - .ready_take_while(move |((prefix, count), _): &KeyVal<'_>| { - *prefix == user_or_room_id && *count <= to - }) - .map(|((_, count), user_id): KeyVal<'_>| (user_id, count)) - } - - pub async fn mark_device_key_update(&self, user_id: &UserId) { - let count = self.services.globals.next_count().unwrap(); - - self.services - .state_cache - .rooms_joined(user_id) - // Don't send key updates to unencrypted rooms - .filter(|room_id| self.services.state_accessor.is_encrypted_room(room_id)) - .ready_for_each(|room_id| { - let key = (room_id, count); - self.db.keychangeid_userid.put_raw(key, user_id); - }) - .await; - - let key = (user_id, count); - self.db.keychangeid_userid.put_raw(key, user_id); - } - - pub async fn get_device_keys<'a>( - &'a self, - user_id: &'a UserId, - device_id: &DeviceId, - ) -> Result> { - let key_id = (user_id, device_id); - self.db - .keyid_key - .qry(&key_id) - .await - .deserialized() - } - - pub async fn get_key( - &self, - key_id: &[u8], - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &F, - ) -> Result> - where - F: Fn(&UserId) -> bool + Send + Sync, - { - let key: serde_json::Value = self - .db - .keyid_key - .get(key_id) - .await - .deserialized()?; - - let cleaned = clean_signatures(key, sender_user, user_id, allowed_signatures)?; - let raw_value = serde_json::value::to_raw_value(&cleaned)?; - Ok(Raw::from_json(raw_value)) - } - - pub async fn get_master_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &F, - ) -> Result> - where - F: Fn(&UserId) -> bool + Send + Sync, - { - let key_id = self.db.userid_masterkeyid.get(user_id).await?; - - self.get_key(&key_id, sender_user, user_id, allowed_signatures) - .await - } - - pub async fn get_self_signing_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &F, - ) -> Result> - where - F: Fn(&UserId) -> bool + Send + Sync, - { - let key_id = self - .db - .userid_selfsigningkeyid - .get(user_id) - .await?; - - self.get_key(&key_id, sender_user, user_id, allowed_signatures) - .await - } - - pub async fn get_user_signing_key(&self, user_id: &UserId) -> Result> { - self.db - .userid_usersigningkeyid - .get(user_id) - .and_then(|key_id| self.db.keyid_key.get(&*key_id)) - .await - .deserialized() - } - - pub async fn add_to_device_event( - &self, - sender: &UserId, - target_user_id: &UserId, - target_device_id: &DeviceId, - event_type: &str, - content: serde_json::Value, - ) { - let count = self.services.globals.next_count().unwrap(); - - let key = (target_user_id, target_device_id, count); - self.db.todeviceid_events.put( - key, - Json(json!({ - "type": event_type, - "sender": sender, - "content": content, - })), - ); - } - - pub fn get_to_device_events<'a>( - &'a self, - user_id: &'a UserId, - device_id: &'a DeviceId, - since: Option, - to: Option, - ) -> impl Stream> + Send + 'a { - type Key<'a> = (&'a UserId, &'a DeviceId, u64); - - let from = (user_id, device_id, since.map_or(0, |since| since.saturating_add(1))); - - self.db - .todeviceid_events - .stream_from(&from) - .ignore_err() - .ready_take_while(move |((user_id_, device_id_, count), _): &(Key<'_>, _)| { - user_id == *user_id_ - && device_id == *device_id_ - && to.is_none_or(|to| *count <= to) - }) - .map(at!(1)) - } - - pub async fn remove_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - until: Until, - ) where - Until: Into> + Send, - { - type Key<'a> = (&'a UserId, &'a DeviceId, u64); - - let until = until.into().unwrap_or(u64::MAX); - let from = (user_id, device_id, until); - self.db - .todeviceid_events - .rev_keys_from(&from) - .ignore_err() - .ready_take_while(move |(user_id_, device_id_, _): &Key<'_>| { - user_id == *user_id_ && device_id == *device_id_ - }) - .ready_for_each(|key: Key<'_>| { - self.db.todeviceid_events.del(key); - }) - .await; - } - - pub async fn update_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - device: &Device, - ) -> Result<()> { - increment(&self.db.userid_devicelistversion, user_id.as_bytes()); - - let key = (user_id, device_id); - self.db - .userdeviceid_metadata - .put(key, Json(device)); - - Ok(()) - } - - /// Get device metadata. - pub async fn get_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result { - self.db - .userdeviceid_metadata - .qry(&(user_id, device_id)) - .await - .deserialized() - } - - pub async fn get_devicelist_version(&self, user_id: &UserId) -> Result { - self.db - .userid_devicelistversion - .get(user_id) - .await - .deserialized() - } - - pub fn all_devices_metadata<'a>( - &'a self, - user_id: &'a UserId, - ) -> impl Stream + Send + 'a { - let key = (user_id, Interfix); - self.db - .userdeviceid_metadata - .stream_prefix(&key) - .ignore_err() - .map(|(_, val): (Ignore, Device)| val) - } - /// Creates a new sync filter. Returns the filter id. pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> String { let filter_id = utils::random_string(4); @@ -1100,330 +492,13 @@ impl Service { Ok(user_id) } - /// Gets a specific user profile key - pub async fn profile_key( - &self, - user_id: &UserId, - profile_key: &str, - ) -> Result { - let key = (user_id, profile_key); - self.db - .useridprofilekey_value - .qry(&key) - .await - .deserialized() - } - - /// Gets all the user's profile keys and values in an iterator - pub fn all_profile_keys<'a>( - &'a self, - user_id: &'a UserId, - ) -> impl Stream + 'a + Send { - type KeyVal = ((Ignore, String), serde_json::Value); - - let prefix = (user_id, Interfix); - self.db - .useridprofilekey_value - .stream_prefix(&prefix) - .ignore_err() - .map(|((_, key), val): KeyVal| (key, val)) - } - - /// Sets a new profile key value, removes the key if value is None - pub fn set_profile_key( - &self, - user_id: &UserId, - profile_key: &str, - profile_key_value: Option, - ) { - // TODO: insert to the stable MSC4175 key when it's stable - let key = (user_id, profile_key); - - if let Some(value) = profile_key_value { - self.db - .useridprofilekey_value - .put(key, Json(value)); - } else { - self.db.useridprofilekey_value.del(key); - } - } - - /// Get the timezone of a user. - pub async fn timezone(&self, user_id: &UserId) -> Result { - // TODO: transparently migrate unstable key usage to the stable key once MSC4133 - // and MSC4175 are stable, likely a remove/insert in this block. - - // first check the unstable prefix then check the stable prefix - let unstable_key = (user_id, "us.cloke.msc4175.tz"); - let stable_key = (user_id, "m.tz"); - self.db - .useridprofilekey_value - .qry(&unstable_key) - .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) - .await - .deserialized() - } - - /// Sets a new timezone or removes it if timezone is None. - pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { - // TODO: insert to the stable MSC4175 key when it's stable - let key = (user_id, "us.cloke.msc4175.tz"); - - if let Some(timezone) = timezone { - self.db - .useridprofilekey_value - .put_raw(key, &timezone); - } else { - self.db.useridprofilekey_value.del(key); - } - } - - /// Performs a LDAP search for the given user. - /// - /// Returns the list of matching users, with a boolean for each result set - /// to true if the user is an admin. - #[cfg(feature = "ldap")] - pub async fn search_ldap(&self, user_id: &UserId) -> Result> { - use itertools::Itertools; - use ldap3::{LdapConnAsync, Scope, SearchEntry}; - use tuwunel_core::{debug, error, result::LogErr}; - - let localpart = user_id.localpart().to_owned(); - let lowercased_localpart = localpart.to_lowercase(); - - let config = &self.services.server.config.ldap; - let uri = config - .uri - .as_ref() - .ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?; - - debug!(?uri, "LDAP creating connection..."); - let (conn, mut ldap) = LdapConnAsync::new(uri.as_str()) - .await - .map_err(|e| err!(Ldap(error!(?user_id, "LDAP connection setup error: {e}"))))?; - - let driver = self.services.server.runtime().spawn(async move { - match conn.drive().await { - | Err(e) => error!("LDAP connection error: {e}"), - | Ok(()) => debug!("LDAP connection completed."), - } - }); - - match (&config.bind_dn, &config.bind_password_file) { - | (Some(bind_dn), Some(bind_password_file)) => { - let bind_pw = String::from_utf8(std::fs::read(bind_password_file)?)?; - ldap.simple_bind(bind_dn, bind_pw.trim()) - .await - .and_then(ldap3::LdapResult::success) - .map_err(|e| err!(Ldap(error!("LDAP bind error: {e}"))))?; - }, - | (..) => {}, - } - - let attr = [&config.uid_attribute, &config.name_attribute]; - - let user_filter = &config - .filter - .replace("{username}", &lowercased_localpart); - - let (entries, _result) = ldap - .search(&config.base_dn, Scope::Subtree, user_filter, &attr) - .await - .and_then(ldap3::SearchResult::success) - .inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Search")) - .map_err(|e| err!(Ldap(error!(?attr, ?user_filter, "LDAP search error: {e}"))))?; - - let mut dns = entries - .into_iter() - .filter_map(|entry| { - let search_entry = SearchEntry::construct(entry); - debug!(?search_entry, "LDAP search entry"); - search_entry - .attrs - .get(&config.uid_attribute) - .into_iter() - .chain(search_entry.attrs.get(&config.name_attribute)) - .any(|ids| ids.contains(&localpart) || ids.contains(&lowercased_localpart)) - .then_some((search_entry.dn, false)) - }) - .collect_vec(); - - if !config.admin_base_dn.is_empty() { - let admin_base_dn = if config.admin_base_dn.is_empty() { - &config.base_dn - } else { - &config.admin_base_dn - }; - - let admin_filter = &config - .admin_filter - .replace("{username}", &lowercased_localpart); - - let (admin_entries, _result) = ldap - .search(admin_base_dn, Scope::Subtree, admin_filter, &attr) - .await - .and_then(ldap3::SearchResult::success) - .inspect(|(entries, result)| trace!(?entries, ?result, "LDAP Admin Search")) - .map_err(|e| { - err!(Ldap(error!(?attr, ?user_filter, "Ldap admin search error: {e}"))) - })?; - - let mut admin_dns = admin_entries - .into_iter() - .filter_map(|entry| { - let search_entry = SearchEntry::construct(entry); - debug!(?search_entry, "LDAP search entry"); - search_entry - .attrs - .get(&config.uid_attribute) - .into_iter() - .chain(search_entry.attrs.get(&config.name_attribute)) - .any(|ids| { - ids.contains(&localpart) || ids.contains(&lowercased_localpart) - }) - .then_some((search_entry.dn, true)) - }) - .collect_vec(); - - dns.append(&mut admin_dns); - } - - ldap.unbind() - .await - .map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?; - - driver.await.log_err().ok(); - - Ok(dns) - } - #[cfg(not(feature = "ldap"))] pub async fn search_ldap(&self, _user_id: &UserId) -> Result> { Err!(FeatureDisabled("ldap")) } - #[cfg(feature = "ldap")] - pub async fn auth_ldap(&self, user_dn: &str, password: &str) -> Result { - use ldap3::LdapConnAsync; - use tuwunel_core::{debug, error, result::LogErr}; - - let config = &self.services.server.config.ldap; - let uri = config - .uri - .as_ref() - .ok_or_else(|| err!(Ldap(error!("LDAP URI is not configured."))))?; - - debug!(?uri, "LDAP creating connection..."); - let (conn, mut ldap) = LdapConnAsync::new(uri.as_str()) - .await - .map_err(|e| err!(Ldap(error!(?user_dn, "LDAP connection setup error: {e}"))))?; - - let driver = self.services.server.runtime().spawn(async move { - match conn.drive().await { - | Err(e) => error!("LDAP connection error: {e}"), - | Ok(()) => debug!("LDAP connection completed."), - } - }); - - ldap.simple_bind(user_dn, password) - .await - .and_then(ldap3::LdapResult::success) - .map_err(|e| { - err!(Request(Forbidden(debug_error!("LDAP authentication error: {e}")))) - })?; - - ldap.unbind() - .await - .map_err(|e| err!(Ldap(error!("LDAP unbind error: {e}"))))?; - - driver.await.log_err().ok(); - - Ok(()) - } - #[cfg(not(feature = "ldap"))] pub async fn auth_ldap(&self, _user_dn: &str, _password: &str) -> Result { Err!(FeatureDisabled("ldap")) } } - -pub fn parse_master_key( - user_id: &UserId, - master_key: &Raw, -) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xFF); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids - .next() - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) -} - -pub fn parse_user_signing_key(user_signing_key: &Raw) -> Result { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| err!(Request(InvalidParam("Invalid user signing key"))))? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids - .next() - .ok_or(err!(Request(InvalidParam("User signing key contained no key."))))?; - - if user_signing_key_ids.next().is_some() { - return Err!(Request(InvalidParam("User signing key contained more than one key."))); - } - - Ok(user_signing_key_id) -} - -/// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures( - mut cross_signing_key: serde_json::Value, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &F, -) -> Result -where - F: Fn(&UserId) -> bool + Send + Sync, -{ - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let sid = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { - signatures.insert(user, signature); - } - } - } - - Ok(cross_signing_key) -} - -//TODO: this is an ABA -fn increment(db: &Arc, key: &[u8]) { - let old = db.get_blocking(key); - let new = utils::increment(old.ok().as_deref()); - db.insert(key, new); -} diff --git a/src/service/users/profile.rs b/src/service/users/profile.rs new file mode 100644 index 00000000..e0057a85 --- /dev/null +++ b/src/service/users/profile.rs @@ -0,0 +1,87 @@ +use futures::{Stream, StreamExt, TryFutureExt}; +use ruma::UserId; +use tuwunel_core::{Result, implement, utils::stream::TryIgnore}; +use tuwunel_database::{Deserialized, Ignore, Interfix, Json}; + +/// Gets a specific user profile key +#[implement(super::Service)] +pub async fn profile_key( + &self, + user_id: &UserId, + profile_key: &str, +) -> Result { + let key = (user_id, profile_key); + self.db + .useridprofilekey_value + .qry(&key) + .await + .deserialized() +} + +/// Gets all the user's profile keys and values in an iterator +#[implement(super::Service)] +pub fn all_profile_keys<'a>( + &'a self, + user_id: &'a UserId, +) -> impl Stream + 'a + Send { + type KeyVal = ((Ignore, String), serde_json::Value); + + let prefix = (user_id, Interfix); + self.db + .useridprofilekey_value + .stream_prefix(&prefix) + .ignore_err() + .map(|((_, key), val): KeyVal| (key, val)) +} + +/// Sets a new profile key value, removes the key if value is None +#[implement(super::Service)] +pub fn set_profile_key( + &self, + user_id: &UserId, + profile_key: &str, + profile_key_value: Option, +) { + // TODO: insert to the stable MSC4175 key when it's stable + let key = (user_id, profile_key); + + if let Some(value) = profile_key_value { + self.db + .useridprofilekey_value + .put(key, Json(value)); + } else { + self.db.useridprofilekey_value.del(key); + } +} + +/// Get the timezone of a user. +#[implement(super::Service)] +pub async fn timezone(&self, user_id: &UserId) -> Result { + // TODO: transparently migrate unstable key usage to the stable key once MSC4133 + // and MSC4175 are stable, likely a remove/insert in this block. + + // first check the unstable prefix then check the stable prefix + let unstable_key = (user_id, "us.cloke.msc4175.tz"); + let stable_key = (user_id, "m.tz"); + self.db + .useridprofilekey_value + .qry(&unstable_key) + .or_else(|_| self.db.useridprofilekey_value.qry(&stable_key)) + .await + .deserialized() +} + +/// Sets a new timezone or removes it if timezone is None. +#[implement(super::Service)] +pub fn set_timezone(&self, user_id: &UserId, timezone: Option) { + // TODO: insert to the stable MSC4175 key when it's stable + let key = (user_id, "us.cloke.msc4175.tz"); + + if let Some(timezone) = timezone { + self.db + .useridprofilekey_value + .put_raw(key, &timezone); + } else { + self.db.useridprofilekey_value.del(key); + } +}