From fb102f0e0aa6b2e06a64d92dbf8cf5e687b7dec7 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 18 Jan 2026 12:49:14 +0000 Subject: [PATCH] Add privileged support for SSO account associations. (#252) Signed-off-by: Jason Volk --- src/admin/query/oauth.rs | 62 +++++++++++++- src/api/client/session/sso.rs | 42 ++++++---- src/service/oauth/sessions.rs | 9 ++- src/service/oauth/sessions/association.rs | 98 +++++++++++++++++++++++ 4 files changed, 195 insertions(+), 16 deletions(-) create mode 100644 src/service/oauth/sessions/association.rs diff --git a/src/admin/query/oauth.rs b/src/admin/query/oauth.rs index e3cecaa4..9f8fb09b 100644 --- a/src/admin/query/oauth.rs +++ b/src/admin/query/oauth.rs @@ -1,7 +1,7 @@ use clap::Subcommand; use futures::{StreamExt, TryStreamExt}; use ruma::OwnedUserId; -use tuwunel_core::{Err, Result, utils::stream::IterStream}; +use tuwunel_core::{Err, Result, apply, err, itertools::Itertools, utils::stream::IterStream}; use tuwunel_service::{ Services, oauth::{Provider, Session}, @@ -13,6 +13,19 @@ use crate::{admin_command, admin_command_dispatch}; #[derive(Debug, Subcommand)] /// Query OAuth service state pub(crate) enum OauthCommand { + /// Associate existing user with future authorization claims. + Associate { + /// ID of configured provider to listen on. + provider: String, + + /// MXID of local user to associate. + user_id: OwnedUserId, + + /// List of claims to match in key=value format. + #[arg(long, required = true)] + claim: Vec, + }, + /// List configured OAuth providers. ListProviders, @@ -51,6 +64,53 @@ pub(crate) enum OauthCommand { }, } +#[admin_command] +pub(super) async fn oauth_associate( + &self, + provider: String, + user_id: OwnedUserId, + claim: Vec, +) -> Result { + if !self.services.globals.user_is_local(&user_id) { + return Err!(Request(NotFound("User {user_id:?} does not belong to this server."))); + } + + if !self.services.users.exists(&user_id).await { + return Err!(Request(NotFound("User {user_id:?} is not registered"))); + } + + let provider = self + .services + .oauth + .providers + .get(&provider) + .await?; + + let claim = claim + .iter() + .map(|kv| { + let (key, val) = kv + .split_once('=') + .ok_or_else(|| err!("Missing '=' in --claim {kv}=???"))?; + + if !key.is_empty() && !val.is_empty() { + Ok((key, val)) + } else { + Err!("Missing key or value in --claim=key=value argument") + } + }) + .map_ok(apply!(2, ToOwned::to_owned)) + .collect::>()?; + + let _replaced = self + .services + .oauth + .sessions + .set_user_association_pending(provider.id(), &user_id, claim); + + Ok(()) +} + #[admin_command] pub(super) async fn oauth_list_providers(&self) -> Result { self.services diff --git a/src/api/client/session/sso.rs b/src/api/client/session/sso.rs index 769cd627..19d95e94 100644 --- a/src/api/client/session/sso.rs +++ b/src/api/client/session/sso.rs @@ -5,7 +5,6 @@ use axum_client_ip::InsecureClientIp; use axum_extra::extract::cookie::{Cookie, SameSite}; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64}; use futures::{StreamExt, TryFutureExt, future::try_join}; -use itertools::Itertools; use reqwest::header::{CONTENT_TYPE, HeaderValue}; use ruma::{ Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId, @@ -16,7 +15,9 @@ use tuwunel_core::{ Err, Result, at, config::IdentityProvider, debug::INFO_SPAN_LEVEL, - debug_info, debug_warn, err, info, utils, + debug_info, debug_warn, err, info, + itertools::Itertools, + utils, utils::{ content_disposition::make_content_disposition, hash::sha256, @@ -30,8 +31,7 @@ use tuwunel_service::{ Services, media::MXC_LENGTH, oauth::{ - CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id, - unique_id_sub, + CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id_sub, }, users::Register, }; @@ -341,13 +341,15 @@ pub(crate) async fn sso_callback_route( .request_userinfo((&provider, &session)) .await?; + let unique_id = unique_id_sub((&provider, &userinfo.sub))?; + // Check for an existing session from this identity. We want to maintain one // session for each identity and keep the newer one which has up-to-date state // and access. let (user_id, old_sess_id) = match services .oauth .sessions - .get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?) + .get_by_unique_id(&unique_id) .await { | Ok(session) => (session.user_id, session.sess_id), @@ -364,7 +366,7 @@ pub(crate) async fn sso_callback_route( // Keep the user_id from the old session as best as possible. let user_id = match user_id { | Some(user_id) => user_id, - | None => decide_user_id(&services, &provider, &session, &userinfo).await?, + | None => decide_user_id(&services, &provider, &userinfo, &unique_id).await?, }; // Update the session with user_id @@ -538,9 +540,24 @@ async fn set_avatar( async fn decide_user_id( services: &Services, provider: &Provider, - session: &Session, userinfo: &UserInfo, + unique_id: &str, ) -> Result { + if let Some(user_id) = services + .oauth + .sessions + .find_user_association_pending(provider.id(), userinfo) + { + debug_info!( + provider = ?provider.id(), + ?user_id, + ?userinfo, + "Matched pending association" + ); + + return Ok(user_id); + } + let allowed = |claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim); @@ -576,13 +593,10 @@ async fn decide_user_id( } } - if let Ok(infallible) = unique_id((provider, session)) - .map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase()) - .log_err() - { - if let Some(user_id) = try_user_id(services, &infallible, true).await { - return Ok(user_id); - } + let length = Some(15..23); + let infallible = truncate_deterministic(unique_id, length).to_lowercase(); + if let Some(user_id) = try_user_id(services, &infallible, true).await { + return Ok(user_id); } Err!(Request(UserInUse("User ID is not available."))) diff --git a/src/service/oauth/sessions.rs b/src/service/oauth/sessions.rs index 947b4879..878feece 100644 --- a/src/service/oauth/sessions.rs +++ b/src/service/oauth/sessions.rs @@ -1,4 +1,9 @@ -use std::{sync::Arc, time::SystemTime}; +pub mod association; + +use std::{ + sync::{Arc, Mutex}, + time::SystemTime, +}; use futures::{Stream, StreamExt, TryFutureExt}; use ruma::{OwnedUserId, UserId}; @@ -12,6 +17,7 @@ use crate::SelfServices; pub struct Sessions { _services: SelfServices, + association_pending: Mutex, providers: Arc, db: Data, } @@ -94,6 +100,7 @@ pub const SESSION_ID_LENGTH: usize = 32; pub(super) fn build(args: &crate::Args<'_>, providers: Arc) -> Self { Self { _services: args.services.clone(), + association_pending: Default::default(), providers, db: Data { oauthid_session: args.db["oauthid_session"].clone(), diff --git a/src/service/oauth/sessions/association.rs b/src/service/oauth/sessions/association.rs new file mode 100644 index 00000000..64035a64 --- /dev/null +++ b/src/service/oauth/sessions/association.rs @@ -0,0 +1,98 @@ +use std::collections::BTreeMap; + +use ruma::{OwnedUserId, UserId}; +use serde_json::Value; +use tuwunel_core::{debug, implement, trace}; + +use super::{Sessions, UserInfo}; + +pub(super) type Pending = BTreeMap; +type Claimants = BTreeMap; +pub type Claims = BTreeMap; + +#[implement(Sessions)] +pub fn set_user_association_pending( + &self, + idp_id: &str, + user_id: &UserId, + claims: Claims, +) -> Option { + self.association_pending + .lock() + .expect("locked") + .entry(idp_id.into()) + .or_default() + .insert(user_id.into(), claims) +} + +#[implement(Sessions)] +pub fn find_user_association_pending( + &self, + idp_id: &str, + userinfo: &UserInfo, +) -> Option { + let claiming = serde_json::to_value(userinfo) + .expect("Failed to transform user_info into serde_json::Value"); + + let claiming = claiming + .as_object() + .expect("Failed to interpret user_info as object"); + + assert!( + !claiming.is_empty(), + "Expecting at least one claim from user_info such as `sub`" + ); + + debug!(?idp_id, ?claiming, "finding pending association",); + self.association_pending + .lock() + .expect("locked") + .get(idp_id) + .into_iter() + .flat_map(Claimants::iter) + .find_map(|(user_id, claimant)| { + trace!(?user_id, ?claimant, "checking against pending association"); + + assert!( + !claimant.is_empty(), + "Must not match empty set of claims; should not exist in association_pending." + ); + + for (claim, value) in claimant { + if claiming.get(claim).and_then(Value::as_str) != Some(value) { + return None; + } + } + + Some(user_id.clone()) + }) +} + +#[implement(Sessions)] +pub fn remove_provider_associations_pending(&self, idp_id: &str) { + self.association_pending + .lock() + .expect("locked") + .remove(idp_id); +} + +#[implement(Sessions)] +pub fn remove_user_association_pending(&self, user_id: &UserId, idp_id: Option<&str>) { + self.association_pending + .lock() + .expect("locked") + .iter_mut() + .filter(|(provider, _)| idp_id == Some(provider)) + .for_each(|(_, claiming)| { + claiming.remove(user_id); + }); +} + +#[implement(Sessions)] +pub fn is_user_association_pending(&self, user_id: &UserId) -> bool { + self.association_pending + .lock() + .expect("locked") + .values() + .any(|claiming| claiming.contains_key(user_id)) +}