Add privileged support for SSO account associations. (#252)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
use clap::Subcommand;
|
use clap::Subcommand;
|
||||||
use futures::{StreamExt, TryStreamExt};
|
use futures::{StreamExt, TryStreamExt};
|
||||||
use ruma::OwnedUserId;
|
use ruma::OwnedUserId;
|
||||||
use tuwunel_core::{Err, Result, utils::stream::IterStream};
|
use tuwunel_core::{Err, Result, apply, err, itertools::Itertools, utils::stream::IterStream};
|
||||||
use tuwunel_service::{
|
use tuwunel_service::{
|
||||||
Services,
|
Services,
|
||||||
oauth::{Provider, Session},
|
oauth::{Provider, Session},
|
||||||
@@ -13,6 +13,19 @@ use crate::{admin_command, admin_command_dispatch};
|
|||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
/// Query OAuth service state
|
/// Query OAuth service state
|
||||||
pub(crate) enum OauthCommand {
|
pub(crate) enum OauthCommand {
|
||||||
|
/// Associate existing user with future authorization claims.
|
||||||
|
Associate {
|
||||||
|
/// ID of configured provider to listen on.
|
||||||
|
provider: String,
|
||||||
|
|
||||||
|
/// MXID of local user to associate.
|
||||||
|
user_id: OwnedUserId,
|
||||||
|
|
||||||
|
/// List of claims to match in key=value format.
|
||||||
|
#[arg(long, required = true)]
|
||||||
|
claim: Vec<String>,
|
||||||
|
},
|
||||||
|
|
||||||
/// List configured OAuth providers.
|
/// List configured OAuth providers.
|
||||||
ListProviders,
|
ListProviders,
|
||||||
|
|
||||||
@@ -51,6 +64,53 @@ pub(crate) enum OauthCommand {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_associate(
|
||||||
|
&self,
|
||||||
|
provider: String,
|
||||||
|
user_id: OwnedUserId,
|
||||||
|
claim: Vec<String>,
|
||||||
|
) -> Result {
|
||||||
|
if !self.services.globals.user_is_local(&user_id) {
|
||||||
|
return Err!(Request(NotFound("User {user_id:?} does not belong to this server.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.services.users.exists(&user_id).await {
|
||||||
|
return Err!(Request(NotFound("User {user_id:?} is not registered")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = self
|
||||||
|
.services
|
||||||
|
.oauth
|
||||||
|
.providers
|
||||||
|
.get(&provider)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let claim = claim
|
||||||
|
.iter()
|
||||||
|
.map(|kv| {
|
||||||
|
let (key, val) = kv
|
||||||
|
.split_once('=')
|
||||||
|
.ok_or_else(|| err!("Missing '=' in --claim {kv}=???"))?;
|
||||||
|
|
||||||
|
if !key.is_empty() && !val.is_empty() {
|
||||||
|
Ok((key, val))
|
||||||
|
} else {
|
||||||
|
Err!("Missing key or value in --claim=key=value argument")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map_ok(apply!(2, ToOwned::to_owned))
|
||||||
|
.collect::<Result<_>>()?;
|
||||||
|
|
||||||
|
let _replaced = self
|
||||||
|
.services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.set_user_association_pending(provider.id(), &user_id, claim);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[admin_command]
|
#[admin_command]
|
||||||
pub(super) async fn oauth_list_providers(&self) -> Result {
|
pub(super) async fn oauth_list_providers(&self) -> Result {
|
||||||
self.services
|
self.services
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ use axum_client_ip::InsecureClientIp;
|
|||||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||||
use futures::{StreamExt, TryFutureExt, future::try_join};
|
use futures::{StreamExt, TryFutureExt, future::try_join};
|
||||||
use itertools::Itertools;
|
|
||||||
use reqwest::header::{CONTENT_TYPE, HeaderValue};
|
use reqwest::header::{CONTENT_TYPE, HeaderValue};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
|
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
|
||||||
@@ -16,7 +15,9 @@ use tuwunel_core::{
|
|||||||
Err, Result, at,
|
Err, Result, at,
|
||||||
config::IdentityProvider,
|
config::IdentityProvider,
|
||||||
debug::INFO_SPAN_LEVEL,
|
debug::INFO_SPAN_LEVEL,
|
||||||
debug_info, debug_warn, err, info, utils,
|
debug_info, debug_warn, err, info,
|
||||||
|
itertools::Itertools,
|
||||||
|
utils,
|
||||||
utils::{
|
utils::{
|
||||||
content_disposition::make_content_disposition,
|
content_disposition::make_content_disposition,
|
||||||
hash::sha256,
|
hash::sha256,
|
||||||
@@ -30,8 +31,7 @@ use tuwunel_service::{
|
|||||||
Services,
|
Services,
|
||||||
media::MXC_LENGTH,
|
media::MXC_LENGTH,
|
||||||
oauth::{
|
oauth::{
|
||||||
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id,
|
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id_sub,
|
||||||
unique_id_sub,
|
|
||||||
},
|
},
|
||||||
users::Register,
|
users::Register,
|
||||||
};
|
};
|
||||||
@@ -341,13 +341,15 @@ pub(crate) async fn sso_callback_route(
|
|||||||
.request_userinfo((&provider, &session))
|
.request_userinfo((&provider, &session))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let unique_id = unique_id_sub((&provider, &userinfo.sub))?;
|
||||||
|
|
||||||
// Check for an existing session from this identity. We want to maintain one
|
// Check for an existing session from this identity. We want to maintain one
|
||||||
// session for each identity and keep the newer one which has up-to-date state
|
// session for each identity and keep the newer one which has up-to-date state
|
||||||
// and access.
|
// and access.
|
||||||
let (user_id, old_sess_id) = match services
|
let (user_id, old_sess_id) = match services
|
||||||
.oauth
|
.oauth
|
||||||
.sessions
|
.sessions
|
||||||
.get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?)
|
.get_by_unique_id(&unique_id)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
| Ok(session) => (session.user_id, session.sess_id),
|
| Ok(session) => (session.user_id, session.sess_id),
|
||||||
@@ -364,7 +366,7 @@ pub(crate) async fn sso_callback_route(
|
|||||||
// Keep the user_id from the old session as best as possible.
|
// Keep the user_id from the old session as best as possible.
|
||||||
let user_id = match user_id {
|
let user_id = match user_id {
|
||||||
| Some(user_id) => user_id,
|
| Some(user_id) => user_id,
|
||||||
| None => decide_user_id(&services, &provider, &session, &userinfo).await?,
|
| None => decide_user_id(&services, &provider, &userinfo, &unique_id).await?,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Update the session with user_id
|
// Update the session with user_id
|
||||||
@@ -538,9 +540,24 @@ async fn set_avatar(
|
|||||||
async fn decide_user_id(
|
async fn decide_user_id(
|
||||||
services: &Services,
|
services: &Services,
|
||||||
provider: &Provider,
|
provider: &Provider,
|
||||||
session: &Session,
|
|
||||||
userinfo: &UserInfo,
|
userinfo: &UserInfo,
|
||||||
|
unique_id: &str,
|
||||||
) -> Result<OwnedUserId> {
|
) -> Result<OwnedUserId> {
|
||||||
|
if let Some(user_id) = services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.find_user_association_pending(provider.id(), userinfo)
|
||||||
|
{
|
||||||
|
debug_info!(
|
||||||
|
provider = ?provider.id(),
|
||||||
|
?user_id,
|
||||||
|
?userinfo,
|
||||||
|
"Matched pending association"
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(user_id);
|
||||||
|
}
|
||||||
|
|
||||||
let allowed =
|
let allowed =
|
||||||
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
|
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
|
||||||
|
|
||||||
@@ -576,13 +593,10 @@ async fn decide_user_id(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(infallible) = unique_id((provider, session))
|
let length = Some(15..23);
|
||||||
.map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase())
|
let infallible = truncate_deterministic(unique_id, length).to_lowercase();
|
||||||
.log_err()
|
if let Some(user_id) = try_user_id(services, &infallible, true).await {
|
||||||
{
|
return Ok(user_id);
|
||||||
if let Some(user_id) = try_user_id(services, &infallible, true).await {
|
|
||||||
return Ok(user_id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Err!(Request(UserInUse("User ID is not available.")))
|
Err!(Request(UserInUse("User ID is not available.")))
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
use std::{sync::Arc, time::SystemTime};
|
pub mod association;
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
time::SystemTime,
|
||||||
|
};
|
||||||
|
|
||||||
use futures::{Stream, StreamExt, TryFutureExt};
|
use futures::{Stream, StreamExt, TryFutureExt};
|
||||||
use ruma::{OwnedUserId, UserId};
|
use ruma::{OwnedUserId, UserId};
|
||||||
@@ -12,6 +17,7 @@ use crate::SelfServices;
|
|||||||
|
|
||||||
pub struct Sessions {
|
pub struct Sessions {
|
||||||
_services: SelfServices,
|
_services: SelfServices,
|
||||||
|
association_pending: Mutex<association::Pending>,
|
||||||
providers: Arc<Providers>,
|
providers: Arc<Providers>,
|
||||||
db: Data,
|
db: Data,
|
||||||
}
|
}
|
||||||
@@ -94,6 +100,7 @@ pub const SESSION_ID_LENGTH: usize = 32;
|
|||||||
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
_services: args.services.clone(),
|
_services: args.services.clone(),
|
||||||
|
association_pending: Default::default(),
|
||||||
providers,
|
providers,
|
||||||
db: Data {
|
db: Data {
|
||||||
oauthid_session: args.db["oauthid_session"].clone(),
|
oauthid_session: args.db["oauthid_session"].clone(),
|
||||||
|
|||||||
98
src/service/oauth/sessions/association.rs
Normal file
98
src/service/oauth/sessions/association.rs
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use ruma::{OwnedUserId, UserId};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tuwunel_core::{debug, implement, trace};
|
||||||
|
|
||||||
|
use super::{Sessions, UserInfo};
|
||||||
|
|
||||||
|
pub(super) type Pending = BTreeMap<String, Claimants>;
|
||||||
|
type Claimants = BTreeMap<OwnedUserId, Claims>;
|
||||||
|
pub type Claims = BTreeMap<String, String>;
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn set_user_association_pending(
|
||||||
|
&self,
|
||||||
|
idp_id: &str,
|
||||||
|
user_id: &UserId,
|
||||||
|
claims: Claims,
|
||||||
|
) -> Option<Claims> {
|
||||||
|
self.association_pending
|
||||||
|
.lock()
|
||||||
|
.expect("locked")
|
||||||
|
.entry(idp_id.into())
|
||||||
|
.or_default()
|
||||||
|
.insert(user_id.into(), claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn find_user_association_pending(
|
||||||
|
&self,
|
||||||
|
idp_id: &str,
|
||||||
|
userinfo: &UserInfo,
|
||||||
|
) -> Option<OwnedUserId> {
|
||||||
|
let claiming = serde_json::to_value(userinfo)
|
||||||
|
.expect("Failed to transform user_info into serde_json::Value");
|
||||||
|
|
||||||
|
let claiming = claiming
|
||||||
|
.as_object()
|
||||||
|
.expect("Failed to interpret user_info as object");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!claiming.is_empty(),
|
||||||
|
"Expecting at least one claim from user_info such as `sub`"
|
||||||
|
);
|
||||||
|
|
||||||
|
debug!(?idp_id, ?claiming, "finding pending association",);
|
||||||
|
self.association_pending
|
||||||
|
.lock()
|
||||||
|
.expect("locked")
|
||||||
|
.get(idp_id)
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(Claimants::iter)
|
||||||
|
.find_map(|(user_id, claimant)| {
|
||||||
|
trace!(?user_id, ?claimant, "checking against pending association");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!claimant.is_empty(),
|
||||||
|
"Must not match empty set of claims; should not exist in association_pending."
|
||||||
|
);
|
||||||
|
|
||||||
|
for (claim, value) in claimant {
|
||||||
|
if claiming.get(claim).and_then(Value::as_str) != Some(value) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(user_id.clone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn remove_provider_associations_pending(&self, idp_id: &str) {
|
||||||
|
self.association_pending
|
||||||
|
.lock()
|
||||||
|
.expect("locked")
|
||||||
|
.remove(idp_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn remove_user_association_pending(&self, user_id: &UserId, idp_id: Option<&str>) {
|
||||||
|
self.association_pending
|
||||||
|
.lock()
|
||||||
|
.expect("locked")
|
||||||
|
.iter_mut()
|
||||||
|
.filter(|(provider, _)| idp_id == Some(provider))
|
||||||
|
.for_each(|(_, claiming)| {
|
||||||
|
claiming.remove(user_id);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn is_user_association_pending(&self, user_id: &UserId) -> bool {
|
||||||
|
self.association_pending
|
||||||
|
.lock()
|
||||||
|
.expect("locked")
|
||||||
|
.values()
|
||||||
|
.any(|claiming| claiming.contains_key(user_id))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user