Implement associated multi-provider single-sign-on flow support. (#252)

Add experimental note for multi-provider flow. (#252)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-23 03:10:41 +00:00
parent a3294fe1cf
commit 6db87a4027
12 changed files with 411 additions and 197 deletions

View File

@@ -5,6 +5,7 @@ pub mod user_info;
use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::{
Method,
header::{ACCEPT, CONTENT_TYPE},
@@ -14,16 +15,16 @@ use serde::Serialize;
use serde_json::Value as JsonValue;
use tuwunel_core::{
Err, Result, err, implement,
utils::{hash::sha256, result::LogErr},
utils::{hash::sha256, result::LogErr, stream::ReadyExt},
};
use url::Url;
use self::{providers::Providers, sessions::Sessions};
pub use self::{
providers::Provider,
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
providers::{Provider, ProviderId},
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
user_info::UserInfo,
};
use self::{providers::Providers, sessions::Sessions};
use crate::SelfServices;
pub struct Service {
@@ -46,6 +47,48 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
/// Remove all session state for a user. For debug and developer use only;
/// deleting state can cause registration conflicts and unintended
/// re-registrations.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn delete_user_sessions(&self, user_id: &UserId) {
self.user_sessions(user_id)
.ready_filter_map(Result::ok)
.ready_filter_map(|(_, session)| session.sess_id)
.for_each(async |sess_id| {
self.sessions.delete(&sess_id).await;
})
.await;
}
/// Revoke all session tokens for a user.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn revoke_user_tokens(&self, user_id: &UserId) {
self.user_sessions(user_id)
.ready_filter_map(Result::ok)
.for_each(async |(provider, session)| {
self.revoke_token((&provider, &session))
.await
.log_err()
.ok();
})
.await;
}
/// Get user's authorizations. Lists pairs of `(Provider, Session)` for a user.
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub fn user_sessions(
&self,
user_id: &UserId,
) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
self.sessions
.get_by_user(user_id)
.and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
}
/// Network request to a Provider returning userinfo for a Session. The session
/// must have a valid access token.
#[implement(Service)]
@@ -222,15 +265,6 @@ where
Ok(response)
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
let session = self.sessions.get_by_user(user_id).await?;
let provider = self.sessions.provider(&session).await?;
Ok((provider, session))
}
/// Generate a unique-id string determined by the combination of `Provider` and
/// `Session` instances.
#[inline]

View File

@@ -8,12 +8,16 @@ use url::Url;
use crate::SelfServices;
/// Discovered providers
#[derive(Default)]
pub struct Providers {
services: SelfServices,
providers: RwLock<BTreeMap<String, Provider>>,
providers: RwLock<BTreeMap<ProviderId, Provider>>,
}
/// Identity Provider ID
pub type ProviderId = String;
#[implement(Providers)]
pub(super) fn build(args: &crate::Args<'_>) -> Self {
Self {

View File

@@ -1,15 +1,19 @@
pub mod association;
use std::{
iter::once,
sync::{Arc, Mutex},
time::SystemTime,
};
use futures::{Stream, StreamExt, TryFutureExt};
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
use tuwunel_database::{Cbor, Deserialized, Map};
use tuwunel_core::{
Err, Result, at, implement,
utils::stream::{IterStream, ReadyExt, TryExpect},
};
use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
use url::Url;
use super::{Provider, Providers, UserInfo, unique_id};
@@ -41,7 +45,7 @@ pub struct Session {
pub idp_id: Option<String>,
/// Session ID used as the index key for this session itself.
pub sess_id: Option<String>,
pub sess_id: Option<SessionId>,
/// Token type (bearer, mac, etc).
pub token_type: Option<String>,
@@ -89,6 +93,9 @@ pub struct Session {
pub user_info: Option<UserInfo>,
}
/// Session Identifier type.
pub type SessionId = String;
/// Number of characters generated for our code_verifier. The code_verifier is a
/// random string which must be between 43 and 128 characters.
pub const CODE_VERIFIER_LENGTH: usize = 64;
@@ -110,15 +117,6 @@ pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
}
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
/// Delete database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self))]
@@ -127,13 +125,19 @@ pub async fn delete(&self, sess_id: &str) {
return;
};
// Check the user_id still points to this sess_id before deleting. If not, the
// association was updated to a newer session.
if let Some(user_id) = session.user_id.as_deref()
&& let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await
&& assoc_id == sess_id
{
self.db.userid_oauthid.remove(user_id);
if let Some(user_id) = session.user_id.as_deref() {
let sess_ids: Vec<_> = self
.get_sess_id_by_user(user_id)
.ready_filter_map(Result::ok)
.ready_filter(|assoc_id| assoc_id != sess_id)
.collect()
.await;
if !sess_ids.is_empty() {
self.db.userid_oauthid.raw_put(user_id, sess_ids);
} else {
self.db.userid_oauthid.remove(user_id);
}
}
// Check the unique identity still points to this sess_id before deleting. If
@@ -153,15 +157,16 @@ pub async fn delete(&self, sess_id: &str) {
/// Create or overwrite database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "info", skip(self))]
pub async fn put(&self, sess_id: &str, session: &Session) {
pub async fn put(&self, session: &Session) {
let sess_id = session
.sess_id
.as_deref()
.expect("Missing session.sess_id required for sessions.put()");
self.db
.oauthid_session
.raw_put(sess_id, Cbor(session));
if let Some(user_id) = session.user_id.as_deref() {
self.db.userid_oauthid.insert(user_id, sess_id);
}
if let Some(idp_id) = session.idp_id.as_ref()
&& let Ok(provider) = self.providers.get(idp_id).await
&& let Ok(unique_id) = unique_id((&provider, session))
@@ -170,6 +175,22 @@ pub async fn put(&self, sess_id: &str, session: &Session) {
.oauthuniqid_oauthid
.insert(&unique_id, sess_id);
}
if let Some(user_id) = session.user_id.as_deref() {
let sess_ids = self
.get_sess_id_by_user(user_id)
.ready_filter_map(Result::ok)
.chain(once(sess_id.to_owned()).stream())
.collect::<Vec<_>>()
.map(|mut ids| {
ids.sort_unstable();
ids.dedup();
ids
})
.await;
self.db.userid_oauthid.raw_put(user_id, sess_ids);
}
}
/// Fetch database state for a session from its associated `(iss,sub)`, in case
@@ -182,14 +203,13 @@ pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
.await
}
/// Fetch database state for a session from its associated `user_id`, in case
/// `sess_id` is not known.
/// Fetch database state for one or more sessions from its associated `user_id`,
/// in case `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
#[tracing::instrument(level = "debug", skip(self))]
pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
self.get_sess_id_by_user(user_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its `sess_id`.
@@ -204,15 +224,17 @@ pub async fn get(&self, sess_id: &str) -> Result<Session> {
.map(at!(0))
}
/// Resolve the `sess_id` from an associated `user_id`.
/// Resolve the `sess_id` associations with a `user_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
#[tracing::instrument(level = "debug", skip(self))]
pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
self.db
.userid_oauthid
.get(user_id)
.await
.deserialized()
.map(Deserialized::deserialized)
.map_ok(Vec::into_iter)
.map_ok(IterStream::try_stream)
.try_flatten_stream()
}
/// Resolve the `sess_id` from an associated provider issuer and subject hash.
@@ -226,6 +248,24 @@ pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String>
.deserialized()
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
#[implement(Sessions)]
pub fn stream(&self) -> impl Stream<Item = Session> + Send {
self.db
.oauthid_session
.stream()
.expect_ok()
.map(|(_, session): (Ignore, Cbor<_>)| session.0)
}
#[implement(Sessions)]
pub async fn provider(&self, session: &Session) -> Result<Provider> {
let Some(idp_id) = session.idp_id.as_deref() else {

View File

@@ -17,7 +17,7 @@ use tuwunel_core::{
Err, Result, debug_warn, err, is_equal_to,
pdu::PduBuilder,
trace,
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
utils::{self, ReadyExt, stream::TryIgnore},
warn,
};
use tuwunel_database::{Deserialized, Json, Map};
@@ -133,14 +133,10 @@ impl Service {
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
// Revoke any SSO authorizations
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
self.services
.oauth
.revoke_token((&provider, &session))
.await
.log_err()
.ok();
}
self.services
.oauth
.revoke_user_tokens(user_id)
.await;
// Remove all associated devices
self.all_device_ids(user_id)