Implement associated multi-provider single-sign-on flow support. (#252)

Add experimental note for multi-provider flow. (#252)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-23 03:10:41 +00:00
parent a3294fe1cf
commit 6db87a4027
12 changed files with 411 additions and 197 deletions

View File

@@ -22,10 +22,7 @@ use ruma::api::client::session::{
v3::{DiscoveryInfo, HomeserverInfo, LoginInfo},
},
};
use tuwunel_core::{
Err, Result, info,
utils::{BoolExt, stream::ReadyExt},
};
use tuwunel_core::{Err, Result, info, utils::stream::ReadyExt};
use tuwunel_service::users::device::generate_refresh_token;
use self::{ldap::ldap_login, password::password_login};
@@ -50,13 +47,13 @@ pub(crate) async fn get_login_types_route(
) -> Result<get_login_types::v3::Response> {
let get_login_token = services.config.login_via_existing_session;
let identity_providers = services
let list_idps = !services.config.sso_custom_providers_page && !services.config.single_sso;
let identity_providers: Vec<_> = services
.config
.sso_custom_providers_page
.is_false()
.then(|| services.config.identity_provider.iter())
.into_iter()
.flatten()
.identity_provider
.iter()
.filter(|_| list_idps)
.cloned()
.map(|config| IdentityProvider {
id: config.id().to_owned(),
@@ -70,8 +67,8 @@ pub(crate) async fn get_login_types_route(
LoginType::ApplicationService(ApplicationServiceLoginType::default()),
LoginType::Jwt(JwtLoginType::default()),
LoginType::Password(PasswordLoginType::default()),
LoginType::Sso(SsoLoginType { identity_providers }),
LoginType::Token(TokenLoginType { get_login_token }),
LoginType::Sso(SsoLoginType { identity_providers }),
];
Ok(get_login_types::v3::Response {
@@ -79,8 +76,7 @@ pub(crate) async fn get_login_types_route(
.into_iter()
.filter(|login_type| match login_type {
| LoginType::Sso(SsoLoginType { identity_providers })
if !services.config.sso_custom_providers_page
&& identity_providers.is_empty() =>
if list_idps && identity_providers.is_empty() =>
false,
| _ => true,

View File

@@ -4,7 +4,7 @@ use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use axum_extra::extract::cookie::{Cookie, SameSite};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use futures::{StreamExt, TryFutureExt, future::try_join};
use futures::{FutureExt, StreamExt, TryFutureExt, future::try_join};
use reqwest::header::{CONTENT_TYPE, HeaderValue};
use ruma::{
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
@@ -19,6 +19,7 @@ use tuwunel_core::{
itertools::Itertools,
utils,
utils::{
OptionExt,
content_disposition::make_content_disposition,
hash::sha256,
result::{FlatOk, LogErr},
@@ -98,7 +99,7 @@ pub(crate) async fn sso_login_route(
let redirect_url = body.body.redirect_url;
handle_sso_login(&services, &client, default_idp_id, redirect_url)
handle_sso_login(&services, &client, default_idp_id, redirect_url, None)
.map_ok(|response| sso_login::v3::Response {
location: response.location,
cookie: response.cookie,
@@ -128,8 +129,9 @@ pub(crate) async fn sso_login_with_provider_route(
) -> Result<sso_login_with_provider::v3::Response> {
let idp_id = body.body.idp_id;
let redirect_url = body.body.redirect_url;
let login_token = body.body.login_token;
handle_sso_login(&services, &client, idp_id, redirect_url).await
handle_sso_login(&services, &client, idp_id, redirect_url, login_token).await
}
async fn handle_sso_login(
@@ -137,6 +139,7 @@ async fn handle_sso_login(
_client: &IpAddr,
idp_id: String,
redirect_url: String,
login_token: Option<String>,
) -> Result<sso_login_with_provider::v3::Response> {
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
return Err!(Request(InvalidParam("Invalid redirect_url")));
@@ -221,14 +224,16 @@ async fn handle_sso_login(
.map(timepoint_from_now)
.transpose()?,
user_id: login_token
.as_deref()
.map_async(|token| services.users.find_from_login_token(token))
.map(FlatOk::flat_ok)
.await,
..Default::default()
};
services
.oauth
.sessions
.put(&sess_id, &session)
.await;
services.oauth.sessions.put(&session).await;
Ok(sso_login_with_provider::v3::Response {
location: location.into(),
@@ -269,13 +274,23 @@ pub(crate) async fn sso_callback_route(
.get(sess_id)
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
let idp_id = body.body.idp_id.as_str();
let provider = services.oauth.providers.get(idp_id);
let provider = services
.oauth
.providers
.get(body.body.idp_id.as_str());
let (provider, session) = try_join(provider, session).await.log_err()?;
let client_id = &provider.client_id;
let idp_id = provider.id();
if session.sess_id.as_deref() != Some(sess_id) {
return Err!(Request(Unauthorized("Session ID {sess_id:?} not recognized.")));
}
if session.idp_id.as_deref() != Some(idp_id) {
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
return Err!(Request(Unauthorized(
"Identity Provider {idp_id:?} session not recognized."
)));
}
if session
@@ -346,7 +361,7 @@ pub(crate) async fn sso_callback_route(
// Check for an existing session from this identity. We want to maintain one
// session for each identity and keep the newer one which has up-to-date state
// and access.
let (user_id, old_sess_id) = match services
let (old_user_id, old_sess_id) = match services
.oauth
.sessions
.get_by_unique_id(&unique_id)
@@ -364,9 +379,9 @@ pub(crate) async fn sso_callback_route(
};
// Keep the user_id from the old session as best as possible.
let user_id = match user_id {
| Some(user_id) => user_id,
| None => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
let user_id = match (session.user_id, old_user_id) {
| (Some(user_id), ..) | (None, Some(user_id)) => user_id,
| (None, None) => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
};
// Update the session with user_id
@@ -381,11 +396,7 @@ pub(crate) async fn sso_callback_route(
}
// Commit the updated session.
services
.oauth
.sessions
.put(sess_id, &session)
.await;
services.oauth.sessions.put(&session).await;
// Delete any old session.
if let Some(old_sess_id) = old_sess_id
@@ -398,14 +409,43 @@ pub(crate) async fn sso_callback_route(
return Err!(Request(UserDeactivated("This user has been deactivated.")));
}
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
.removal()
.build()
.to_string()
.into();
// Determine the next provider to chain after this one.
let next_idp_url = services
.config
.identity_provider
.iter()
.filter(|idp| idp.default || services.config.single_sso)
.skip_while(|idp| idp.id() != idp_id)
.nth(1)
.map(IdentityProvider::id)
.and_then(|next_idp| {
provider.callback_url.clone().map(|mut url| {
let path = format!("/_matrix/client/v3/login/sso/redirect/{next_idp}");
url.set_path(&path);
if let Some(redirect_url) = session.redirect_url.as_ref() {
url.query_pairs_mut()
.append_pair("redirectUrl", redirect_url.as_str());
}
url
})
});
// Allow the user to login to Matrix.
let login_token = utils::random_string(TOKEN_LENGTH);
let _login_token_expires_in = services
.users
.create_login_token(&user_id, &login_token);
let location = session
.redirect_url
let location = next_idp_url
.or(session.redirect_url)
.as_ref()
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
.clone()
@@ -414,12 +454,6 @@ pub(crate) async fn sso_callback_route(
.finish()
.to_string();
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
.removal()
.build()
.to_string()
.into();
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
}
@@ -594,8 +628,8 @@ async fn decide_user_id(
}
let length = Some(15..23);
let infallible = truncate_deterministic(unique_id, length).to_lowercase();
if let Some(user_id) = try_user_id(services, &infallible, true).await {
let unique_id = truncate_deterministic(unique_id, length).to_lowercase();
if let Some(user_id) = try_user_id(services, &unique_id, true).await {
return Ok(user_id);
}