From 11309062a2e447a187875965730bbffe7c933edb Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 23 Dec 2025 14:55:29 +0000 Subject: [PATCH] Implement SSO/OIDC support. (closes #7) Signed-off-by: Jason Volk --- Cargo.lock | 27 ++ Cargo.toml | 7 +- src/admin/query/mod.rs | 13 +- src/admin/query/oauth.rs | 172 ++++++++++ src/api/Cargo.toml | 1 + src/api/client/session/mod.rs | 24 +- src/api/client/session/sso.rs | 601 +++++++++++++++++++++++++++++++++ src/api/router.rs | 3 + src/api/router/args.rs | 5 + src/api/router/auth/uiaa.rs | 14 +- src/api/router/request.rs | 5 +- src/core/config/mod.rs | 145 +++++++- src/database/maps.rs | 12 + src/service/Cargo.toml | 1 + src/service/client/mod.rs | 6 + src/service/mod.rs | 1 + src/service/oauth/mod.rs | 265 +++++++++++++++ src/service/oauth/providers.rs | 246 ++++++++++++++ src/service/oauth/sessions.rs | 234 +++++++++++++ src/service/oauth/user_info.rs | 39 +++ src/service/services.rs | 5 +- src/service/users/mod.rs | 31 +- tuwunel-example.toml | 129 +++++++ 23 files changed, 1959 insertions(+), 27 deletions(-) create mode 100644 src/admin/query/oauth.rs create mode 100644 src/api/client/session/sso.rs create mode 100644 src/service/oauth/mod.rs create mode 100644 src/service/oauth/providers.rs create mode 100644 src/service/oauth/sessions.rs create mode 100644 src/service/oauth/user_info.rs diff --git a/Cargo.lock b/Cargo.lock index c7e5bf56..0919ae18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,6 +383,7 @@ dependencies = [ "axum", "axum-core", "bytes", + "cookie", "futures-util", "headers", "http", @@ -836,6 +837,17 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "coolor" version = "1.1.0" @@ -1291,6 +1303,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -3582,6 +3603,7 @@ checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" dependencies = [ "base64", "bytes", + "encoding_rs", "futures-channel", "futures-core", "futures-util", @@ -3595,6 +3617,7 @@ dependencies = [ "hyper-util", "js-sys", "log", + "mime", "once_cell", "percent-encoding", "pin-project-lite", @@ -5141,6 +5164,7 @@ dependencies = [ "tracing", "tuwunel_core", "tuwunel_service", + "url", ] [[package]] @@ -5185,6 +5209,8 @@ dependencies = [ "ruma", "sanitize-filename", "serde", + "serde_core", + "serde_html_form", "serde_json", "serde_regex", "serde_yaml", @@ -5295,6 +5321,7 @@ dependencies = [ "rustls", "rustyline-async", "serde", + "serde_html_form", "serde_json", "serde_yaml", "sha2", diff --git a/Cargo.toml b/Cargo.toml index b513f295..999f019e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ version = "0.7" version = "0.10" default-features = false features = [ + "cookie", "typed-header", "tracing", ] @@ -312,10 +313,12 @@ version = "1.12" version = "0.12" default-features = false features = [ - "rustls-tls-native-roots", - "socks", + "charset", "hickory-dns", "http2", + "json", + "rustls-tls-native-roots", + "socks", ] [workspace.dependencies.ring] diff --git a/src/admin/query/mod.rs b/src/admin/query/mod.rs index 00a3b993..0ac379f5 100644 --- a/src/admin/query/mod.rs +++ b/src/admin/query/mod.rs @@ -1,6 +1,7 @@ mod account_data; mod appservice; mod globals; +mod oauth; mod presence; mod pusher; mod raw; @@ -18,10 +19,10 @@ use tuwunel_core::Result; use self::{ account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand, - presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand, resolver::ResolverCommand, - room_alias::RoomAliasCommand, room_state_cache::RoomStateCacheCommand, - room_timeline::RoomTimelineCommand, sending::SendingCommand, short::ShortCommand, - sync::SyncCommand, users::UsersCommand, + oauth::OauthCommand, presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand, + resolver::ResolverCommand, room_alias::RoomAliasCommand, + room_state_cache::RoomStateCacheCommand, room_timeline::RoomTimelineCommand, + sending::SendingCommand, short::ShortCommand, sync::SyncCommand, users::UsersCommand, }; use crate::admin_command_dispatch; @@ -81,6 +82,10 @@ pub(super) enum QueryCommand { #[command(subcommand)] Sync(SyncCommand), + /// - oauth service + #[command(subcommand)] + Oauth(OauthCommand), + /// - raw service #[command(subcommand)] Raw(RawCommand), diff --git a/src/admin/query/oauth.rs b/src/admin/query/oauth.rs new file mode 100644 index 00000000..e3cecaa4 --- /dev/null +++ b/src/admin/query/oauth.rs @@ -0,0 +1,172 @@ +use clap::Subcommand; +use futures::{StreamExt, TryStreamExt}; +use ruma::OwnedUserId; +use tuwunel_core::{Err, Result, utils::stream::IterStream}; +use tuwunel_service::{ + Services, + oauth::{Provider, Session}, +}; + +use crate::{admin_command, admin_command_dispatch}; + +#[admin_command_dispatch(handler_prefix = "oauth")] +#[derive(Debug, Subcommand)] +/// Query OAuth service state +pub(crate) enum OauthCommand { + /// List configured OAuth providers. + ListProviders, + + /// List users associated with an OAuth provider + ListUsers, + + /// Show active configuration of a provider. + ShowProvider { + id: String, + + #[arg(long)] + config: bool, + }, + + /// Show session state + ShowSession { + id: String, + }, + + /// Token introspection request to provider. + TokenInfo { + id: String, + }, + + /// Revoke token for user_id or sess_id. + Revoke { + id: String, + }, + + /// Remove oauth state (DANGER!) + Remove { + id: String, + + #[arg(long)] + force: bool, + }, +} + +#[admin_command] +pub(super) async fn oauth_list_providers(&self) -> Result { + self.services + .config + .identity_provider + .iter() + .try_stream() + .map_ok(Provider::id) + .map_ok(|id| format!("{id}\n")) + .try_for_each(async |id| self.write_str(&id).await) + .await +} + +#[admin_command] +pub(super) async fn oauth_list_users(&self) -> Result { + self.services + .oauth + .sessions + .users() + .map(|id| format!("{id}\n")) + .map(Ok) + .try_for_each(async |id: String| self.write_str(&id).await) + .await +} + +#[admin_command] +pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Result { + if config { + let config = self.services.oauth.providers.get_config(&id)?; + + self.write_str(&format!("{config:#?}\n")).await?; + return Ok(()); + } + + let provider = self.services.oauth.providers.get(&id).await?; + + self.write_str(&format!("{provider:#?}\n")).await +} + +#[admin_command] +pub(super) async fn oauth_show_session(&self, id: String) -> Result { + let session = find_session(self.services, &id).await?; + + self.write_str(&format!("{session:#?}\n")).await +} + +#[admin_command] +pub(super) async fn oauth_token_info(&self, id: String) -> Result { + let session = find_session(self.services, &id).await?; + + let provider = self + .services + .oauth + .sessions + .provider(&session) + .await?; + + let tokeninfo = self + .services + .oauth + .request_tokeninfo((&provider, &session)) + .await; + + self.write_str(&format!("{tokeninfo:#?}\n")).await +} + +#[admin_command] +pub(super) async fn oauth_revoke(&self, id: String) -> Result { + let session = find_session(self.services, &id).await?; + + let provider = self + .services + .oauth + .sessions + .provider(&session) + .await?; + + self.services + .oauth + .revoke_token((&provider, &session)) + .await?; + + self.write_str("done").await +} + +#[admin_command] +pub(super) async fn oauth_remove(&self, id: String, force: bool) -> Result { + let session = find_session(self.services, &id).await?; + + let Some(sess_id) = session.sess_id else { + return Err!("Missing sess_id in oauth Session state"); + }; + + if !force { + return Err!( + "Deleting these records can cause registration conflicts. Use --force to be sure." + ); + } + + self.services + .oauth + .sessions + .delete(&sess_id) + .await; + + self.write_str("done").await +} + +async fn find_session(services: &Services, id: &str) -> Result { + if let Ok(user_id) = OwnedUserId::parse(id) { + services + .oauth + .sessions + .get_by_user(&user_id) + .await + } else { + services.oauth.sessions.get(id).await + } +} diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index b47850fe..4914d8c3 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -101,6 +101,7 @@ tokio.workspace = true tracing.workspace = true tuwunel-core.workspace = true tuwunel-service.workspace = true +url.workspace = true [lints] workspace = true diff --git a/src/api/client/session/mod.rs b/src/api/client/session/mod.rs index 9bf0aad2..c3609dcc 100644 --- a/src/api/client/session/mod.rs +++ b/src/api/client/session/mod.rs @@ -4,6 +4,7 @@ mod ldap; mod logout; mod password; mod refresh; +mod sso; mod token; use axum::extract::State; @@ -12,8 +13,8 @@ use ruma::api::client::session::{ get_login_types::{ self, v3::{ - ApplicationServiceLoginType, JwtLoginType, LoginType, PasswordLoginType, - TokenLoginType, + ApplicationServiceLoginType, IdentityProvider, JwtLoginType, LoginType, + PasswordLoginType, SsoLoginType, TokenLoginType, }, }, login::{ @@ -28,6 +29,7 @@ use self::{ldap::ldap_login, password::password_login}; pub(crate) use self::{ logout::{logout_all_route, logout_route}, refresh::refresh_token_route, + sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route}, token::login_token_route, }; use super::TOKEN_LENGTH; @@ -50,6 +52,20 @@ pub(crate) async fn get_login_types_route( LoginType::Token(TokenLoginType { get_login_token: services.config.login_via_existing_session, }), + LoginType::Sso(SsoLoginType { + identity_providers: services + .config + .identity_provider + .iter() + .cloned() + .map(|config| IdentityProvider { + id: config.id().to_owned(), + brand: Some(config.brand.clone().into()), + icon: config.icon, + name: config.name.unwrap_or(config.brand), + }) + .collect(), + }), ])) } @@ -89,6 +105,10 @@ pub(crate) async fn login_route( }, }; + if !services.users.is_active_local(&user_id).await { + return Err!(Request(UserDeactivated("This user has been deactivated."))); + } + // Generate a new token for the device let (access_token, expires_in) = services .users diff --git a/src/api/client/session/sso.rs b/src/api/client/session/sso.rs new file mode 100644 index 00000000..679f3775 --- /dev/null +++ b/src/api/client/session/sso.rs @@ -0,0 +1,601 @@ +use std::time::Duration; + +use axum::extract::State; +use axum_client_ip::InsecureClientIp; +use axum_extra::extract::cookie::{Cookie, SameSite}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64}; +use futures::{StreamExt, TryFutureExt, future::try_join}; +use itertools::Itertools; +use reqwest::header::{CONTENT_TYPE, HeaderValue}; +use ruma::{ + Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId, + api::client::session::{sso_callback, sso_login, sso_login_with_provider}, +}; +use serde::{Deserialize, Serialize}; +use tuwunel_core::{ + Err, Result, at, + debug::INFO_SPAN_LEVEL, + debug_info, debug_warn, err, info, utils, + utils::{ + content_disposition::make_content_disposition, + hash::sha256, + result::{FlatOk, LogErr}, + string::{EMPTY, truncate_deterministic}, + timepoint_from_now, timepoint_has_passed, + }, + warn, +}; +use tuwunel_service::{ + Services, + media::MXC_LENGTH, + oauth::{ + CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id, + unique_id_sub, + }, + users::Register, +}; +use url::Url; + +use super::TOKEN_LENGTH; +use crate::Ruma; + +/// Grant phase query string. +#[derive(Debug, Serialize)] +struct GrantQuery<'a> { + client_id: &'a str, + state: &'a str, + nonce: &'a str, + scope: &'a str, + response_type: &'a str, + access_type: &'a str, + code_challenge_method: &'a str, + code_challenge: &'a str, + redirect_uri: Option<&'a str>, +} + +#[derive(Debug, Deserialize, Serialize)] +struct GrantCookie<'a> { + client_id: &'a str, + state: &'a str, + nonce: &'a str, + redirect_uri: &'a str, +} + +static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session"; + +#[tracing::instrument( + name = "sso_login", + level = "debug", + skip_all, + fields(%client), +)] +pub(crate) async fn sso_login_route( + State(_services): State, + InsecureClientIp(client): InsecureClientIp, + _body: Ruma, +) -> Result { + Err!(Request(NotImplemented( + "SSO login without specific provider has not been implemented." + ))) +} + +#[tracing::instrument( + name = "sso_login_with_provider", + level = "info", + skip_all, + ret(level = "debug") + fields( + %client, + idp_id = body.body.idp_id, + ), +)] +pub(crate) async fn sso_login_with_provider_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let sso_login_with_provider::v3::Request { idp_id, redirect_url } = body.body; + let Ok(redirect_url) = redirect_url.parse::() else { + return Err!(Request(InvalidParam("Invalid redirect_url"))); + }; + + let provider = services.oauth.providers.get(&idp_id).await?; + let sess_id = utils::random_string(SESSION_ID_LENGTH); + let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH); + let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH); + let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH); + let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes())); + let callback_uri = provider.callback_url.as_ref().map(Url::as_str); + let scope = provider.scope.iter().join(" "); + + let query = GrantQuery { + client_id: &provider.client_id, + state: &sess_id, + nonce: &query_nonce, + access_type: "online", + response_type: "code", + code_challenge_method: "S256", + code_challenge: &code_challenge, + redirect_uri: callback_uri, + scope: scope + .is_empty() + .then_some("openid email profile") + .unwrap_or(scope.as_str()), + }; + + let location = provider + .authorization_url + .clone() + .map(|mut location| { + let query = serde_html_form::to_string(&query).ok(); + location.set_query(query.as_deref()); + location + }) + .ok_or_else(|| { + err!(Config("authorization_url", "Missing required IdentityProvider config")) + })?; + + let cookie_val = GrantCookie { + client_id: query.client_id, + state: query.state, + nonce: &cookie_nonce, + redirect_uri: redirect_url.as_str(), + }; + + let cookie_path = provider + .callback_url + .as_ref() + .map(Url::path) + .unwrap_or("/"); + + let cookie_max_age = provider + .grant_session_duration + .map(Duration::from_secs) + .expect("Defaulted to Some value during configure_idp()") + .try_into() + .expect("std::time::Duration to time::Duration conversion failure"); + + let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?)) + .path(cookie_path) + .max_age(cookie_max_age) + .same_site(SameSite::None) + .secure(true) + .http_only(true) + .build() + .to_string() + .into(); + + let session = Session { + idp_id: Some(idp_id), + sess_id: Some(sess_id.clone()), + redirect_url: Some(redirect_url), + code_verifier: Some(code_verifier), + query_nonce: Some(query_nonce), + cookie_nonce: Some(cookie_nonce), + authorize_expires_at: provider + .grant_session_duration + .map(Duration::from_secs) + .map(timepoint_from_now) + .transpose()?, + + ..Default::default() + }; + + services + .oauth + .sessions + .put(&sess_id, &session) + .await; + + Ok(sso_login_with_provider::v3::Response { + location: location.into(), + cookie: Some(cookie), + }) +} + +#[tracing::instrument( + name = "sso_callback" + level = "debug", + skip_all, + fields( + %client, + cookie = ?body.cookie, + body = ?body.body, + ), +)] +pub(crate) async fn sso_callback_route( + State(services): State, + InsecureClientIp(client): InsecureClientIp, + body: Ruma, +) -> Result { + let sess_id = body + .body + .state + .as_deref() + .ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?; + + let code = body + .body + .code + .as_deref() + .ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?; + + let session = services + .oauth + .sessions + .get(sess_id) + .map_err(|_| err!(Request(Forbidden("Invalid state in callback")))); + + let idp_id = body.body.idp_id.as_str(); + let provider = services.oauth.providers.get(idp_id); + let (provider, session) = try_join(provider, session).await.log_err()?; + let client_id = &provider.client_id; + + if session.idp_id.as_deref() != Some(idp_id) { + return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized."))); + } + + if session + .authorize_expires_at + .map(timepoint_has_passed) + .unwrap_or(false) + { + return Err!(Request(Unauthorized("Authorization grant session has expired."))); + } + + let cookie = body + .cookie + .get(GRANT_SESSION_COOKIE) + .map(Cookie::value) + .map(serde_html_form::from_str::>) + .transpose()? + .ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?; + + if cookie.client_id != client_id { + return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch."))); + } + + if Some(cookie.nonce) != session.cookie_nonce.as_deref() { + return Err!(Request(Unauthorized("Cookie nonce does not match session state."))); + } + + if cookie.state != sess_id { + return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch."))); + } + + // Request access token. + let token_response = services + .oauth + .request_token((&provider, &session), code) + .await?; + + let token_expires_at = token_response + .expires_in + .map(Duration::from_secs) + .map(timepoint_from_now) + .transpose()?; + + let refresh_token_expires_at = token_response + .refresh_token_expires_in + .map(Duration::from_secs) + .map(timepoint_from_now) + .transpose()?; + + // Update the session with access token results + let session = Session { + scope: token_response.scope, + token_type: token_response.token_type, + access_token: token_response.access_token, + expires_at: token_expires_at, + refresh_token: token_response.refresh_token, + refresh_token_expires_at, + ..session + }; + + // Request userinfo claims. + let userinfo = services + .oauth + .request_userinfo((&provider, &session)) + .await?; + + // Check for an existing session from this identity. We want to maintain one + // session for each identity and keep the newer one which has up-to-date state + // and access. + let (user_id, old_sess_id) = match services + .oauth + .sessions + .get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?) + .await + { + | Ok(session) => (session.user_id, session.sess_id), + | Err(error) if !error.is_not_found() => return Err(error), + | Err(_) => (None, None), + }; + + // Update the session with userinfo + let session = Session { + user_info: Some(userinfo.clone()), + ..session + }; + + // Keep the user_id from the old session as best as possible. + let user_id = match user_id { + | Some(user_id) => user_id, + | None => decide_user_id(&services, &provider, &session, &userinfo).await?, + }; + + // Update the session with user_id + let session = Session { + user_id: Some(user_id.clone()), + ..session + }; + + // Attempt to register a non-existing user. + if !services.users.exists(&user_id).await { + register_user(&services, &provider, &session, &userinfo, &user_id).await?; + } + + // Commit the updated session. + services + .oauth + .sessions + .put(sess_id, &session) + .await; + + // Delete any old session. + if let Some(old_sess_id) = old_sess_id + && sess_id != old_sess_id + { + services.oauth.sessions.delete(&old_sess_id).await; + } + + if !services.users.is_active_local(&user_id).await { + return Err!(Request(UserDeactivated("This user has been deactivated."))); + } + + // Allow the user to login to Matrix. + let login_token = utils::random_string(TOKEN_LENGTH); + let _login_token_expires_in = services + .users + .create_login_token(&user_id, &login_token); + + let location = session + .redirect_url + .as_ref() + .ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))? + .clone() + .query_pairs_mut() + .append_pair("loginToken", &login_token) + .finish() + .to_string(); + + let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY)) + .removal() + .build() + .to_string() + .into(); + + Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) }) +} + +#[tracing::instrument( + name = "register", + level = INFO_SPAN_LEVEL, + skip_all, + fields(user_id, userinfo) +)] +async fn register_user( + services: &Services, + provider: &Provider, + session: &Session, + userinfo: &UserInfo, + user_id: &UserId, +) -> Result { + debug_info!(%user_id, "Creating new user account..."); + + services + .users + .full_register(Register { + user_id: Some(user_id), + password: Some("*"), + origin: Some("sso"), + displayname: userinfo.name.as_deref(), + ..Default::default() + }) + .await?; + + if let Some(avatar_url) = userinfo + .avatar_url + .as_deref() + .or(userinfo.picture.as_deref()) + { + set_avatar(services, provider, session, userinfo, user_id, avatar_url) + .await + .ok(); + } + + let idp_id = provider.id(); + let idp_name = provider + .name + .as_deref() + .unwrap_or(provider.brand.as_str()); + + // log in conduit admin channel if a non-guest user registered + let notice = + format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})",); + + info!("{notice}"); + if services.server.config.admin_room_notices { + services.admin.notice(¬ice).await; + } + + Ok(()) +} + +#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))] +async fn set_avatar( + services: &Services, + _provider: &Provider, + _session: &Session, + _userinfo: &UserInfo, + user_id: &UserId, + avatar_url: &str, +) -> Result { + use reqwest::Response; + + let response = services + .client + .default + .get(avatar_url) + .send() + .await + .and_then(Response::error_for_status)?; + + let content_type = response + .headers() + .get(CONTENT_TYPE) + .map(HeaderValue::to_str) + .flat_ok() + .map(ToOwned::to_owned); + + let mxc = Mxc { + server_name: services.globals.server_name(), + media_id: &utils::random_string(MXC_LENGTH), + }; + + let content_disposition = make_content_disposition(None, content_type.as_deref(), None); + let bytes = response.bytes().await?; + services + .media + .create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes) + .await?; + + let all_joined_rooms: Vec = services + .state_cache + .rooms_joined(user_id) + .map(ToOwned::to_owned) + .collect() + .await; + + let mxc_uri = mxc.to_string().into(); + services + .users + .update_avatar_url(user_id, Some(mxc_uri), None, &all_joined_rooms) + .await; + + Ok(()) +} + +#[tracing::instrument( + level = "debug", + ret(level = "debug") + skip_all, + fields(user), +)] +async fn decide_user_id( + services: &Services, + provider: &Provider, + session: &Session, + userinfo: &UserInfo, +) -> Result { + let allowed = + |claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim); + + let choices = [ + userinfo + .preferred_username + .as_deref() + .map(str::to_lowercase) + .filter(|_| allowed("preferred_username")), + userinfo + .nickname + .as_deref() + .map(str::to_lowercase) + .filter(|_| allowed("nickname")), + provider + .brand + .eq(&"github") + .then_some(userinfo.sub.as_str()) + .map(str::to_lowercase) + .filter(|_| allowed("login")), + userinfo + .email + .as_deref() + .and_then(|email| email.split_once('@')) + .map(at!(0)) + .map(str::to_lowercase) + .filter(|_| allowed("email")), + ]; + + for choice in choices.into_iter().flatten() { + if let Some(user_id) = try_user_id(services, &choice, false).await { + return Ok(user_id); + } + } + + if let Ok(infallible) = unique_id((provider, session)) + .map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase()) + .log_err() + { + if let Some(user_id) = try_user_id(services, &infallible, true).await { + return Ok(user_id); + } + } + + Err!(Request(UserInUse("User ID is not available."))) +} + +#[tracing::instrument(level = "debug", skip_all, fields(username))] +async fn try_user_id( + services: &Services, + username: &str, + may_exist: bool, +) -> Option { + let server_name = services.globals.server_name(); + let user_id = parse_user_id(server_name, username) + .inspect_err(|e| warn!(?username, "Username invalid: {e}")) + .ok()?; + + if services + .config + .forbidden_usernames + .is_match(username) + { + warn!(?username, "Username forbidden."); + return None; + } + + if services.users.exists(&user_id).await { + debug_warn!(?username, "Username exists."); + + if services + .users + .origin(&user_id) + .await + .ok() + .is_none_or(|origin| origin != "sso") + { + debug_warn!(?username, "Username has non-sso origin."); + return None; + } + + if !may_exist { + return None; + } + } + + Some(user_id) +} + +fn parse_user_id(server_name: &ServerName, username: &str) -> Result { + match UserId::parse_with_server_name(username, server_name) { + | Err(e) => + Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}")))), + | Ok(user_id) => match user_id.validate_strict() { + | Ok(()) => Ok(user_id), + | Err(e) => Err!(Request(InvalidUsername(debug_error!( + "Username {username} contains disallowed characters or spaces: {e}" + )))), + }, + } +} diff --git a/src/api/router.rs b/src/api/router.rs index c3d11bc4..48072f03 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -38,6 +38,9 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(&client::login_route) .ruma_route(&client::login_token_route) .ruma_route(&client::refresh_token_route) + .ruma_route(&client::sso_login_route) + .ruma_route(&client::sso_login_with_provider_route) + .ruma_route(&client::sso_callback_route) .ruma_route(&client::whoami_route) .ruma_route(&client::logout_route) .ruma_route(&client::logout_all_route) diff --git a/src/api/router/args.rs b/src/api/router/args.rs index c5b0c670..f5f7b052 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -1,6 +1,7 @@ use std::{fmt::Debug, mem, ops::Deref}; use axum::{body::Body, extract::FromRequest}; +use axum_extra::extract::cookie::CookieJar; use bytes::{BufMut, Bytes, BytesMut}; use ruma::{ CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName, @@ -18,6 +19,9 @@ pub(crate) struct Args { /// Request struct body pub(crate) body: T, + /// Cookies received from the useragent. + pub(crate) cookie: CookieJar, + /// Federation server authentication: X-Matrix origin /// None when not a federation server. pub(crate) origin: Option, @@ -110,6 +114,7 @@ where let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?; Ok(Self { body: make_body::(services, &mut request, json_body.as_mut(), &auth)?, + cookie: request.cookie, origin: auth.origin, sender_user: auth.sender_user, sender_device: auth.sender_device, diff --git a/src/api/router/auth/uiaa.rs b/src/api/router/auth/uiaa.rs index 4410a7f5..4f7306fa 100644 --- a/src/api/router/auth/uiaa.rs +++ b/src/api/router/auth/uiaa.rs @@ -5,7 +5,7 @@ use ruma::{ client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo}, }, }; -use tuwunel_core::{Err, Error, Result, err, utils}; +use tuwunel_core::{Err, Error, Result, err, is_equal_to, utils}; use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH}; use crate::{Ruma, client::jwt}; @@ -67,9 +67,19 @@ where | Some(ref json) => { let sender_user = body .sender_user - .as_ref() + .as_deref() .ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?; + // Skip UIAA for SSO/OIDC users. + if services + .users + .origin(sender_user) + .await + .is_ok_and(is_equal_to!("sso")) + { + return Ok(sender_user.to_owned()); + } + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); services .uiaa diff --git a/src/api/router/request.rs b/src/api/router/request.rs index 1e5b783b..bb64955b 100644 --- a/src/api/router/request.rs +++ b/src/api/router/request.rs @@ -1,6 +1,7 @@ use std::str; use axum::{RequestExt, RequestPartsExt, extract::Path}; +use axum_extra::extract::cookie::CookieJar; use bytes::Bytes; use http::request::Parts; use serde::Deserialize; @@ -17,6 +18,7 @@ pub(super) type UserId = SmallString<[u8; 48]>; #[derive(Debug)] pub(super) struct Request { + pub(super) cookie: CookieJar, pub(super) path: Path, pub(super) query: QueryParams, pub(super) body: Bytes, @@ -33,6 +35,7 @@ pub(super) async fn from( let limited = request.with_limited_body(); let (mut parts, body) = limited.into_parts(); + let cookie: CookieJar = parts.extract().await?; let path: Path = parts.extract().await?; let query = parts.uri.query().unwrap_or_default(); let query = serde_html_form::from_str(query) @@ -44,5 +47,5 @@ pub(super) async fn from( .await .map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?; - Ok(Request { path, query, body, parts }) + Ok(Request { cookie, path, query, body, parts }) } diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 36a7ab6e..2c107a5e 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -3,7 +3,8 @@ pub mod manager; pub mod proxy; use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, + hash::{Hash, Hasher}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::{Path, PathBuf}, }; @@ -16,7 +17,7 @@ use figment::providers::{Env, Format, Toml}; pub use figment::{Figment, value::Value as FigmentValue}; use regex::RegexSet; use ruma::{ - OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId, + OwnedMxcUri, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId, api::client::discovery::discover_support::ContactRole, }; use serde::{Deserialize, de::IgnoredAny}; @@ -28,6 +29,7 @@ pub use self::{check::check, manager::Manager}; use crate::{ Result, err, error::Error, + utils, utils::{string::EMPTY, sys}, }; @@ -57,7 +59,7 @@ use crate::{ ### https://tuwunel.chat/configuration.html "#, ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates ldap jwt \ - appservice" + appservice identity_provider" )] pub struct Config { /// The server_name is the pretty name of this server. It is used as a @@ -2153,6 +2155,13 @@ pub struct Config { #[serde(default)] pub appservice: BTreeMap, + // external structure; separate sections + #[serde(default)] + pub identity_provider: HashSet, + + #[serde(default)] + pub sso_aware_preferred: bool, + #[serde(flatten)] #[allow(clippy::zero_sized_map_values)] // this is a catchall, the map shouldn't be zero at runtime @@ -2468,6 +2477,134 @@ pub struct JwtConfig { pub validate_signature: bool, } +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[config_example_generator( + filename = "tuwunel-example.toml", + section = "[global.identity_provider]" +)] +pub struct IdentityProvider { + /// The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab, + /// Google) or the software (e.g. keycloak, MAS) providing the identity. + /// When a brand is recognized we apply certain defaults to this config + /// for your convenience. For certain brands we apply essential internal + /// workarounds specific to that provider; it is important to configure this + /// field properly when a provider needs to be recognized (like GitHub for + /// example). Several configured providers can share the same brand name. It + /// is not case-sensitive. + #[serde(deserialize_with = "utils::string::de::to_lowercase")] + pub brand: String, + + /// The ID of your OAuth application which the provider generates upon + /// registration. This ID then uniquely identifies this configuration + /// instance itself, becoming the identity provider's ID and must be unique + /// and remain unchanged. + pub client_id: String, + + /// Secret key the provider generated for you along with the `client_id` + /// above. Unlike the `client_id`, the `client_secret` can be changed here + /// whenever the provider regenerates one for you. + pub client_secret: String, + + /// The callback URL configured when registering the OAuth application with + /// the provider. Tuwunel's callback URL must be strictly formatted exactly + /// as instructed. The URL host must point directly at the matrix server and + /// use the following path: + /// `/_matrix/client/unstable/login/sso/callback/` where + /// `` is the same one configured for this provider above. + pub callback_url: Option, + + /// Optional display-name for this provider instance seen on the login page + /// by users. It defaults to `brand`. When configuring multiple providers + /// using the same `brand` this can be set to distinguish them. + pub name: Option, + + /// Optional icon for the provider. The canonical providers have a default + /// icon based on the `brand` supplied above when this is not supplied. Note + /// that it uses an MXC url which is curious in the auth-media era and may + /// not be reliable. + pub icon: Option, + + /// Optional list of scopes to authorize. An empty array does not impose any + /// restrictions from here, effectively defaulting to all scopes you + /// configured for the OAuth application at the provider. This setting + /// allows for restricting to a subset of those scopes for this instance. + /// Note the user can further restrict scopes during their authorization. + /// + /// default: [] + #[serde(default)] + pub scope: BTreeSet, + + /// List of userinfo claims which shape and restrict the way we compute a + /// Matrix UserId for new registrations. Reviewing Tuwunel's documentation + /// will be necessary for a complete description in detail. An empty array + /// imposes no restriction here, avoiding generated fallbacks as much as + /// possible. For simplicity we reserve a claim called "unique" which can be + /// listed alone to ensure *only* generated ID's are used for registrations. + /// + /// default: [] + #[serde(default)] + pub userid_claims: BTreeSet, + + /// Issuer URL the provider publishes for you. We have pre-supplied default + /// values for some of the canonical providers, making this field optional + /// based on the `brand` set above. Otherwise it is required for OIDC + /// discovery to acquire additional provider configuration, and it must be + /// correct to pass validations during various interactions. + pub issuer_url: Option, + + /// Extra path components after the issuer_url leading to the location of + /// the `.well-known` directory used for discovery. This will be empty for + /// specification-compliant providers. We have supplied any known values + /// based on `brand` (e.g. `/login/oauth` for GitHub). + pub base_path: Option, + + /// Overrides the `.well-known` location where the provider's OIDC + /// configuration is found. It is very unlikely you will need to set this; + /// available for developers or special purposes only. + pub discovery_url: Option, + + /// Overrides the authorize URL requested during the grant phase. This is + /// generally discovered or derived automatically, but may be required as a + /// workaround for any non-standard or undiscoverable provider. + pub authorization_url: Option, + + /// Overrides the access token URL; the same caveats apply as with the other + /// URL overrides. + pub token_url: Option, + + /// Overrides the revocation URL; the same caveats apply as with the other + /// URL overrides. + pub revocation_url: Option, + + /// Overrides the introspection URL; the same caveats apply as with the + /// other URL overrides. + pub introspection_url: Option, + + /// Overrides the userinfo URL; the same caveats apply as with the other URL + /// overrides. + pub userinfo_url: Option, + + /// Whether to perform discovery and adjust this provider's configuration + /// accordingly. This defaults to true. When true, it is an error when + /// discovery fails and authorizations will not be attempted to the + /// provider. + #[serde(default = "true_fn")] + pub discovery: bool, + + /// The duration in seconds before a grant authorization session expires. + #[serde(default = "default_sso_grant_session_duration")] + pub grant_session_duration: Option, +} + +impl IdentityProvider { + #[must_use] + pub fn id(&self) -> &str { self.client_id.as_str() } +} + +impl Hash for IdentityProvider { + fn hash(&self, state: &mut H) { self.id().hash(state) } +} + #[derive(Clone, Debug, Default, Deserialize)] #[config_example_generator( filename = "tuwunel-example.toml", @@ -3016,3 +3153,5 @@ fn default_one_time_key_limit() -> usize { 256 } fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 } fn default_max_join_attempts_per_join_request() -> usize { 3 } + +fn default_sso_grant_session_duration() -> Option { Some(180) } diff --git a/src/database/maps.rs b/src/database/maps.rs index f8bee69d..abf2fdc0 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -113,6 +113,14 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "mediaid_user", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "oauthid_session", + ..descriptor::RANDOM_SMALL + }, + Descriptor { + name: "oauthuniqid_oauthid", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "onetimekeyid_onetimekeys", ..descriptor::RANDOM_SMALL @@ -403,6 +411,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "userid_masterkeyid", ..descriptor::RANDOM_SMALL }, + Descriptor { + name: "userid_oauthid", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "userid_origin", ..descriptor::RANDOM diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index e7a90af8..b44af77c 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -110,6 +110,7 @@ ruma.workspace = true rustls.workspace = true rustyline-async.workspace = true rustyline-async.optional = true +serde_html_form.workspace = true serde_json.workspace = true serde.workspace = true serde_yaml.workspace = true diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index f7de6542..64f1aa39 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -21,6 +21,7 @@ pub struct Service { pub sender: ClientLazylock, pub appservice: ClientLazylock, pub pusher: ClientLazylock, + pub oauth: ClientLazylock, pub cidr_range_denylist: Vec, } @@ -112,6 +113,11 @@ impl crate::Service for Service { .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) .redirect(redirect::Policy::limited(2))), + oauth: create_client!(config, services; base(config)? + .dns_resolver2(Arc::clone(&services.resolver.resolver)) + .redirect(redirect::Policy::limited(0)) + .pool_max_idle_per_host(1)), + cidr_range_denylist: config .ip_range_denylist .iter() diff --git a/src/service/mod.rs b/src/service/mod.rs index 1a85da77..3964809d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -19,6 +19,7 @@ pub mod globals; pub mod key_backups; pub mod media; pub mod membership; +pub mod oauth; pub mod presence; pub mod pusher; pub mod resolver; diff --git a/src/service/oauth/mod.rs b/src/service/oauth/mod.rs new file mode 100644 index 00000000..4a5ed471 --- /dev/null +++ b/src/service/oauth/mod.rs @@ -0,0 +1,265 @@ +pub mod providers; +pub mod sessions; +pub mod user_info; + +use std::sync::Arc; + +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode}; +use reqwest::{Method, header::ACCEPT}; +use ruma::UserId; +use serde::Serialize; +use serde_json::Value as JsonValue; +use tuwunel_core::{ + Err, Result, err, implement, + utils::{hash::sha256, result::LogErr}, +}; +use url::Url; + +pub use self::{ + providers::Provider, + sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session}, + user_info::UserInfo, +}; +use self::{providers::Providers, sessions::Sessions}; +use crate::SelfServices; + +pub struct Service { + services: SelfServices, + pub providers: Arc, + pub sessions: Arc, +} + +impl crate::Service for Service { + fn build(args: &crate::Args<'_>) -> Result> { + let providers = Arc::new(Providers::build(args)); + let sessions = Arc::new(Sessions::build(args, providers.clone())); + Ok(Arc::new(Self { + services: args.services.clone(), + sessions, + providers, + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip_all, ret)] +pub async fn request_userinfo( + &self, + (provider, session): (&Provider, &Session), +) -> Result { + let url = provider + .userinfo_url + .clone() + .ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?; + + self.request((Some(provider), Some(session)), Method::GET, url) + .await + .and_then(|value| serde_json::from_value(value).map_err(Into::into)) + .log_err() +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip_all, ret)] +pub async fn request_tokeninfo( + &self, + (provider, session): (&Provider, &Session), +) -> Result { + let url = provider + .introspection_url + .clone() + .ok_or_else(|| { + err!(Config("introspection_url", "Missing introspection URL in config")) + })?; + + self.request((Some(provider), Some(session)), Method::GET, url) + .await + .and_then(|value| serde_json::from_value(value).map_err(Into::into)) + .log_err() +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip_all, ret)] +pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result { + #[derive(Debug, Serialize)] + struct RevokeQuery<'a> { + client_id: &'a str, + client_secret: &'a str, + } + + let query = RevokeQuery { + client_id: &provider.client_id, + client_secret: &provider.client_secret, + }; + + let query = serde_html_form::to_string(&query)?; + let url = provider + .revocation_url + .clone() + .map(|mut url| { + url.set_query(Some(&query)); + url + }) + .ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?; + + self.request((Some(provider), Some(session)), Method::POST, url) + .await + .log_err() + .map(|_| ()) +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip_all, ret)] +pub async fn request_token( + &self, + (provider, session): (&Provider, &Session), + code: &str, +) -> Result { + #[derive(Debug, Serialize)] + struct TokenQuery<'a> { + client_id: &'a str, + client_secret: &'a str, + grant_type: &'a str, + code: &'a str, + code_verifier: Option<&'a str>, + redirect_uri: Option<&'a str>, + } + + let query = TokenQuery { + client_id: &provider.client_id, + client_secret: &provider.client_secret, + grant_type: "authorization_code", + code, + code_verifier: session.code_verifier.as_deref(), + redirect_uri: provider.callback_url.as_ref().map(Url::as_str), + }; + + let query = serde_html_form::to_string(&query)?; + let url = provider + .token_url + .clone() + .map(|mut url| { + url.set_query(Some(&query)); + url + }) + .ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?; + + self.request((Some(provider), Some(session)), Method::POST, url) + .await + .and_then(|value| serde_json::from_value(value).map_err(Into::into)) + .log_err() +} + +/// Send a request to a provider; this is somewhat abstract since URL's are +/// formed prior to this call and could point at anything, however this function +/// uses the oauth-specific http client and is configured for JSON with special +/// casing for an `error` property in the response. +#[implement(Service)] +#[tracing::instrument( + name = "request", + level = "debug", + ret(level = "trace"), + skip(self) +)] +pub async fn request( + &self, + (provider, session): (Option<&Provider>, Option<&Session>), + method: Method, + url: Url, +) -> Result { + let mut request = self + .services + .client + .oauth + .request(method, url) + .header(ACCEPT, "application/json"); + + if let Some(session) = session { + if let Some(access_token) = session.access_token.clone() { + request = request.bearer_auth(access_token); + } + } + + let response: JsonValue = request + .send() + .await? + .error_for_status()? + .json() + .await?; + + if let Some(response) = response.as_object().as_ref() + && let Some(error) = response.get("error").and_then(JsonValue::as_str) + { + let description = response + .get("error_description") + .and_then(JsonValue::as_str) + .unwrap_or("(no description)"); + + return Err!(Request(Forbidden("Error from provider: {error}: {description}",))); + } + + Ok(response) +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> { + let session = self.sessions.get_by_user(user_id).await?; + let provider = self.sessions.provider(&session).await?; + + Ok((provider, session)) +} + +#[inline] +pub fn unique_id((provider, session): (&Provider, &Session)) -> Result { + unique_id_parts((provider, session)).and_then(unique_id_iss_sub) +} + +#[inline] +pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result { + unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub) +} + +#[inline] +pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result { + unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub) +} + +pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result { + let hash = sha256::delimited([iss, sub].iter()); + let b64 = b64encode.encode(hash); + + Ok(b64) +} + +fn unique_id_parts<'a>( + (provider, session): (&'a Provider, &'a Session), +) -> Result<(&'a str, &'a str)> { + provider + .issuer_url + .as_ref() + .map(Url::as_str) + .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider."))) + .and_then(|iss| unique_id_iss_parts((iss, session))) +} + +fn unique_id_sub_parts<'a>( + (provider, sub): (&'a Provider, &'a str), +) -> Result<(&'a str, &'a str)> { + provider + .issuer_url + .as_ref() + .map(Url::as_str) + .ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider."))) + .map(|iss| (iss, sub)) +} + +fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> { + session + .user_info + .as_ref() + .map(|user_info| user_info.sub.as_str()) + .ok_or_else(|| err!(Request(NotFound("user_info not found for this session.")))) + .map(|sub| (iss, sub)) +} diff --git a/src/service/oauth/providers.rs b/src/service/oauth/providers.rs new file mode 100644 index 00000000..cac599e7 --- /dev/null +++ b/src/service/oauth/providers.rs @@ -0,0 +1,246 @@ +use std::collections::BTreeMap; + +use serde_json::{Map as JsonObject, Value as JsonValue}; +use tokio::sync::RwLock; +pub use tuwunel_core::config::IdentityProvider as Provider; +use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement}; +use url::Url; + +use crate::SelfServices; + +#[derive(Default)] +pub struct Providers { + services: SelfServices, + providers: RwLock>, +} + +#[implement(Providers)] +pub(super) fn build(args: &crate::Args<'_>) -> Self { + Self { + services: args.services.clone(), + ..Default::default() + } +} + +/// Get the Provider configuration after any discovery and adjustments +/// made on top of the admin's configuration. This incurs network-based +/// discovery on the first call but responds from cache on subsequent calls. +#[implement(Providers)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn get(&self, id: &str) -> Result { + if let Some(provider) = self.providers.read().await.get(id).cloned() { + return Ok(provider); + } + + let config = self.get_config(id)?; + let mut map = self.providers.write().await; + let config = self.configure(config).await?; + + debug!(?id, ?config); + _ = map.insert(id.into(), config.clone()); + + Ok(config) +} + +/// Get the admin-configured Provider which exists prior to any +/// reconciliation with the well-known discovery (the server's config is +/// immutable); though it is important to note the server config can be +/// reloaded. This will Err NotFound for a non-existent idp. +#[implement(Providers)] +pub fn get_config(&self, id: &str) -> Result { + self.services + .config + .identity_provider + .iter() + .find(|config| config.id() == id) + .cloned() + .ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider")))) +} + +/// Configure an identity provider; takes the admin-configured instance from the +/// server's config, queries the provider for discovery, and then returns an +/// updated config based on the proper reconciliation. This final config is then +/// cached in memory to avoid repeating this process. +#[implement(Providers)] +#[tracing::instrument( + level = INFO_SPAN_LEVEL, + ret(level = "debug"), + skip(self), +)] +async fn configure(&self, mut provider: Provider) -> Result { + _ = provider + .name + .get_or_insert_with(|| provider.brand.clone()); + + if provider.issuer_url.is_none() { + _ = provider + .issuer_url + .replace(match provider.brand.as_str() { + | "github" => "https://github.com".try_into()?, + | "gitlab" => "https://gitlab.com".try_into()?, + | "google" => "https://accounts.google.com".try_into()?, + | _ => return Err!(Config("issuer_url", "Required for this provider.")), + }); + } + + if provider.base_path.is_none() { + provider.base_path = match provider.brand.as_str() { + | "github" => Some("/login/oauth".to_owned()), + | _ => None, + }; + } + + let response = self + .discover(&provider) + .await + .and_then(|response| { + response.as_object().cloned().ok_or_else(|| { + err!(Request(NotJson("Expecting JSON object for discovery response"))) + }) + }) + .and_then(|response| check_issuer(response, &provider))?; + + if provider.authorization_url.is_none() { + response + .get("authorization_endpoint") + .and_then(JsonValue::as_str) + .map(Url::parse) + .transpose()? + .or_else(|| make_url(&provider, "/authorize").ok()) + .map(|url| provider.authorization_url.replace(url)); + } + + if provider.revocation_url.is_none() { + response + .get("revocation_endpoint") + .and_then(JsonValue::as_str) + .map(Url::parse) + .transpose()? + .or_else(|| make_url(&provider, "/revocation").ok()) + .map(|url| provider.revocation_url.replace(url)); + } + + if provider.introspection_url.is_none() { + response + .get("introspection_endpoint") + .and_then(JsonValue::as_str) + .map(Url::parse) + .transpose()? + .or_else(|| make_url(&provider, "/introspection").ok()) + .map(|url| provider.introspection_url.replace(url)); + } + + if provider.userinfo_url.is_none() { + response + .get("userinfo_endpoint") + .and_then(JsonValue::as_str) + .map(Url::parse) + .transpose()? + .or_else(|| match provider.brand.as_str() { + | "github" => "https://api.github.com/user".try_into().ok(), + | _ => make_url(&provider, "/userinfo").ok(), + }) + .map(|url| provider.userinfo_url.replace(url)); + } + + if provider.token_url.is_none() { + response + .get("token_endpoint") + .and_then(JsonValue::as_str) + .map(Url::parse) + .transpose()? + .or_else(|| { + let path = if provider.brand == "github" { + "/access_token" + } else { + "/token" + }; + + make_url(&provider, path).ok() + }) + .map(|url| provider.token_url.replace(url)); + } + + Ok(provider) +} + +#[implement(Providers)] +#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))] +pub async fn discover(&self, provider: &Provider) -> Result { + self.services + .client + .oauth + .get(discovery_url(provider)?) + .send() + .await? + .error_for_status()? + .json() + .await + .map_err(Into::into) +} + +fn discovery_url(provider: &Provider) -> Result { + let default_url = provider + .discovery + .then(|| make_url(provider, "/.well-known/openid-configuration")) + .transpose()?; + + let Some(url) = provider + .discovery_url + .clone() + .filter(|_| provider.discovery) + .or(default_url) + else { + return Err!(Config( + "discovery_url", + "Failed to determine URL for discovery of provider {}", + provider.id() + )); + }; + + Ok(url) +} + +/// Validate that the locally configured `issuer_url` matches the issuer claimed +/// in any response. todo: cryptographic validation is not yet implemented here. +fn check_issuer( + response: JsonObject, + provider: &Provider, +) -> Result> { + let expected = provider + .issuer_url + .as_ref() + .map(Url::as_str) + .map(|url| url.trim_end_matches('/')); + + let responded = response + .get("issuer") + .and_then(JsonValue::as_str) + .map(|url| url.trim_end_matches('/')); + + if expected != responded { + return Err!(Request(Unauthorized( + "Configured issuer_url {expected:?} does not match discovered {responded:?}", + ))); + } + + Ok(response) +} + +/// Generate a full URL for a request to the idp based on the idp's derived +/// configuration. +fn make_url(provider: &Provider, path: &str) -> Result { + let mut suffix = provider.base_path.clone().unwrap_or_default(); + + suffix.push_str(path); + let url = provider + .issuer_url + .as_ref() + .ok_or_else(|| { + let id = &provider.client_id; + err!(Config("issuer_url", "Provider {id:?} required field")) + })? + .join(&suffix)?; + + Ok(url) +} diff --git a/src/service/oauth/sessions.rs b/src/service/oauth/sessions.rs new file mode 100644 index 00000000..56e382a4 --- /dev/null +++ b/src/service/oauth/sessions.rs @@ -0,0 +1,234 @@ +use std::{sync::Arc, time::SystemTime}; + +use futures::{Stream, StreamExt, TryFutureExt}; +use ruma::{OwnedUserId, UserId}; +use serde::{Deserialize, Serialize}; +use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect}; +use tuwunel_database::{Cbor, Deserialized, Map}; +use url::Url; + +use super::{Provider, Providers, UserInfo, unique_id}; +use crate::SelfServices; + +pub struct Sessions { + _services: SelfServices, + providers: Arc, + db: Data, +} + +struct Data { + oauthid_session: Arc, + oauthuniqid_oauthid: Arc, + userid_oauthid: Arc, +} + +/// Session ultimately represents an OAuth authorization session yielding an +/// associated matrix user registration. +/// +/// Mixed-use structure capable of deserializing response values, maintaining +/// the state between authorization steps, and maintaining the association to +/// the matrix user until deactivation or revocation. +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Session { + /// Identity Provider ID (the `client_id` in the configuration) associated + /// with this session. + pub idp_id: Option, + + /// Session ID used as the index key for this session itself. + pub sess_id: Option, + + /// Token type (bearer, mac, etc). + pub token_type: Option, + + /// Access token to the provider. + pub access_token: Option, + + /// Duration in seconds the access_token is valid for. + pub expires_in: Option, + + /// Point in time that the access_token expires. + pub expires_at: Option, + + /// Token used to refresh the access_token. + pub refresh_token: Option, + + /// Duration in seconds the refresh_token is valid for + pub refresh_token_expires_in: Option, + + /// Point in time that the refresh_token expires. + pub refresh_token_expires_at: Option, + + /// Access scope actually granted (if supported). + pub scope: Option, + + /// Redirect URL + pub redirect_url: Option, + + /// Challenge preimage + pub code_verifier: Option, + + /// Random string passed exclusively in the grant session cookie. + pub cookie_nonce: Option, + + /// Random single-use string passed in the provider redirect. + pub query_nonce: Option, + + /// Point in time the authorization grant session expires. + pub authorize_expires_at: Option, + + /// Associated User Id registration. + pub user_id: Option, + + /// Last userinfo response persisted here. + pub user_info: Option, +} + +/// Number of characters generated for our code_verifier. The code_verifier is a +/// random string which must be between 43 and 128 characters. +pub const CODE_VERIFIER_LENGTH: usize = 64; + +/// Number of characters we will generate for the Session ID. +pub const SESSION_ID_LENGTH: usize = 32; + +#[implement(Sessions)] +pub(super) fn build(args: &crate::Args<'_>, providers: Arc) -> Self { + Self { + _services: args.services.clone(), + providers, + db: Data { + oauthid_session: args.db["oauthid_session"].clone(), + oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(), + userid_oauthid: args.db["userid_oauthid"].clone(), + }, + } +} + +#[implement(Sessions)] +pub fn users(&self) -> impl Stream + Send { + self.db + .userid_oauthid + .keys() + .expect_ok() + .map(UserId::to_owned) +} + +/// Delete database state for the session. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self))] +pub async fn delete(&self, sess_id: &str) { + let Ok(session) = self.get(sess_id).await else { + return; + }; + + // Check the user_id still points to this sess_id before deleting. If not, the + // association was updated to a newer session. + if let Some(user_id) = session.user_id.as_deref() { + if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await { + if assoc_id == sess_id { + self.db.userid_oauthid.remove(user_id); + } + } + } + + // Check the unique identity still points to this sess_id before deleting. If + // not, the association was updated to a newer session. + if let Some(idp_id) = session.idp_id.as_ref() { + if let Ok(provider) = self.providers.get(idp_id).await { + if let Ok(unique_id) = unique_id((&provider, &session)) { + if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await { + if assoc_id == sess_id { + self.db.oauthuniqid_oauthid.remove(&unique_id); + } + } + } + } + } + + self.db.oauthid_session.remove(sess_id); +} + +/// Create or overwrite database state for the session. +#[implement(Sessions)] +#[tracing::instrument(level = "info", skip(self))] +pub async fn put(&self, sess_id: &str, session: &Session) { + self.db + .oauthid_session + .raw_put(sess_id, Cbor(session)); + + if let Some(user_id) = session.user_id.as_deref() { + self.db.userid_oauthid.insert(user_id, sess_id); + } + + if let Some(idp_id) = session.idp_id.as_ref() { + if let Ok(provider) = self.providers.get(idp_id).await { + if let Ok(unique_id) = unique_id((&provider, session)) { + self.db + .oauthuniqid_oauthid + .insert(&unique_id, sess_id); + } + } + } +} + +/// Fetch database state for a session from its associated `sub`, in case +/// `sess_id` is not known. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))] +pub async fn get_by_unique_id(&self, unique_id: &str) -> Result { + self.get_sess_id_by_unique_id(unique_id) + .and_then(async |sess_id| self.get(&sess_id).await) + .await +} + +/// Fetch database state for a session from its associated `user_id`, in case +/// `sess_id` is not known. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))] +pub async fn get_by_user(&self, user_id: &UserId) -> Result { + self.get_sess_id_by_user(user_id) + .and_then(async |sess_id| self.get(&sess_id).await) + .await +} + +/// Fetch database state for a session from its `sess_id`. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))] +pub async fn get(&self, sess_id: &str) -> Result { + self.db + .oauthid_session + .get(sess_id) + .await + .deserialized::>() + .map(at!(0)) +} + +/// Resolve the `sess_id` from an associated `user_id`. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))] +pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result { + self.db + .userid_oauthid + .get(user_id) + .await + .deserialized() +} + +/// Resolve the `sess_id` from an associated provider subject id. +#[implement(Sessions)] +#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))] +pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result { + self.db + .oauthuniqid_oauthid + .get(unique_id) + .await + .deserialized() +} + +#[implement(Sessions)] +pub async fn provider(&self, session: &Session) -> Result { + let Some(idp_id) = session.idp_id.as_deref() else { + return Err!(Request(NotFound("No provider for this session"))); + }; + + self.providers.get(idp_id).await +} diff --git a/src/service/oauth/user_info.rs b/src/service/oauth/user_info.rs new file mode 100644 index 00000000..18eda0db --- /dev/null +++ b/src/service/oauth/user_info.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +/// Selection of userinfo response claims. +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct UserInfo { + /// Unique identifier number or login username. Usually a number on most + /// services. We consider a concatenation of the `iss` and `sub` to be a + /// universally unique identifier for some user/identity; we index that in + /// `oauthidpsub_oauthid`. + /// + /// Considered for user mxid only if none of the better fields are defined. + /// `login` alias intended for github. + #[serde(alias = "login")] + pub sub: String, + + /// The login username we first consider when defined. + pub preferred_username: Option, + + /// The login username considered if none preferred. + pub nickname: Option, + + /// Full name. + pub name: Option, + + /// First name. + pub given_name: Option, + + /// Last name. + pub family_name: Option, + + /// Email address (`email` scope). + pub email: Option, + + /// URL to pfp (github/gitlab) + pub avatar_url: Option, + + /// URL to pfp (google) + pub picture: Option, +} diff --git a/src/service/services.rs b/src/service/services.rs index cda9232c..98c55e7b 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -12,7 +12,7 @@ use crate::{ account_data, admin, appservice, client, config, deactivate, emergency, federation, globals, key_backups, manager::Manager, - media, membership, presence, pusher, resolver, rooms, sending, server_keys, + media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys, service::{Args, Service}, sync, transaction_ids, uiaa, users, }; @@ -58,6 +58,7 @@ pub struct Services { pub users: Arc, pub membership: Arc, pub deactivate: Arc, + pub oauth: Arc, manager: Mutex>>, pub server: Arc, @@ -115,6 +116,7 @@ pub async fn build(server: Arc) -> Result> { users: users::Service::build(&args)?, membership: membership::Service::build(&args)?, deactivate: deactivate::Service::build(&args)?, + oauth: oauth::Service::build(&args)?, manager: Mutex::new(None), server, @@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator> + Send { cast!(self.users), cast!(self.membership), cast!(self.deactivate), + cast!(self.oauth), ] .into_iter() } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 981c02f9..5b652d23 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -14,10 +14,10 @@ use ruma::{ events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent}, }; use tuwunel_core::{ - Err, Result, debug_warn, err, is_equal_to, + Err, Result, debug_warn, err, pdu::PduBuilder, trace, - utils::{self, ReadyExt, stream::TryIgnore}, + utils::{self, ReadyExt, result::LogErr, stream::TryIgnore}, warn, }; use tuwunel_database::{Deserialized, Json, Map}; @@ -132,6 +132,16 @@ impl Service { /// Deactivate account pub async fn deactivate_account(&self, user_id: &UserId) -> Result { + // Revoke any SSO authorizations + if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await { + self.services + .oauth + .revoke_token((&provider, &session)) + .await + .log_err() + .ok(); + } + // Remove all associated devices self.all_device_ids(user_id) .for_each(|device_id| self.remove_device(user_id, device_id)) @@ -227,18 +237,15 @@ impl Service { // Cannot change the password of a LDAP user. There are two special cases : // - a `None` password can be used to deactivate a LDAP user // - a "*" password is used as the default password of an active LDAP user - if cfg!(feature = "ldap") - && password.is_some() + // + // The above now applies to all non-password origin users including SSO + // users. Note that users with no origin are also password-origin users. + if password.is_some() && password != Some("*") - && self - .db - .userid_origin - .get(user_id) - .await - .deserialized::() - .is_ok_and(is_equal_to!("ldap")) + && let Ok(origin) = self.origin(user_id).await + && origin != "password" { - return Err!(Request(InvalidParam("Cannot change password of a LDAP user"))); + return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user."))); } password diff --git a/tuwunel-example.toml b/tuwunel-example.toml index d974d477..0bd58ae8 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -1834,6 +1834,10 @@ # #one_time_key_limit = 256 +# This item is undocumented. Please contribute documentation for it. +# +#sso_aware_preferred = false + #[global.tls] @@ -2097,6 +2101,131 @@ +#[[global.identity_provider]] + +# The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab, +# Google) or the software (e.g. keycloak, MAS) providing the identity. +# When a brand is recognized we apply certain defaults to this config +# for your convenience. For certain brands we apply essential internal +# workarounds specific to that provider; it is important to configure this +# field properly when a provider needs to be recognized (like GitHub for +# example). Several configured providers can share the same brand name. It +# is not case-sensitive. +# +#brand = + +# The ID of your OAuth application which the provider generates upon +# registration. This ID then uniquely identifies this configuration +# instance itself, becoming the identity provider's ID and must be unique +# and remain unchanged. +# +#client_id = + +# Secret key the provider generated for you along with the `client_id` +# above. Unlike the `client_id`, the `client_secret` can be changed here +# whenever the provider regenerates one for you. +# +#client_secret = + +# The callback URL configured when registering the OAuth application with +# the provider. Tuwunel's callback URL must be strictly formatted exactly +# as instructed. The URL host must point directly at the matrix server and +# use the following path: +# `/_matrix/client/unstable/login/sso/callback/` where +# `` is the same one configured for this provider above. +# +#callback_url = + +# Optional display-name for this provider instance seen on the login page +# by users. It defaults to `brand`. When configuring multiple providers +# using the same `brand` this can be set to distinguish them. +# +#name = + +# Optional icon for the provider. The canonical providers have a default +# icon based on the `brand` supplied above when this is not supplied. Note +# that it uses an MXC url which is curious in the auth-media era and may +# not be reliable. +# +#icon = + +# Optional list of scopes to authorize. An empty array does not impose any +# restrictions from here, effectively defaulting to all scopes you +# configured for the OAuth application at the provider. This setting +# allows for restricting to a subset of those scopes for this instance. +# Note the user can further restrict scopes during their authorization. +# +#scope = [] + +# List of userinfo claims which shape and restrict the way we compute a +# Matrix UserId for new registrations. Reviewing Tuwunel's documentation +# will be necessary for a complete description in detail. An empty array +# imposes no restriction here, avoiding generated fallbacks as much as +# possible. For simplicity we reserve a claim called "unique" which can be +# listed alone to ensure *only* generated ID's are used for registrations. +# +#userid_claims = [] + +# Issuer URL the provider publishes for you. We have pre-supplied default +# values for some of the canonical providers, making this field optional +# based on the `brand` set above. Otherwise it is required for OIDC +# discovery to acquire additional provider configuration, and it must be +# correct to pass validations during various interactions. +# +#issuer_url = + +# Extra path components after the issuer_url leading to the location of +# the `.well-known` directory used for discovery. This will be empty for +# specification-compliant providers. We have supplied any known values +# based on `brand` (e.g. `/login/oauth` for GitHub). +# +#base_path = + +# Overrides the `.well-known` location where the provider's OIDC +# configuration is found. It is very unlikely you will need to set this; +# available for developers or special purposes only. +# +#discovery_url = + +# Overrides the authorize URL requested during the grant phase. This is +# generally discovered or derived automatically, but may be required as a +# workaround for any non-standard or undiscoverable provider. +# +#authorization_url = + +# Overrides the access token URL; the same caveats apply as with the other +# URL overrides. +# +#token_url = + +# Overrides the revocation URL; the same caveats apply as with the other +# URL overrides. +# +#revocation_url = + +# Overrides the introspection URL; the same caveats apply as with the +# other URL overrides. +# +#introspection_url = + +# Overrides the userinfo URL; the same caveats apply as with the other URL +# overrides. +# +#userinfo_url = + +# Whether to perform discovery and adjust this provider's configuration +# accordingly. This defaults to true. When true, it is an error when +# discovery fails and authorizations will not be attempted to the +# provider. +# +#discovery = true + +# The duration in seconds before a grant authorization session expires. +# +#grant_session_duration = + + + #[global.appservice.] # The URL for the application service.