Implement associated multi-provider single-sign-on flow support. (#252)
Add experimental note for multi-provider flow. (#252) Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -5,6 +5,7 @@ pub mod user_info;
|
||||
use std::sync::Arc;
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use reqwest::{
|
||||
Method,
|
||||
header::{ACCEPT, CONTENT_TYPE},
|
||||
@@ -14,16 +15,16 @@ use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, err, implement,
|
||||
utils::{hash::sha256, result::LogErr},
|
||||
utils::{hash::sha256, result::LogErr, stream::ReadyExt},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
pub use self::{
|
||||
providers::Provider,
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
|
||||
providers::{Provider, ProviderId},
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
|
||||
user_info::UserInfo,
|
||||
};
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
use crate::SelfServices;
|
||||
|
||||
pub struct Service {
|
||||
@@ -46,6 +47,48 @@ impl crate::Service for Service {
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
/// Remove all session state for a user. For debug and developer use only;
|
||||
/// deleting state can cause registration conflicts and unintended
|
||||
/// re-registrations.
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn delete_user_sessions(&self, user_id: &UserId) {
|
||||
self.user_sessions(user_id)
|
||||
.ready_filter_map(Result::ok)
|
||||
.ready_filter_map(|(_, session)| session.sess_id)
|
||||
.for_each(async |sess_id| {
|
||||
self.sessions.delete(&sess_id).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Revoke all session tokens for a user.
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn revoke_user_tokens(&self, user_id: &UserId) {
|
||||
self.user_sessions(user_id)
|
||||
.ready_filter_map(Result::ok)
|
||||
.for_each(async |(provider, session)| {
|
||||
self.revoke_token((&provider, &session))
|
||||
.await
|
||||
.log_err()
|
||||
.ok();
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Get user's authorizations. Lists pairs of `(Provider, Session)` for a user.
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub fn user_sessions(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> impl Stream<Item = Result<(Provider, Session)>> + Send {
|
||||
self.sessions
|
||||
.get_by_user(user_id)
|
||||
.and_then(async |session| Ok((self.sessions.provider(&session).await?, session)))
|
||||
}
|
||||
|
||||
/// Network request to a Provider returning userinfo for a Session. The session
|
||||
/// must have a valid access token.
|
||||
#[implement(Service)]
|
||||
@@ -222,15 +265,6 @@ where
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
|
||||
let session = self.sessions.get_by_user(user_id).await?;
|
||||
let provider = self.sessions.provider(&session).await?;
|
||||
|
||||
Ok((provider, session))
|
||||
}
|
||||
|
||||
/// Generate a unique-id string determined by the combination of `Provider` and
|
||||
/// `Session` instances.
|
||||
#[inline]
|
||||
|
||||
@@ -8,12 +8,16 @@ use url::Url;
|
||||
|
||||
use crate::SelfServices;
|
||||
|
||||
/// Discovered providers
|
||||
#[derive(Default)]
|
||||
pub struct Providers {
|
||||
services: SelfServices,
|
||||
providers: RwLock<BTreeMap<String, Provider>>,
|
||||
providers: RwLock<BTreeMap<ProviderId, Provider>>,
|
||||
}
|
||||
|
||||
/// Identity Provider ID
|
||||
pub type ProviderId = String;
|
||||
|
||||
#[implement(Providers)]
|
||||
pub(super) fn build(args: &crate::Args<'_>) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
pub mod association;
|
||||
|
||||
use std::{
|
||||
iter::once,
|
||||
sync::{Arc, Mutex},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use ruma::{OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, implement,
|
||||
utils::stream::{IterStream, ReadyExt, TryExpect},
|
||||
};
|
||||
use tuwunel_database::{Cbor, Deserialized, Ignore, Map};
|
||||
use url::Url;
|
||||
|
||||
use super::{Provider, Providers, UserInfo, unique_id};
|
||||
@@ -41,7 +45,7 @@ pub struct Session {
|
||||
pub idp_id: Option<String>,
|
||||
|
||||
/// Session ID used as the index key for this session itself.
|
||||
pub sess_id: Option<String>,
|
||||
pub sess_id: Option<SessionId>,
|
||||
|
||||
/// Token type (bearer, mac, etc).
|
||||
pub token_type: Option<String>,
|
||||
@@ -89,6 +93,9 @@ pub struct Session {
|
||||
pub user_info: Option<UserInfo>,
|
||||
}
|
||||
|
||||
/// Session Identifier type.
|
||||
pub type SessionId = String;
|
||||
|
||||
/// Number of characters generated for our code_verifier. The code_verifier is a
|
||||
/// random string which must be between 43 and 128 characters.
|
||||
pub const CODE_VERIFIER_LENGTH: usize = 64;
|
||||
@@ -110,15 +117,6 @@ pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.keys()
|
||||
.expect_ok()
|
||||
.map(UserId::to_owned)
|
||||
}
|
||||
|
||||
/// Delete database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
@@ -127,13 +125,19 @@ pub async fn delete(&self, sess_id: &str) {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check the user_id still points to this sess_id before deleting. If not, the
|
||||
// association was updated to a newer session.
|
||||
if let Some(user_id) = session.user_id.as_deref()
|
||||
&& let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await
|
||||
&& assoc_id == sess_id
|
||||
{
|
||||
self.db.userid_oauthid.remove(user_id);
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
let sess_ids: Vec<_> = self
|
||||
.get_sess_id_by_user(user_id)
|
||||
.ready_filter_map(Result::ok)
|
||||
.ready_filter(|assoc_id| assoc_id != sess_id)
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
if !sess_ids.is_empty() {
|
||||
self.db.userid_oauthid.raw_put(user_id, sess_ids);
|
||||
} else {
|
||||
self.db.userid_oauthid.remove(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Check the unique identity still points to this sess_id before deleting. If
|
||||
@@ -153,15 +157,16 @@ pub async fn delete(&self, sess_id: &str) {
|
||||
/// Create or overwrite database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "info", skip(self))]
|
||||
pub async fn put(&self, sess_id: &str, session: &Session) {
|
||||
pub async fn put(&self, session: &Session) {
|
||||
let sess_id = session
|
||||
.sess_id
|
||||
.as_deref()
|
||||
.expect("Missing session.sess_id required for sessions.put()");
|
||||
|
||||
self.db
|
||||
.oauthid_session
|
||||
.raw_put(sess_id, Cbor(session));
|
||||
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
self.db.userid_oauthid.insert(user_id, sess_id);
|
||||
}
|
||||
|
||||
if let Some(idp_id) = session.idp_id.as_ref()
|
||||
&& let Ok(provider) = self.providers.get(idp_id).await
|
||||
&& let Ok(unique_id) = unique_id((&provider, session))
|
||||
@@ -170,6 +175,22 @@ pub async fn put(&self, sess_id: &str, session: &Session) {
|
||||
.oauthuniqid_oauthid
|
||||
.insert(&unique_id, sess_id);
|
||||
}
|
||||
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
let sess_ids = self
|
||||
.get_sess_id_by_user(user_id)
|
||||
.ready_filter_map(Result::ok)
|
||||
.chain(once(sess_id.to_owned()).stream())
|
||||
.collect::<Vec<_>>()
|
||||
.map(|mut ids| {
|
||||
ids.sort_unstable();
|
||||
ids.dedup();
|
||||
ids
|
||||
})
|
||||
.await;
|
||||
|
||||
self.db.userid_oauthid.raw_put(user_id, sess_ids);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `(iss,sub)`, in case
|
||||
@@ -182,14 +203,13 @@ pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `user_id`, in case
|
||||
/// `sess_id` is not known.
|
||||
/// Fetch database state for one or more sessions from its associated `user_id`,
|
||||
/// in case `sess_id` is not known.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub fn get_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<Session>> + Send {
|
||||
self.get_sess_id_by_user(user_id)
|
||||
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its `sess_id`.
|
||||
@@ -204,15 +224,17 @@ pub async fn get(&self, sess_id: &str) -> Result<Session> {
|
||||
.map(at!(0))
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated `user_id`.
|
||||
/// Resolve the `sess_id` associations with a `user_id`.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub fn get_sess_id_by_user(&self, user_id: &UserId) -> impl Stream<Item = Result<String>> + Send {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized()
|
||||
.map(Deserialized::deserialized)
|
||||
.map_ok(Vec::into_iter)
|
||||
.map_ok(IterStream::try_stream)
|
||||
.try_flatten_stream()
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated provider issuer and subject hash.
|
||||
@@ -226,6 +248,24 @@ pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String>
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.keys()
|
||||
.expect_ok()
|
||||
.map(UserId::to_owned)
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub fn stream(&self) -> impl Stream<Item = Session> + Send {
|
||||
self.db
|
||||
.oauthid_session
|
||||
.stream()
|
||||
.expect_ok()
|
||||
.map(|(_, session): (Ignore, Cbor<_>)| session.0)
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub async fn provider(&self, session: &Session) -> Result<Provider> {
|
||||
let Some(idp_id) = session.idp_id.as_deref() else {
|
||||
|
||||
@@ -17,7 +17,7 @@ use tuwunel_core::{
|
||||
Err, Result, debug_warn, err, is_equal_to,
|
||||
pdu::PduBuilder,
|
||||
trace,
|
||||
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
|
||||
utils::{self, ReadyExt, stream::TryIgnore},
|
||||
warn,
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Json, Map};
|
||||
@@ -133,14 +133,10 @@ impl Service {
|
||||
/// Deactivate account
|
||||
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
|
||||
// Revoke any SSO authorizations
|
||||
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
|
||||
self.services
|
||||
.oauth
|
||||
.revoke_token((&provider, &session))
|
||||
.await
|
||||
.log_err()
|
||||
.ok();
|
||||
}
|
||||
self.services
|
||||
.oauth
|
||||
.revoke_user_tokens(user_id)
|
||||
.await;
|
||||
|
||||
// Remove all associated devices
|
||||
self.all_device_ids(user_id)
|
||||
|
||||
Reference in New Issue
Block a user