Implement SSO/OIDC support. (closes #7)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-12-23 14:55:29 +00:00
parent d665a34f30
commit 11309062a2
23 changed files with 1959 additions and 27 deletions

View File

@@ -110,6 +110,7 @@ ruma.workspace = true
rustls.workspace = true
rustyline-async.workspace = true
rustyline-async.optional = true
serde_html_form.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_yaml.workspace = true

View File

@@ -21,6 +21,7 @@ pub struct Service {
pub sender: ClientLazylock,
pub appservice: ClientLazylock,
pub pusher: ClientLazylock,
pub oauth: ClientLazylock,
pub cidr_range_denylist: Vec<IPAddress>,
}
@@ -112,6 +113,11 @@ impl crate::Service for Service {
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
.redirect(redirect::Policy::limited(2))),
oauth: create_client!(config, services; base(config)?
.dns_resolver2(Arc::clone(&services.resolver.resolver))
.redirect(redirect::Policy::limited(0))
.pool_max_idle_per_host(1)),
cidr_range_denylist: config
.ip_range_denylist
.iter()

View File

@@ -19,6 +19,7 @@ pub mod globals;
pub mod key_backups;
pub mod media;
pub mod membership;
pub mod oauth;
pub mod presence;
pub mod pusher;
pub mod resolver;

265
src/service/oauth/mod.rs Normal file
View File

@@ -0,0 +1,265 @@
pub mod providers;
pub mod sessions;
pub mod user_info;
use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
use reqwest::{Method, header::ACCEPT};
use ruma::UserId;
use serde::Serialize;
use serde_json::Value as JsonValue;
use tuwunel_core::{
Err, Result, err, implement,
utils::{hash::sha256, result::LogErr},
};
use url::Url;
pub use self::{
providers::Provider,
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
user_info::UserInfo,
};
use self::{providers::Providers, sessions::Sessions};
use crate::SelfServices;
pub struct Service {
services: SelfServices,
pub providers: Arc<Providers>,
pub sessions: Arc<Sessions>,
}
impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
let providers = Arc::new(Providers::build(args));
let sessions = Arc::new(Sessions::build(args, providers.clone()));
Ok(Arc::new(Self {
services: args.services.clone(),
sessions,
providers,
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_userinfo(
&self,
(provider, session): (&Provider, &Session),
) -> Result<UserInfo> {
let url = provider
.userinfo_url
.clone()
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
self.request((Some(provider), Some(session)), Method::GET, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_tokeninfo(
&self,
(provider, session): (&Provider, &Session),
) -> Result<UserInfo> {
let url = provider
.introspection_url
.clone()
.ok_or_else(|| {
err!(Config("introspection_url", "Missing introspection URL in config"))
})?;
self.request((Some(provider), Some(session)), Method::GET, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
#[derive(Debug, Serialize)]
struct RevokeQuery<'a> {
client_id: &'a str,
client_secret: &'a str,
}
let query = RevokeQuery {
client_id: &provider.client_id,
client_secret: &provider.client_secret,
};
let query = serde_html_form::to_string(&query)?;
let url = provider
.revocation_url
.clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url)
.await
.log_err()
.map(|_| ())
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_token(
&self,
(provider, session): (&Provider, &Session),
code: &str,
) -> Result<Session> {
#[derive(Debug, Serialize)]
struct TokenQuery<'a> {
client_id: &'a str,
client_secret: &'a str,
grant_type: &'a str,
code: &'a str,
code_verifier: Option<&'a str>,
redirect_uri: Option<&'a str>,
}
let query = TokenQuery {
client_id: &provider.client_id,
client_secret: &provider.client_secret,
grant_type: "authorization_code",
code,
code_verifier: session.code_verifier.as_deref(),
redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
};
let query = serde_html_form::to_string(&query)?;
let url = provider
.token_url
.clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
/// Send a request to a provider; this is somewhat abstract since URL's are
/// formed prior to this call and could point at anything, however this function
/// uses the oauth-specific http client and is configured for JSON with special
/// casing for an `error` property in the response.
#[implement(Service)]
#[tracing::instrument(
name = "request",
level = "debug",
ret(level = "trace"),
skip(self)
)]
pub async fn request(
&self,
(provider, session): (Option<&Provider>, Option<&Session>),
method: Method,
url: Url,
) -> Result<JsonValue> {
let mut request = self
.services
.client
.oauth
.request(method, url)
.header(ACCEPT, "application/json");
if let Some(session) = session {
if let Some(access_token) = session.access_token.clone() {
request = request.bearer_auth(access_token);
}
}
let response: JsonValue = request
.send()
.await?
.error_for_status()?
.json()
.await?;
if let Some(response) = response.as_object().as_ref()
&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
{
let description = response
.get("error_description")
.and_then(JsonValue::as_str)
.unwrap_or("(no description)");
return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
}
Ok(response)
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
let session = self.sessions.get_by_user(user_id).await?;
let provider = self.sessions.provider(&session).await?;
Ok((provider, session))
}
#[inline]
pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
}
#[inline]
pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
}
#[inline]
pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
}
pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
let hash = sha256::delimited([iss, sub].iter());
let b64 = b64encode.encode(hash);
Ok(b64)
}
fn unique_id_parts<'a>(
(provider, session): (&'a Provider, &'a Session),
) -> Result<(&'a str, &'a str)> {
provider
.issuer_url
.as_ref()
.map(Url::as_str)
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
.and_then(|iss| unique_id_iss_parts((iss, session)))
}
fn unique_id_sub_parts<'a>(
(provider, sub): (&'a Provider, &'a str),
) -> Result<(&'a str, &'a str)> {
provider
.issuer_url
.as_ref()
.map(Url::as_str)
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
.map(|iss| (iss, sub))
}
fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
session
.user_info
.as_ref()
.map(|user_info| user_info.sub.as_str())
.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
.map(|sub| (iss, sub))
}

View File

@@ -0,0 +1,246 @@
use std::collections::BTreeMap;
use serde_json::{Map as JsonObject, Value as JsonValue};
use tokio::sync::RwLock;
pub use tuwunel_core::config::IdentityProvider as Provider;
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
use url::Url;
use crate::SelfServices;
#[derive(Default)]
pub struct Providers {
services: SelfServices,
providers: RwLock<BTreeMap<String, Provider>>,
}
#[implement(Providers)]
pub(super) fn build(args: &crate::Args<'_>) -> Self {
Self {
services: args.services.clone(),
..Default::default()
}
}
/// Get the Provider configuration after any discovery and adjustments
/// made on top of the admin's configuration. This incurs network-based
/// discovery on the first call but responds from cache on subsequent calls.
#[implement(Providers)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get(&self, id: &str) -> Result<Provider> {
if let Some(provider) = self.providers.read().await.get(id).cloned() {
return Ok(provider);
}
let config = self.get_config(id)?;
let mut map = self.providers.write().await;
let config = self.configure(config).await?;
debug!(?id, ?config);
_ = map.insert(id.into(), config.clone());
Ok(config)
}
/// Get the admin-configured Provider which exists prior to any
/// reconciliation with the well-known discovery (the server's config is
/// immutable); though it is important to note the server config can be
/// reloaded. This will Err NotFound for a non-existent idp.
#[implement(Providers)]
pub fn get_config(&self, id: &str) -> Result<Provider> {
self.services
.config
.identity_provider
.iter()
.find(|config| config.id() == id)
.cloned()
.ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider"))))
}
/// Configure an identity provider; takes the admin-configured instance from the
/// server's config, queries the provider for discovery, and then returns an
/// updated config based on the proper reconciliation. This final config is then
/// cached in memory to avoid repeating this process.
#[implement(Providers)]
#[tracing::instrument(
level = INFO_SPAN_LEVEL,
ret(level = "debug"),
skip(self),
)]
async fn configure(&self, mut provider: Provider) -> Result<Provider> {
_ = provider
.name
.get_or_insert_with(|| provider.brand.clone());
if provider.issuer_url.is_none() {
_ = provider
.issuer_url
.replace(match provider.brand.as_str() {
| "github" => "https://github.com".try_into()?,
| "gitlab" => "https://gitlab.com".try_into()?,
| "google" => "https://accounts.google.com".try_into()?,
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
});
}
if provider.base_path.is_none() {
provider.base_path = match provider.brand.as_str() {
| "github" => Some("/login/oauth".to_owned()),
| _ => None,
};
}
let response = self
.discover(&provider)
.await
.and_then(|response| {
response.as_object().cloned().ok_or_else(|| {
err!(Request(NotJson("Expecting JSON object for discovery response")))
})
})
.and_then(|response| check_issuer(response, &provider))?;
if provider.authorization_url.is_none() {
response
.get("authorization_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/authorize").ok())
.map(|url| provider.authorization_url.replace(url));
}
if provider.revocation_url.is_none() {
response
.get("revocation_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/revocation").ok())
.map(|url| provider.revocation_url.replace(url));
}
if provider.introspection_url.is_none() {
response
.get("introspection_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/introspection").ok())
.map(|url| provider.introspection_url.replace(url));
}
if provider.userinfo_url.is_none() {
response
.get("userinfo_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| match provider.brand.as_str() {
| "github" => "https://api.github.com/user".try_into().ok(),
| _ => make_url(&provider, "/userinfo").ok(),
})
.map(|url| provider.userinfo_url.replace(url));
}
if provider.token_url.is_none() {
response
.get("token_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| {
let path = if provider.brand == "github" {
"/access_token"
} else {
"/token"
};
make_url(&provider, path).ok()
})
.map(|url| provider.token_url.replace(url));
}
Ok(provider)
}
#[implement(Providers)]
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
self.services
.client
.oauth
.get(discovery_url(provider)?)
.send()
.await?
.error_for_status()?
.json()
.await
.map_err(Into::into)
}
fn discovery_url(provider: &Provider) -> Result<Url> {
let default_url = provider
.discovery
.then(|| make_url(provider, "/.well-known/openid-configuration"))
.transpose()?;
let Some(url) = provider
.discovery_url
.clone()
.filter(|_| provider.discovery)
.or(default_url)
else {
return Err!(Config(
"discovery_url",
"Failed to determine URL for discovery of provider {}",
provider.id()
));
};
Ok(url)
}
/// Validate that the locally configured `issuer_url` matches the issuer claimed
/// in any response. todo: cryptographic validation is not yet implemented here.
fn check_issuer(
response: JsonObject<String, JsonValue>,
provider: &Provider,
) -> Result<JsonObject<String, JsonValue>> {
let expected = provider
.issuer_url
.as_ref()
.map(Url::as_str)
.map(|url| url.trim_end_matches('/'));
let responded = response
.get("issuer")
.and_then(JsonValue::as_str)
.map(|url| url.trim_end_matches('/'));
if expected != responded {
return Err!(Request(Unauthorized(
"Configured issuer_url {expected:?} does not match discovered {responded:?}",
)));
}
Ok(response)
}
/// Generate a full URL for a request to the idp based on the idp's derived
/// configuration.
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
let mut suffix = provider.base_path.clone().unwrap_or_default();
suffix.push_str(path);
let url = provider
.issuer_url
.as_ref()
.ok_or_else(|| {
let id = &provider.client_id;
err!(Config("issuer_url", "Provider {id:?} required field"))
})?
.join(&suffix)?;
Ok(url)
}

View File

@@ -0,0 +1,234 @@
use std::{sync::Arc, time::SystemTime};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
use tuwunel_database::{Cbor, Deserialized, Map};
use url::Url;
use super::{Provider, Providers, UserInfo, unique_id};
use crate::SelfServices;
pub struct Sessions {
_services: SelfServices,
providers: Arc<Providers>,
db: Data,
}
struct Data {
oauthid_session: Arc<Map>,
oauthuniqid_oauthid: Arc<Map>,
userid_oauthid: Arc<Map>,
}
/// Session ultimately represents an OAuth authorization session yielding an
/// associated matrix user registration.
///
/// Mixed-use structure capable of deserializing response values, maintaining
/// the state between authorization steps, and maintaining the association to
/// the matrix user until deactivation or revocation.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Session {
/// Identity Provider ID (the `client_id` in the configuration) associated
/// with this session.
pub idp_id: Option<String>,
/// Session ID used as the index key for this session itself.
pub sess_id: Option<String>,
/// Token type (bearer, mac, etc).
pub token_type: Option<String>,
/// Access token to the provider.
pub access_token: Option<String>,
/// Duration in seconds the access_token is valid for.
pub expires_in: Option<u64>,
/// Point in time that the access_token expires.
pub expires_at: Option<SystemTime>,
/// Token used to refresh the access_token.
pub refresh_token: Option<String>,
/// Duration in seconds the refresh_token is valid for
pub refresh_token_expires_in: Option<u64>,
/// Point in time that the refresh_token expires.
pub refresh_token_expires_at: Option<SystemTime>,
/// Access scope actually granted (if supported).
pub scope: Option<String>,
/// Redirect URL
pub redirect_url: Option<Url>,
/// Challenge preimage
pub code_verifier: Option<String>,
/// Random string passed exclusively in the grant session cookie.
pub cookie_nonce: Option<String>,
/// Random single-use string passed in the provider redirect.
pub query_nonce: Option<String>,
/// Point in time the authorization grant session expires.
pub authorize_expires_at: Option<SystemTime>,
/// Associated User Id registration.
pub user_id: Option<OwnedUserId>,
/// Last userinfo response persisted here.
pub user_info: Option<UserInfo>,
}
/// Number of characters generated for our code_verifier. The code_verifier is a
/// random string which must be between 43 and 128 characters.
pub const CODE_VERIFIER_LENGTH: usize = 64;
/// Number of characters we will generate for the Session ID.
pub const SESSION_ID_LENGTH: usize = 32;
#[implement(Sessions)]
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
Self {
_services: args.services.clone(),
providers,
db: Data {
oauthid_session: args.db["oauthid_session"].clone(),
oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
userid_oauthid: args.db["userid_oauthid"].clone(),
},
}
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
/// Delete database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn delete(&self, sess_id: &str) {
let Ok(session) = self.get(sess_id).await else {
return;
};
// Check the user_id still points to this sess_id before deleting. If not, the
// association was updated to a newer session.
if let Some(user_id) = session.user_id.as_deref() {
if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await {
if assoc_id == sess_id {
self.db.userid_oauthid.remove(user_id);
}
}
}
// Check the unique identity still points to this sess_id before deleting. If
// not, the association was updated to a newer session.
if let Some(idp_id) = session.idp_id.as_ref() {
if let Ok(provider) = self.providers.get(idp_id).await {
if let Ok(unique_id) = unique_id((&provider, &session)) {
if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await {
if assoc_id == sess_id {
self.db.oauthuniqid_oauthid.remove(&unique_id);
}
}
}
}
}
self.db.oauthid_session.remove(sess_id);
}
/// Create or overwrite database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "info", skip(self))]
pub async fn put(&self, sess_id: &str, session: &Session) {
self.db
.oauthid_session
.raw_put(sess_id, Cbor(session));
if let Some(user_id) = session.user_id.as_deref() {
self.db.userid_oauthid.insert(user_id, sess_id);
}
if let Some(idp_id) = session.idp_id.as_ref() {
if let Ok(provider) = self.providers.get(idp_id).await {
if let Ok(unique_id) = unique_id((&provider, session)) {
self.db
.oauthuniqid_oauthid
.insert(&unique_id, sess_id);
}
}
}
}
/// Fetch database state for a session from its associated `sub`, in case
/// `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
self.get_sess_id_by_unique_id(unique_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its associated `user_id`, in case
/// `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
self.get_sess_id_by_user(user_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its `sess_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get(&self, sess_id: &str) -> Result<Session> {
self.db
.oauthid_session
.get(sess_id)
.await
.deserialized::<Cbor<_>>()
.map(at!(0))
}
/// Resolve the `sess_id` from an associated `user_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
self.db
.userid_oauthid
.get(user_id)
.await
.deserialized()
}
/// Resolve the `sess_id` from an associated provider subject id.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
self.db
.oauthuniqid_oauthid
.get(unique_id)
.await
.deserialized()
}
#[implement(Sessions)]
pub async fn provider(&self, session: &Session) -> Result<Provider> {
let Some(idp_id) = session.idp_id.as_deref() else {
return Err!(Request(NotFound("No provider for this session")));
};
self.providers.get(idp_id).await
}

View File

@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};
/// Selection of userinfo response claims.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct UserInfo {
/// Unique identifier number or login username. Usually a number on most
/// services. We consider a concatenation of the `iss` and `sub` to be a
/// universally unique identifier for some user/identity; we index that in
/// `oauthidpsub_oauthid`.
///
/// Considered for user mxid only if none of the better fields are defined.
/// `login` alias intended for github.
#[serde(alias = "login")]
pub sub: String,
/// The login username we first consider when defined.
pub preferred_username: Option<String>,
/// The login username considered if none preferred.
pub nickname: Option<String>,
/// Full name.
pub name: Option<String>,
/// First name.
pub given_name: Option<String>,
/// Last name.
pub family_name: Option<String>,
/// Email address (`email` scope).
pub email: Option<String>,
/// URL to pfp (github/gitlab)
pub avatar_url: Option<String>,
/// URL to pfp (google)
pub picture: Option<String>,
}

View File

@@ -12,7 +12,7 @@ use crate::{
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
key_backups,
manager::Manager,
media, membership, presence, pusher, resolver, rooms, sending, server_keys,
media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys,
service::{Args, Service},
sync, transaction_ids, uiaa, users,
};
@@ -58,6 +58,7 @@ pub struct Services {
pub users: Arc<users::Service>,
pub membership: Arc<membership::Service>,
pub deactivate: Arc<deactivate::Service>,
pub oauth: Arc<oauth::Service>,
manager: Mutex<Option<Arc<Manager>>>,
pub server: Arc<Server>,
@@ -115,6 +116,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
users: users::Service::build(&args)?,
membership: membership::Service::build(&args)?,
deactivate: deactivate::Service::build(&args)?,
oauth: oauth::Service::build(&args)?,
manager: Mutex::new(None),
server,
@@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
cast!(self.users),
cast!(self.membership),
cast!(self.deactivate),
cast!(self.oauth),
]
.into_iter()
}

View File

@@ -14,10 +14,10 @@ use ruma::{
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
};
use tuwunel_core::{
Err, Result, debug_warn, err, is_equal_to,
Err, Result, debug_warn, err,
pdu::PduBuilder,
trace,
utils::{self, ReadyExt, stream::TryIgnore},
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
warn,
};
use tuwunel_database::{Deserialized, Json, Map};
@@ -132,6 +132,16 @@ impl Service {
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
// Revoke any SSO authorizations
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
self.services
.oauth
.revoke_token((&provider, &session))
.await
.log_err()
.ok();
}
// Remove all associated devices
self.all_device_ids(user_id)
.for_each(|device_id| self.remove_device(user_id, device_id))
@@ -227,18 +237,15 @@ impl Service {
// Cannot change the password of a LDAP user. There are two special cases :
// - a `None` password can be used to deactivate a LDAP user
// - a "*" password is used as the default password of an active LDAP user
if cfg!(feature = "ldap")
&& password.is_some()
//
// The above now applies to all non-password origin users including SSO
// users. Note that users with no origin are also password-origin users.
if password.is_some()
&& password != Some("*")
&& self
.db
.userid_origin
.get(user_id)
.await
.deserialized::<String>()
.is_ok_and(is_equal_to!("ldap"))
&& let Ok(origin) = self.origin(user_id).await
&& origin != "password"
{
return Err!(Request(InvalidParam("Cannot change password of a LDAP user")));
return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user.")));
}
password