Implement associated multi-provider single-sign-on flow support. (#252)

Add experimental note for multi-provider flow. (#252)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-23 03:10:41 +00:00
parent a3294fe1cf
commit 6db87a4027
12 changed files with 411 additions and 197 deletions

20
Cargo.lock generated
View File

@@ -3680,7 +3680,7 @@ dependencies = [
[[package]]
name = "ruma"
version = "0.13.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"assign",
"js_int",
@@ -3699,7 +3699,7 @@ dependencies = [
[[package]]
name = "ruma-appservice-api"
version = "0.13.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"js_int",
"ruma-common",
@@ -3711,7 +3711,7 @@ dependencies = [
[[package]]
name = "ruma-client-api"
version = "0.21.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"as_variant",
"assign",
@@ -3736,7 +3736,7 @@ dependencies = [
[[package]]
name = "ruma-common"
version = "0.16.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"as_variant",
"base64",
@@ -3770,7 +3770,7 @@ dependencies = [
[[package]]
name = "ruma-events"
version = "0.31.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"as_variant",
"indexmap",
@@ -3797,7 +3797,7 @@ dependencies = [
[[package]]
name = "ruma-federation-api"
version = "0.12.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"bytes",
"headers",
@@ -3820,7 +3820,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers-validation"
version = "0.11.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"js_int",
"thiserror 2.0.18",
@@ -3829,7 +3829,7 @@ dependencies = [
[[package]]
name = "ruma-macros"
version = "0.16.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"cfg-if",
"proc-macro-crate",
@@ -3844,7 +3844,7 @@ dependencies = [
[[package]]
name = "ruma-push-gateway-api"
version = "0.12.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"js_int",
"ruma-common",
@@ -3856,7 +3856,7 @@ dependencies = [
[[package]]
name = "ruma-signatures"
version = "0.18.0"
source = "git+https://github.com/matrix-construct/ruma?rev=54f56d7df454f7b4036386ff2ade6e11ab8af7de#54f56d7df454f7b4036386ff2ade6e11ab8af7de"
source = "git+https://github.com/matrix-construct/ruma?rev=1311829cb73986ff8da4837e98a588860712630f#1311829cb73986ff8da4837e98a588860712630f"
dependencies = [
"base64",
"ed25519-dalek",

View File

@@ -327,7 +327,7 @@ default-features = false
[workspace.dependencies.ruma]
git = "https://github.com/matrix-construct/ruma"
rev = "54f56d7df454f7b4036386ff2ade6e11ab8af7de"
rev = "1311829cb73986ff8da4837e98a588860712630f"
features = [
"__compat",
"appservice-api-c",

View File

@@ -1,11 +1,14 @@
use clap::Subcommand;
use futures::{StreamExt, TryStreamExt};
use ruma::OwnedUserId;
use tuwunel_core::{Err, Result, apply, err, itertools::Itertools, utils::stream::IterStream};
use tuwunel_service::{
Services,
oauth::{Provider, Session},
use tuwunel_core::{
Err, Result, apply,
either::{Either, Left, Right},
err,
itertools::Itertools,
utils::stream::{IterStream, ReadyExt},
};
use tuwunel_service::oauth::{Provider, ProviderId, SessionId};
use crate::{admin_command, admin_command_dispatch};
@@ -29,12 +32,18 @@ pub(crate) enum OauthCommand {
/// List configured OAuth providers.
ListProviders,
/// List users associated with an OAuth provider
/// List users associated with any OAuth session
ListUsers,
/// List session ID's
ListSessions {
#[arg(long)]
user: Option<OwnedUserId>,
},
/// Show active configuration of a provider.
ShowProvider {
id: String,
id: ProviderId,
#[arg(long)]
config: bool,
@@ -42,28 +51,43 @@ pub(crate) enum OauthCommand {
/// Show session state
ShowSession {
id: String,
id: SessionId,
},
/// Show user sessions
ShowUser {
user_id: OwnedUserId,
},
/// Token introspection request to provider.
TokenInfo {
id: String,
id: SessionId,
},
/// Revoke token for user_id or sess_id.
Revoke {
id: String,
#[arg(value_parser = session_or_user_id)]
id: Either<SessionId, OwnedUserId>,
},
/// Remove oauth state (DANGER!)
Remove {
id: String,
Delete {
#[arg(value_parser = session_or_user_id)]
id: Either<SessionId, OwnedUserId>,
#[arg(long)]
force: bool,
},
}
type SessionOrUserId = Either<SessionId, OwnedUserId>;
fn session_or_user_id(input: &str) -> Result<SessionOrUserId> {
OwnedUserId::parse(input)
.map(Right)
.or_else(|_| Ok(Left(input.to_owned())))
}
#[admin_command]
pub(super) async fn oauth_associate(
&self,
@@ -137,7 +161,34 @@ pub(super) async fn oauth_list_users(&self) -> Result {
}
#[admin_command]
pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Result {
pub(super) async fn oauth_list_sessions(&self, user_id: Option<OwnedUserId>) -> Result {
if let Some(user_id) = user_id.as_deref() {
return self
.services
.oauth
.sessions
.get_sess_id_by_user(user_id)
.map_ok(|id| format!("{id}\n"))
.try_for_each(async |id: String| self.write_str(&id).await)
.await;
}
self.services
.oauth
.sessions
.stream()
.ready_filter_map(|sess| sess.sess_id)
.map(|sess_id| format!("{sess_id:?}\n"))
.for_each(async |id: String| {
self.write_str(&id).await.ok();
})
.await;
Ok(())
}
#[admin_command]
pub(super) async fn oauth_show_provider(&self, id: ProviderId, config: bool) -> Result {
if config {
let config = self.services.oauth.providers.get_config(&id)?;
@@ -151,15 +202,29 @@ pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Resu
}
#[admin_command]
pub(super) async fn oauth_show_session(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
pub(super) async fn oauth_show_session(&self, id: SessionId) -> Result {
let session = self.services.oauth.sessions.get(&id).await?;
self.write_str(&format!("{session:#?}\n")).await
}
#[admin_command]
pub(super) async fn oauth_token_info(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
pub(super) async fn oauth_show_user(&self, user_id: OwnedUserId) -> Result {
self.services
.oauth
.sessions
.get_sess_id_by_user(&user_id)
.try_for_each(async |id| {
let session = self.services.oauth.sessions.get(&id).await?;
self.write_str(&format!("{session:#?}\n")).await
})
.await
}
#[admin_command]
pub(super) async fn oauth_token_info(&self, id: SessionId) -> Result {
let session = self.services.oauth.sessions.get(&id).await?;
let provider = self
.services
@@ -172,61 +237,64 @@ pub(super) async fn oauth_token_info(&self, id: String) -> Result {
.services
.oauth
.request_tokeninfo((&provider, &session))
.await;
.await?;
self.write_str(&format!("{tokeninfo:#?}\n")).await
}
#[admin_command]
pub(super) async fn oauth_revoke(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
pub(super) async fn oauth_revoke(&self, id: SessionOrUserId) -> Result {
match id {
| Left(sess_id) => {
let session = self.services.oauth.sessions.get(&sess_id).await?;
let provider = self
.services
.oauth
.sessions
.provider(&session)
.await?;
let provider = self
.services
.oauth
.sessions
.provider(&session)
.await?;
self.services
.oauth
.revoke_token((&provider, &session))
.await?;
self.services
.oauth
.revoke_token((&provider, &session))
.await
.ok();
},
| Right(user_id) =>
self.services
.oauth
.revoke_user_tokens(&user_id)
.await,
}
self.write_str("done").await
self.write_str("revoked").await
}
#[admin_command]
pub(super) async fn oauth_remove(&self, id: String, force: bool) -> Result {
let session = find_session(self.services, &id).await?;
let Some(sess_id) = session.sess_id else {
return Err!("Missing sess_id in oauth Session state");
};
pub(super) async fn oauth_delete(&self, id: SessionOrUserId, force: bool) -> Result {
if !force {
return Err!(
"Deleting these records can cause registration conflicts. Use --force to be sure."
);
}
self.services
.oauth
.sessions
.delete(&sess_id)
.await;
self.write_str("done").await
}
async fn find_session(services: &Services, id: &str) -> Result<Session> {
if let Ok(user_id) = OwnedUserId::parse(id) {
services
.oauth
.sessions
.get_by_user(&user_id)
.await
} else {
services.oauth.sessions.get(id).await
match id {
| Left(sess_id) => {
self.services
.oauth
.sessions
.delete(&sess_id)
.await;
},
| Right(user_id) => {
self.services
.oauth
.delete_user_sessions(&user_id)
.await;
},
}
self.write_str("deleted any oauth state for {id}")
.await
}

View File

@@ -22,10 +22,7 @@ use ruma::api::client::session::{
v3::{DiscoveryInfo, HomeserverInfo, LoginInfo},
},
};
use tuwunel_core::{
Err, Result, info,
utils::{BoolExt, stream::ReadyExt},
};
use tuwunel_core::{Err, Result, info, utils::stream::ReadyExt};
use tuwunel_service::users::device::generate_refresh_token;
use self::{ldap::ldap_login, password::password_login};
@@ -50,13 +47,13 @@ pub(crate) async fn get_login_types_route(
) -> Result<get_login_types::v3::Response> {
let get_login_token = services.config.login_via_existing_session;
let identity_providers = services
let list_idps = !services.config.sso_custom_providers_page && !services.config.single_sso;
let identity_providers: Vec<_> = services
.config
.sso_custom_providers_page
.is_false()
.then(|| services.config.identity_provider.iter())
.into_iter()
.flatten()
.identity_provider
.iter()
.filter(|_| list_idps)
.cloned()
.map(|config| IdentityProvider {
id: config.id().to_owned(),
@@ -70,8 +67,8 @@ pub(crate) async fn get_login_types_route(
LoginType::ApplicationService(ApplicationServiceLoginType::default()),
LoginType::Jwt(JwtLoginType::default()),
LoginType::Password(PasswordLoginType::default()),
LoginType::Sso(SsoLoginType { identity_providers }),
LoginType::Token(TokenLoginType { get_login_token }),
LoginType::Sso(SsoLoginType { identity_providers }),
];
Ok(get_login_types::v3::Response {
@@ -79,8 +76,7 @@ pub(crate) async fn get_login_types_route(
.into_iter()
.filter(|login_type| match login_type {
| LoginType::Sso(SsoLoginType { identity_providers })
if !services.config.sso_custom_providers_page
&& identity_providers.is_empty() =>
if list_idps && identity_providers.is_empty() =>
false,
| _ => true,

View File

@@ -4,7 +4,7 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use axum_extra::extract::cookie::{Cookie, SameSite};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use futures::{StreamExt, TryFutureExt, future::try_join};
use futures::{FutureExt, StreamExt, TryFutureExt, future::try_join};
use reqwest::header::{CONTENT_TYPE, HeaderValue};
use ruma::{
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
@@ -19,6 +19,7 @@ use tuwunel_core::{
itertools::Itertools,
utils,
utils::{
OptionExt,
content_disposition::make_content_disposition,
hash::sha256,
result::{FlatOk, LogErr},
@@ -98,7 +99,7 @@ pub(crate) async fn sso_login_route(
let redirect_url = body.body.redirect_url;
handle_sso_login(&services, &client, default_idp_id, redirect_url)
handle_sso_login(&services, &client, default_idp_id, redirect_url, None)
.map_ok(|response| sso_login::v3::Response {
location: response.location,
cookie: response.cookie,
@@ -128,8 +129,9 @@ pub(crate) async fn sso_login_with_provider_route(
) -> Result<sso_login_with_provider::v3::Response> {
let idp_id = body.body.idp_id;
let redirect_url = body.body.redirect_url;
let login_token = body.body.login_token;
handle_sso_login(&services, &client, idp_id, redirect_url).await
handle_sso_login(&services, &client, idp_id, redirect_url, login_token).await
}
async fn handle_sso_login(
@@ -137,6 +139,7 @@ async fn handle_sso_login(
_client: &IpAddr,
idp_id: String,
redirect_url: String,
login_token: Option<String>,
) -> Result<sso_login_with_provider::v3::Response> {
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
return Err!(Request(InvalidParam("Invalid redirect_url")));
@@ -221,14 +224,16 @@ async fn handle_sso_login(
.map(timepoint_from_now)
.transpose()?,
user_id: login_token
.as_deref()
.map_async(|token| services.users.find_from_login_token(token))
.map(FlatOk::flat_ok)
.await,
..Default::default()
};
services
.oauth
.sessions
.put(&sess_id, &session)
.await;
services.oauth.sessions.put(&session).await;
Ok(sso_login_with_provider::v3::Response {
location: location.into(),
@@ -269,13 +274,23 @@ pub(crate) async fn sso_callback_route(
.get(sess_id)
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
let idp_id = body.body.idp_id.as_str();
let provider = services.oauth.providers.get(idp_id);
let provider = services
.oauth
.providers
.get(body.body.idp_id.as_str());
let (provider, session) = try_join(provider, session).await.log_err()?;
let client_id = &provider.client_id;
let idp_id = provider.id();
if session.sess_id.as_deref() != Some(sess_id) {
return Err!(Request(Unauthorized("Session ID {sess_id:?} not recognized.")));
}
if session.idp_id.as_deref() != Some(idp_id) {
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
return Err!(Request(Unauthorized(
"Identity Provider {idp_id:?} session not recognized."
)));
}
if session
@@ -346,7 +361,7 @@ pub(crate) async fn sso_callback_route(
// Check for an existing session from this identity. We want to maintain one
// session for each identity and keep the newer one which has up-to-date state
// and access.
let (user_id, old_sess_id) = match services
let (old_user_id, old_sess_id) = match services
.oauth
.sessions
.get_by_unique_id(&unique_id)
@@ -364,9 +379,9 @@ pub(crate) async fn sso_callback_route(
};
// Keep the user_id from the old session as best as possible.
let user_id = match user_id {
| Some(user_id) => user_id,
| None => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
let user_id = match (session.user_id, old_user_id) {
| (Some(user_id), ..) | (None, Some(user_id)) => user_id,
| (None, None) => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
};
// Update the session with user_id
@@ -381,11 +396,7 @@ pub(crate) async fn sso_callback_route(
}
// Commit the updated session.
services
.oauth
.sessions
.put(sess_id, &session)
.await;
services.oauth.sessions.put(&session).await;
// Delete any old session.
if let Some(old_sess_id) = old_sess_id
@@ -398,14 +409,43 @@ pub(crate) async fn sso_callback_route(
return Err!(Request(UserDeactivated("This user has been deactivated.")));
}
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
.removal()
.build()
.to_string()
.into();
// Determine the next provider to chain after this one.
let next_idp_url = services
.config
.identity_provider
.iter()
.filter(|idp| idp.default || services.config.single_sso)
.skip_while(|idp| idp.id() != idp_id)
.nth(1)
.map(IdentityProvider::id)
.and_then(|next_idp| {
provider.callback_url.clone().map(|mut url| {
let path = format!("/_matrix/client/v3/login/sso/redirect/{next_idp}");
url.set_path(&path);
if let Some(redirect_url) = session.redirect_url.as_ref() {
url.query_pairs_mut()
.append_pair("redirectUrl", redirect_url.as_str());
}
url
})
});
// Allow the user to login to Matrix.
let login_token = utils::random_string(TOKEN_LENGTH);
let _login_token_expires_in = services
.users
.create_login_token(&user_id, &login_token);
let location = session
.redirect_url
let location = next_idp_url
.or(session.redirect_url)
.as_ref()
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
.clone()
@@ -414,12 +454,6 @@ pub(crate) async fn sso_callback_route(
.finish()
.to_string();
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
.removal()
.build()
.to_string()
.into();
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
}
@@ -594,8 +628,8 @@ async fn decide_user_id(
}
let length = Some(15..23);
let infallible = truncate_deterministic(unique_id, length).to_lowercase();
if let Some(user_id) = try_user_id(services, &infallible, true).await {
let unique_id = truncate_deterministic(unique_id, length).to_lowercase();
if let Some(user_id) = try_user_id(services, &unique_id, true).await {
return Ok(user_id);
}

View File

@@ -317,20 +317,6 @@ pub fn check(config: &Config) -> Result {
}
}
if config
.identity_provider
.iter()
.filter(|idp| idp.default)
.count()
.gt(&1)
{
return Err!(Config(
"identity_provider.default",
"More than one identity_provider is configured with `default = true`. Only one can \
be set to default.",
));
}
if !config.sso_custom_providers_page
&& config.identity_provider.len() > 1
&& config

View File

@@ -2173,6 +2173,24 @@ pub struct Config {
#[serde(default = "default_one_time_key_limit")]
pub one_time_key_limit: usize,
/// (EXPERIMENTAL) Setting this option to true replaces the list of identity
/// providers displayed on a client's login page with a single button "Sign
/// in with single sign-on" linking to the URL
/// `/_matrix/client/v3/login/sso/redirect`. All configured providers are
/// attempted for authorization. All authorizations associate with the same
/// Matrix user. NOTE: All authorizations must succeed, as there is no
/// reliable way to skip a provider.
///
/// This option is disabled by default, allowing the client to list
/// configured providers and permitting privacy-conscious users to authorize
/// only their choice.
///
/// Note that fluffychat always displays a single button anyway. You do not
/// need to enable this to use fluffychat; instead we offer a
/// default-provider option, see `default` in the provider config section.
#[serde(default)]
pub single_sso: bool,
/// Setting this option to true replaces the list of identity providers on
/// the client's login screen with a single button "Sign in with single
/// sign-on" linking to the URL `/_matrix/client/v3/login/sso/redirect`. The
@@ -2590,16 +2608,25 @@ pub struct IdentityProvider {
pub callback_url: Option<Url>,
/// When more than one identity_provider has been configured and
/// `sso_custom_providers_page` is false this will determine the results
/// for the `/_matrix/client/v3/login/sso/redirect` endpoint (note the url
/// lacks a trailing `client_id`).
/// `single_sso` is false and `sso_custom_providers_page` is false this will
/// determine the behavior of the `/_matrix/client/v3/login/sso/redirect`
/// endpoint (note the url lacks a trailing `client_id`).
///
/// When only one identity_provider is configured it will be interpreted
/// as the default and this does not have to be set. Otherwise a default
/// as the default and this does not need to be set. Otherwise a default
/// *must* be selected for some clients (e.g. fluffychat) to work properly
/// when the above conditions require it. For compatibility if not set a
/// warning will be logged on startup and the first provider listed will be
/// considered the default.
/// when the above conditions require it. To operate out-of-the-box we
/// default to one configured provider if none are explicitly default; a
/// warning will be logged on startup for this condition.
///
/// (EXPERIMENTAL) Multiple providers can be set to default. All providers
/// configured with this option set to `true` will associate with the same
/// Matrix account when a client flows through
/// `/_matrix/client/v3/login/sso/redirect`.
///
/// When a user authorizes any provider configured default, the flow will
/// include all other providers configured default as well for association.
/// NOTE: authorization must succeed for ALL default providers.
#[serde(default)]
pub default: bool,
@@ -2676,6 +2703,8 @@ pub struct IdentityProvider {
pub discovery: bool,
/// The duration in seconds before a grant authorization session expires.
///
/// default: 300
#[serde(default = "default_sso_grant_session_duration")]
pub grant_session_duration: Option<u64>,
}
@@ -3251,4 +3280,4 @@ fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 }
fn default_max_join_attempts_per_join_request() -> usize { 3 }
fn default_sso_grant_session_duration() -> Option<u64> { Some(180) }
fn default_sso_grant_session_duration() -> Option<u64> { Some(300) }

View File

@@ -5,6 +5,7 @@ pub mod user_info;
use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::{
Method,
header::{ACCEPT, CONTENT_TYPE},
@@ -14,16 +15,16 @@ use serde::Serialize;
use serde_json::Value as JsonValue;
use tuwunel_core::{
Err, Result, err, implement,
utils::{hash::sha256, result::LogErr},
utils::{hash::sha256, result::LogErr, stream::ReadyExt},
};
use url::Url;
use self::{providers::Providers, sessions::Sessions};
pub use self::{
providers::Provider,
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
providers::{Provider, ProviderId},
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
user_info::UserInfo,
};
use self::{providers::Providers, sessions::Sessions};
use crate::SelfServices;
pub struct Service {
@@ -46,6 +47,48 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
/// Remove all session state for a user. For debug and developer use only;
/// deleting state can cause registration conflicts and unintended
/// re-registrations.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn delete_user_sessions(&self, user_id: &UserId) {
self.user_sessions(user_id)
.ready_filter_map(Result::ok)
.ready_filter_map(|(_, session)| session.sess_id)
.for_each(async |sess_id| {
self.sessions.delete(&sess_id).await;
})
.await;
}
/// Revoke all session tokens for a user.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn revoke_user_tokens(&self, user_id: &UserId) {
self.user_sessions(user_id)
.ready_filter_map(Result::ok)
.for_each(async |(provider, session)| {
self.revoke_token((&provider, &session))
.await
.log_err()
.ok();
})
.await;
}
/// Get user's authorizations. Lists pairs of `(Provider, Session)` for a user.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub fn user_sessions(
&self,
user_id: &UserId,
) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
self.sessions
.get_by_user(user_id)
.and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
}
/// Network request to a Provider returning userinfo for a Session. The session
/// must have a valid access token.
#[implement(Service)]
@@ -222,15 +265,6 @@ where
Ok(response)
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
let session = self.sessions.get_by_user(user_id).await?;
let provider = self.sessions.provider(&session).await?;
Ok((provider, session))
}
/// Generate a unique-id string determined by the combination of `Provider` and
/// `Session` instances.
#[inline]

View File

@@ -8,12 +8,16 @@ use url::Url;
use crate::SelfServices;
/// Discovered providers
#[derive(Default)]
pub struct Providers {
services: SelfServices,
providers: RwLock<BTreeMap<String, Provider>>,
providers: RwLock<BTreeMap<ProviderId, Provider>>,
}
/// Identity Provider ID
pub type ProviderId = String;
#[implement(Providers)]
pub(super) fn build(args: &crate::Args<'_>) -> Self {
Self {

View File

@@ -1,15 +1,19 @@
pub mod association;
use std::{
iter::once,
sync::{Arc, Mutex},
time::SystemTime,
};
use futures::{Stream, StreamExt, TryFutureExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
use tuwunel_database::{Cbor, Deserialized, Map};
use tuwunel_core::{
Err, Result, at, implement,
utils::stream::{IterStream, ReadyExt, TryExpect},
};
use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
use url::Url;
use super::{Provider, Providers, UserInfo, unique_id};
@@ -41,7 +45,7 @@ pub struct Session {
pub idp_id: Option<String>,
/// Session ID used as the index key for this session itself.
pub sess_id: Option<String>,
pub sess_id: Option<SessionId>,
/// Token type (bearer, mac, etc).
pub token_type: Option<String>,
@@ -89,6 +93,9 @@ pub struct Session {
pub user_info: Option<UserInfo>,
}
/// Session Identifier type.
pub type SessionId = String;
/// Number of characters generated for our code_verifier. The code_verifier is a
/// random string which must be between 43 and 128 characters.
pub const CODE_VERIFIER_LENGTH: usize = 64;
@@ -110,15 +117,6 @@ pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
}
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
/// Delete database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self))]
@@ -127,13 +125,19 @@ pub async fn delete(&self, sess_id: &str) {
return;
};
// Check the user_id still points to this sess_id before deleting. If not, the
// association was updated to a newer session.
if let Some(user_id) = session.user_id.as_deref()
&& let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await
&& assoc_id == sess_id
{
self.db.userid_oauthid.remove(user_id);
if let Some(user_id) = session.user_id.as_deref() {
let sess_ids: Vec<_> = self
.get_sess_id_by_user(user_id)
.ready_filter_map(Result::ok)
.ready_filter(|assoc_id| assoc_id != sess_id)
.collect()
.await;
if !sess_ids.is_empty() {
self.db.userid_oauthid.raw_put(user_id, sess_ids);
} else {
self.db.userid_oauthid.remove(user_id);
}
}
// Check the unique identity still points to this sess_id before deleting. If
@@ -153,15 +157,16 @@ pub async fn delete(&self, sess_id: &str) {
/// Create or overwrite database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "info", skip(self))]
pub async fn put(&self, sess_id: &str, session: &Session) {
pub async fn put(&self, session: &Session) {
let sess_id = session
.sess_id
.as_deref()
.expect("Missing session.sess_id required for sessions.put()");
self.db
.oauthid_session
.raw_put(sess_id, Cbor(session));
if let Some(user_id) = session.user_id.as_deref() {
self.db.userid_oauthid.insert(user_id, sess_id);
}
if let Some(idp_id) = session.idp_id.as_ref()
&& let Ok(provider) = self.providers.get(idp_id).await
&& let Ok(unique_id) = unique_id((&provider, session))
@@ -170,6 +175,22 @@ pub async fn put(&self, sess_id: &str, session: &Session) {
.oauthuniqid_oauthid
.insert(&unique_id, sess_id);
}
if let Some(user_id) = session.user_id.as_deref() {
let sess_ids = self
.get_sess_id_by_user(user_id)
.ready_filter_map(Result::ok)
.chain(once(sess_id.to_owned()).stream())
.collect::<Vec<_>>()
.map(|mut ids| {
ids.sort_unstable();
ids.dedup();
ids
})
.await;
self.db.userid_oauthid.raw_put(user_id, sess_ids);
}
}
/// Fetch database state for a session from its associated `(iss,sub)`, in case
@@ -182,14 +203,13 @@ pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
.await
}
/// Fetch database state for a session from its associated `user_id`, in case
/// `sess_id` is not known.
/// Fetch database state for one or more sessions from its associated `user_id`,
/// in case `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
#[tracing::instrument(level = "debug", skip(self))]
pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
self.get_sess_id_by_user(user_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its `sess_id`.
@@ -204,15 +224,17 @@ pub async fn get(&self, sess_id: &str) -> Result<Session> {
.map(at!(0))
}
/// Resolve the `sess_id` from an associated `user_id`.
/// Resolve the `sess_id` associations with a `user_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
#[tracing::instrument(level = "debug", skip(self))]
pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
self.db
.userid_oauthid
.get(user_id)
.await
.deserialized()
.map(Deserialized::deserialized)
.map_ok(Vec::into_iter)
.map_ok(IterStream::try_stream)
.try_flatten_stream()
}
/// Resolve the `sess_id` from an associated provider issuer and subject hash.
@@ -226,6 +248,24 @@ pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String>
.deserialized()
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
#[implement(Sessions)]
pub fn stream(&self) -> impl Stream<Item = Session> + Send {
self.db
.oauthid_session
.stream()
.expect_ok()
.map(|(_, session): (Ignore, Cbor<_>)| session.0)
}
#[implement(Sessions)]
pub async fn provider(&self, session: &Session) -> Result<Provider> {
let Some(idp_id) = session.idp_id.as_deref() else {

View File

@@ -17,7 +17,7 @@ use tuwunel_core::{
Err, Result, debug_warn, err, is_equal_to,
pdu::PduBuilder,
trace,
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
utils::{self, ReadyExt, stream::TryIgnore},
warn,
};
use tuwunel_database::{Deserialized, Json, Map};
@@ -133,14 +133,10 @@ impl Service {
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
// Revoke any SSO authorizations
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
self.services
.oauth
.revoke_token((&provider, &session))
.await
.log_err()
.ok();
}
self.services
.oauth
.revoke_user_tokens(user_id)
.await;
// Remove all associated devices
self.all_device_ids(user_id)

View File

@@ -1861,6 +1861,24 @@
#
#one_time_key_limit = 256
# (EXPERIMENTAL) Setting this option to true replaces the list of identity
# providers displayed on a client's login page with a single button "Sign
# in with single sign-on" linking to the URL
# `/_matrix/client/v3/login/sso/redirect`. All configured providers are
# attempted for authorization. All authorizations associate with the same
# Matrix user. NOTE: All authorizations must succeed, as there is no
# reliable way to skip a provider.
#
# This option is disabled by default, allowing the client to list
# configured providers and permitting privacy-conscious users to authorize
# only their choice.
#
# Note that fluffychat always displays a single button anyway. You do not
# need to enable this to use fluffychat; instead we offer a
# default-provider option, see `default` in the provider config section.
#
#single_sso = false
# Setting this option to true replaces the list of identity providers on
# the client's login screen with a single button "Sign in with single
# sign-on" linking to the URL `/_matrix/client/v3/login/sso/redirect`. The
@@ -2207,16 +2225,25 @@
#callback_url =
# When more than one identity_provider has been configured and
# `sso_custom_providers_page` is false this will determine the results
# for the `/_matrix/client/v3/login/sso/redirect` endpoint (note the url
# lacks a trailing `client_id`).
# `single_sso` is false and `sso_custom_providers_page` is false this will
# determine the behavior of the `/_matrix/client/v3/login/sso/redirect`
# endpoint (note the url lacks a trailing `client_id`).
#
# When only one identity_provider is configured it will be interpreted
# as the default and this does not have to be set. Otherwise a default
# as the default and this does not need to be set. Otherwise a default
# *must* be selected for some clients (e.g. fluffychat) to work properly
# when the above conditions require it. For compatibility if not set a
# warning will be logged on startup and the first provider listed will be
# considered the default.
# when the above conditions require it. To operate out-of-the-box we
# default to one configured provider if none are explicitly default; a
# warning will be logged on startup for this condition.
#
# (EXPERIMENTAL) Multiple providers can be set to default. All providers
# configured with this option set to `true` will associate with the same
# Matrix account when a client flows through
# `/_matrix/client/v3/login/sso/redirect`.
#
# When a user authorizes any provider configured default, the flow will
# include all other providers configured default as well for association.
# NOTE: authorization must succeed for ALL default providers.
#
#default = false
@@ -2299,7 +2326,7 @@
# The duration in seconds before a grant authorization session expires.
#
#grant_session_duration =
#grant_session_duration = 300