Implement SSO/OIDC support. (closes #7)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-12-23 14:55:29 +00:00
parent d665a34f30
commit 11309062a2
23 changed files with 1959 additions and 27 deletions

27
Cargo.lock generated
View File

@@ -383,6 +383,7 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"headers",
"http",
@@ -836,6 +837,17 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]]
name = "coolor"
version = "1.1.0"
@@ -1291,6 +1303,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "enum-as-inner"
version = "0.6.1"
@@ -3582,6 +3603,7 @@ checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-channel",
"futures-core",
"futures-util",
@@ -3595,6 +3617,7 @@ dependencies = [
"hyper-util",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
@@ -5141,6 +5164,7 @@ dependencies = [
"tracing",
"tuwunel_core",
"tuwunel_service",
"url",
]
[[package]]
@@ -5185,6 +5209,8 @@ dependencies = [
"ruma",
"sanitize-filename",
"serde",
"serde_core",
"serde_html_form",
"serde_json",
"serde_regex",
"serde_yaml",
@@ -5295,6 +5321,7 @@ dependencies = [
"rustls",
"rustyline-async",
"serde",
"serde_html_form",
"serde_json",
"serde_yaml",
"sha2",

View File

@@ -76,6 +76,7 @@ version = "0.7"
version = "0.10"
default-features = false
features = [
"cookie",
"typed-header",
"tracing",
]
@@ -312,10 +313,12 @@ version = "1.12"
version = "0.12"
default-features = false
features = [
"rustls-tls-native-roots",
"socks",
"charset",
"hickory-dns",
"http2",
"json",
"rustls-tls-native-roots",
"socks",
]
[workspace.dependencies.ring]

View File

@@ -1,6 +1,7 @@
mod account_data;
mod appservice;
mod globals;
mod oauth;
mod presence;
mod pusher;
mod raw;
@@ -18,10 +19,10 @@ use tuwunel_core::Result;
use self::{
account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand,
presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand, resolver::ResolverCommand,
room_alias::RoomAliasCommand, room_state_cache::RoomStateCacheCommand,
room_timeline::RoomTimelineCommand, sending::SendingCommand, short::ShortCommand,
sync::SyncCommand, users::UsersCommand,
oauth::OauthCommand, presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand,
resolver::ResolverCommand, room_alias::RoomAliasCommand,
room_state_cache::RoomStateCacheCommand, room_timeline::RoomTimelineCommand,
sending::SendingCommand, short::ShortCommand, sync::SyncCommand, users::UsersCommand,
};
use crate::admin_command_dispatch;
@@ -81,6 +82,10 @@ pub(super) enum QueryCommand {
#[command(subcommand)]
Sync(SyncCommand),
/// - oauth service
#[command(subcommand)]
Oauth(OauthCommand),
/// - raw service
#[command(subcommand)]
Raw(RawCommand),

172
src/admin/query/oauth.rs Normal file
View File

@@ -0,0 +1,172 @@
use clap::Subcommand;
use futures::{StreamExt, TryStreamExt};
use ruma::OwnedUserId;
use tuwunel_core::{Err, Result, utils::stream::IterStream};
use tuwunel_service::{
Services,
oauth::{Provider, Session},
};
use crate::{admin_command, admin_command_dispatch};
#[admin_command_dispatch(handler_prefix = "oauth")]
#[derive(Debug, Subcommand)]
/// Query OAuth service state
pub(crate) enum OauthCommand {
/// List configured OAuth providers.
ListProviders,
/// List users associated with an OAuth provider
ListUsers,
/// Show active configuration of a provider.
ShowProvider {
id: String,
#[arg(long)]
config: bool,
},
/// Show session state
ShowSession {
id: String,
},
/// Token introspection request to provider.
TokenInfo {
id: String,
},
/// Revoke token for user_id or sess_id.
Revoke {
id: String,
},
/// Remove oauth state (DANGER!)
Remove {
id: String,
#[arg(long)]
force: bool,
},
}
#[admin_command]
pub(super) async fn oauth_list_providers(&self) -> Result {
self.services
.config
.identity_provider
.iter()
.try_stream()
.map_ok(Provider::id)
.map_ok(|id| format!("{id}\n"))
.try_for_each(async |id| self.write_str(&id).await)
.await
}
#[admin_command]
pub(super) async fn oauth_list_users(&self) -> Result {
self.services
.oauth
.sessions
.users()
.map(|id| format!("{id}\n"))
.map(Ok)
.try_for_each(async |id: String| self.write_str(&id).await)
.await
}
#[admin_command]
pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Result {
if config {
let config = self.services.oauth.providers.get_config(&id)?;
self.write_str(&format!("{config:#?}\n")).await?;
return Ok(());
}
let provider = self.services.oauth.providers.get(&id).await?;
self.write_str(&format!("{provider:#?}\n")).await
}
#[admin_command]
pub(super) async fn oauth_show_session(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
self.write_str(&format!("{session:#?}\n")).await
}
#[admin_command]
pub(super) async fn oauth_token_info(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
let provider = self
.services
.oauth
.sessions
.provider(&session)
.await?;
let tokeninfo = self
.services
.oauth
.request_tokeninfo((&provider, &session))
.await;
self.write_str(&format!("{tokeninfo:#?}\n")).await
}
#[admin_command]
pub(super) async fn oauth_revoke(&self, id: String) -> Result {
let session = find_session(self.services, &id).await?;
let provider = self
.services
.oauth
.sessions
.provider(&session)
.await?;
self.services
.oauth
.revoke_token((&provider, &session))
.await?;
self.write_str("done").await
}
#[admin_command]
pub(super) async fn oauth_remove(&self, id: String, force: bool) -> Result {
let session = find_session(self.services, &id).await?;
let Some(sess_id) = session.sess_id else {
return Err!("Missing sess_id in oauth Session state");
};
if !force {
return Err!(
"Deleting these records can cause registration conflicts. Use --force to be sure."
);
}
self.services
.oauth
.sessions
.delete(&sess_id)
.await;
self.write_str("done").await
}
async fn find_session(services: &Services, id: &str) -> Result<Session> {
if let Ok(user_id) = OwnedUserId::parse(id) {
services
.oauth
.sessions
.get_by_user(&user_id)
.await
} else {
services.oauth.sessions.get(id).await
}
}

View File

@@ -101,6 +101,7 @@ tokio.workspace = true
tracing.workspace = true
tuwunel-core.workspace = true
tuwunel-service.workspace = true
url.workspace = true
[lints]
workspace = true

View File

@@ -4,6 +4,7 @@ mod ldap;
mod logout;
mod password;
mod refresh;
mod sso;
mod token;
use axum::extract::State;
@@ -12,8 +13,8 @@ use ruma::api::client::session::{
get_login_types::{
self,
v3::{
ApplicationServiceLoginType, JwtLoginType, LoginType, PasswordLoginType,
TokenLoginType,
ApplicationServiceLoginType, IdentityProvider, JwtLoginType, LoginType,
PasswordLoginType, SsoLoginType, TokenLoginType,
},
},
login::{
@@ -28,6 +29,7 @@ use self::{ldap::ldap_login, password::password_login};
pub(crate) use self::{
logout::{logout_all_route, logout_route},
refresh::refresh_token_route,
sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route},
token::login_token_route,
};
use super::TOKEN_LENGTH;
@@ -50,6 +52,20 @@ pub(crate) async fn get_login_types_route(
LoginType::Token(TokenLoginType {
get_login_token: services.config.login_via_existing_session,
}),
LoginType::Sso(SsoLoginType {
identity_providers: services
.config
.identity_provider
.iter()
.cloned()
.map(|config| IdentityProvider {
id: config.id().to_owned(),
brand: Some(config.brand.clone().into()),
icon: config.icon,
name: config.name.unwrap_or(config.brand),
})
.collect(),
}),
]))
}
@@ -89,6 +105,10 @@ pub(crate) async fn login_route(
},
};
if !services.users.is_active_local(&user_id).await {
return Err!(Request(UserDeactivated("This user has been deactivated.")));
}
// Generate a new token for the device
let (access_token, expires_in) = services
.users

View File

@@ -0,0 +1,601 @@
use std::time::Duration;
use axum::extract::State;
use axum_client_ip::InsecureClientIp;
use axum_extra::extract::cookie::{Cookie, SameSite};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
use futures::{StreamExt, TryFutureExt, future::try_join};
use itertools::Itertools;
use reqwest::header::{CONTENT_TYPE, HeaderValue};
use ruma::{
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
api::client::session::{sso_callback, sso_login, sso_login_with_provider},
};
use serde::{Deserialize, Serialize};
use tuwunel_core::{
Err, Result, at,
debug::INFO_SPAN_LEVEL,
debug_info, debug_warn, err, info, utils,
utils::{
content_disposition::make_content_disposition,
hash::sha256,
result::{FlatOk, LogErr},
string::{EMPTY, truncate_deterministic},
timepoint_from_now, timepoint_has_passed,
},
warn,
};
use tuwunel_service::{
Services,
media::MXC_LENGTH,
oauth::{
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id,
unique_id_sub,
},
users::Register,
};
use url::Url;
use super::TOKEN_LENGTH;
use crate::Ruma;
/// Grant phase query string.
#[derive(Debug, Serialize)]
struct GrantQuery<'a> {
client_id: &'a str,
state: &'a str,
nonce: &'a str,
scope: &'a str,
response_type: &'a str,
access_type: &'a str,
code_challenge_method: &'a str,
code_challenge: &'a str,
redirect_uri: Option<&'a str>,
}
#[derive(Debug, Deserialize, Serialize)]
struct GrantCookie<'a> {
client_id: &'a str,
state: &'a str,
nonce: &'a str,
redirect_uri: &'a str,
}
static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
#[tracing::instrument(
name = "sso_login",
level = "debug",
skip_all,
fields(%client),
)]
pub(crate) async fn sso_login_route(
State(_services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
_body: Ruma<sso_login::v3::Request>,
) -> Result<sso_login::v3::Response> {
Err!(Request(NotImplemented(
"SSO login without specific provider has not been implemented."
)))
}
#[tracing::instrument(
name = "sso_login_with_provider",
level = "info",
skip_all,
ret(level = "debug")
fields(
%client,
idp_id = body.body.idp_id,
),
)]
pub(crate) async fn sso_login_with_provider_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<sso_login_with_provider::v3::Request>,
) -> Result<sso_login_with_provider::v3::Response> {
let sso_login_with_provider::v3::Request { idp_id, redirect_url } = body.body;
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
return Err!(Request(InvalidParam("Invalid redirect_url")));
};
let provider = services.oauth.providers.get(&idp_id).await?;
let sess_id = utils::random_string(SESSION_ID_LENGTH);
let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
let scope = provider.scope.iter().join(" ");
let query = GrantQuery {
client_id: &provider.client_id,
state: &sess_id,
nonce: &query_nonce,
access_type: "online",
response_type: "code",
code_challenge_method: "S256",
code_challenge: &code_challenge,
redirect_uri: callback_uri,
scope: scope
.is_empty()
.then_some("openid email profile")
.unwrap_or(scope.as_str()),
};
let location = provider
.authorization_url
.clone()
.map(|mut location| {
let query = serde_html_form::to_string(&query).ok();
location.set_query(query.as_deref());
location
})
.ok_or_else(|| {
err!(Config("authorization_url", "Missing required IdentityProvider config"))
})?;
let cookie_val = GrantCookie {
client_id: query.client_id,
state: query.state,
nonce: &cookie_nonce,
redirect_uri: redirect_url.as_str(),
};
let cookie_path = provider
.callback_url
.as_ref()
.map(Url::path)
.unwrap_or("/");
let cookie_max_age = provider
.grant_session_duration
.map(Duration::from_secs)
.expect("Defaulted to Some value during configure_idp()")
.try_into()
.expect("std::time::Duration to time::Duration conversion failure");
let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
.path(cookie_path)
.max_age(cookie_max_age)
.same_site(SameSite::None)
.secure(true)
.http_only(true)
.build()
.to_string()
.into();
let session = Session {
idp_id: Some(idp_id),
sess_id: Some(sess_id.clone()),
redirect_url: Some(redirect_url),
code_verifier: Some(code_verifier),
query_nonce: Some(query_nonce),
cookie_nonce: Some(cookie_nonce),
authorize_expires_at: provider
.grant_session_duration
.map(Duration::from_secs)
.map(timepoint_from_now)
.transpose()?,
..Default::default()
};
services
.oauth
.sessions
.put(&sess_id, &session)
.await;
Ok(sso_login_with_provider::v3::Response {
location: location.into(),
cookie: Some(cookie),
})
}
#[tracing::instrument(
name = "sso_callback"
level = "debug",
skip_all,
fields(
%client,
cookie = ?body.cookie,
body = ?body.body,
),
)]
pub(crate) async fn sso_callback_route(
State(services): State<crate::State>,
InsecureClientIp(client): InsecureClientIp,
body: Ruma<sso_callback::unstable::Request>,
) -> Result<sso_callback::unstable::Response> {
let sess_id = body
.body
.state
.as_deref()
.ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
let code = body
.body
.code
.as_deref()
.ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
let session = services
.oauth
.sessions
.get(sess_id)
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
let idp_id = body.body.idp_id.as_str();
let provider = services.oauth.providers.get(idp_id);
let (provider, session) = try_join(provider, session).await.log_err()?;
let client_id = &provider.client_id;
if session.idp_id.as_deref() != Some(idp_id) {
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
}
if session
.authorize_expires_at
.map(timepoint_has_passed)
.unwrap_or(false)
{
return Err!(Request(Unauthorized("Authorization grant session has expired.")));
}
let cookie = body
.cookie
.get(GRANT_SESSION_COOKIE)
.map(Cookie::value)
.map(serde_html_form::from_str::<GrantCookie<'_>>)
.transpose()?
.ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
if cookie.client_id != client_id {
return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
}
if Some(cookie.nonce) != session.cookie_nonce.as_deref() {
return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
}
if cookie.state != sess_id {
return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
}
// Request access token.
let token_response = services
.oauth
.request_token((&provider, &session), code)
.await?;
let token_expires_at = token_response
.expires_in
.map(Duration::from_secs)
.map(timepoint_from_now)
.transpose()?;
let refresh_token_expires_at = token_response
.refresh_token_expires_in
.map(Duration::from_secs)
.map(timepoint_from_now)
.transpose()?;
// Update the session with access token results
let session = Session {
scope: token_response.scope,
token_type: token_response.token_type,
access_token: token_response.access_token,
expires_at: token_expires_at,
refresh_token: token_response.refresh_token,
refresh_token_expires_at,
..session
};
// Request userinfo claims.
let userinfo = services
.oauth
.request_userinfo((&provider, &session))
.await?;
// Check for an existing session from this identity. We want to maintain one
// session for each identity and keep the newer one which has up-to-date state
// and access.
let (user_id, old_sess_id) = match services
.oauth
.sessions
.get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?)
.await
{
| Ok(session) => (session.user_id, session.sess_id),
| Err(error) if !error.is_not_found() => return Err(error),
| Err(_) => (None, None),
};
// Update the session with userinfo
let session = Session {
user_info: Some(userinfo.clone()),
..session
};
// Keep the user_id from the old session as best as possible.
let user_id = match user_id {
| Some(user_id) => user_id,
| None => decide_user_id(&services, &provider, &session, &userinfo).await?,
};
// Update the session with user_id
let session = Session {
user_id: Some(user_id.clone()),
..session
};
// Attempt to register a non-existing user.
if !services.users.exists(&user_id).await {
register_user(&services, &provider, &session, &userinfo, &user_id).await?;
}
// Commit the updated session.
services
.oauth
.sessions
.put(sess_id, &session)
.await;
// Delete any old session.
if let Some(old_sess_id) = old_sess_id
&& sess_id != old_sess_id
{
services.oauth.sessions.delete(&old_sess_id).await;
}
if !services.users.is_active_local(&user_id).await {
return Err!(Request(UserDeactivated("This user has been deactivated.")));
}
// Allow the user to login to Matrix.
let login_token = utils::random_string(TOKEN_LENGTH);
let _login_token_expires_in = services
.users
.create_login_token(&user_id, &login_token);
let location = session
.redirect_url
.as_ref()
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
.clone()
.query_pairs_mut()
.append_pair("loginToken", &login_token)
.finish()
.to_string();
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
.removal()
.build()
.to_string()
.into();
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
}
#[tracing::instrument(
name = "register",
level = INFO_SPAN_LEVEL,
skip_all,
fields(user_id, userinfo)
)]
async fn register_user(
services: &Services,
provider: &Provider,
session: &Session,
userinfo: &UserInfo,
user_id: &UserId,
) -> Result {
debug_info!(%user_id, "Creating new user account...");
services
.users
.full_register(Register {
user_id: Some(user_id),
password: Some("*"),
origin: Some("sso"),
displayname: userinfo.name.as_deref(),
..Default::default()
})
.await?;
if let Some(avatar_url) = userinfo
.avatar_url
.as_deref()
.or(userinfo.picture.as_deref())
{
set_avatar(services, provider, session, userinfo, user_id, avatar_url)
.await
.ok();
}
let idp_id = provider.id();
let idp_name = provider
.name
.as_deref()
.unwrap_or(provider.brand.as_str());
// log in conduit admin channel if a non-guest user registered
let notice =
format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})",);
info!("{notice}");
if services.server.config.admin_room_notices {
services.admin.notice(&notice).await;
}
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
async fn set_avatar(
services: &Services,
_provider: &Provider,
_session: &Session,
_userinfo: &UserInfo,
user_id: &UserId,
avatar_url: &str,
) -> Result {
use reqwest::Response;
let response = services
.client
.default
.get(avatar_url)
.send()
.await
.and_then(Response::error_for_status)?;
let content_type = response
.headers()
.get(CONTENT_TYPE)
.map(HeaderValue::to_str)
.flat_ok()
.map(ToOwned::to_owned);
let mxc = Mxc {
server_name: services.globals.server_name(),
media_id: &utils::random_string(MXC_LENGTH),
};
let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
let bytes = response.bytes().await?;
services
.media
.create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
.await?;
let all_joined_rooms: Vec<OwnedRoomId> = services
.state_cache
.rooms_joined(user_id)
.map(ToOwned::to_owned)
.collect()
.await;
let mxc_uri = mxc.to_string().into();
services
.users
.update_avatar_url(user_id, Some(mxc_uri), None, &all_joined_rooms)
.await;
Ok(())
}
#[tracing::instrument(
level = "debug",
ret(level = "debug")
skip_all,
fields(user),
)]
async fn decide_user_id(
services: &Services,
provider: &Provider,
session: &Session,
userinfo: &UserInfo,
) -> Result<OwnedUserId> {
let allowed =
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
let choices = [
userinfo
.preferred_username
.as_deref()
.map(str::to_lowercase)
.filter(|_| allowed("preferred_username")),
userinfo
.nickname
.as_deref()
.map(str::to_lowercase)
.filter(|_| allowed("nickname")),
provider
.brand
.eq(&"github")
.then_some(userinfo.sub.as_str())
.map(str::to_lowercase)
.filter(|_| allowed("login")),
userinfo
.email
.as_deref()
.and_then(|email| email.split_once('@'))
.map(at!(0))
.map(str::to_lowercase)
.filter(|_| allowed("email")),
];
for choice in choices.into_iter().flatten() {
if let Some(user_id) = try_user_id(services, &choice, false).await {
return Ok(user_id);
}
}
if let Ok(infallible) = unique_id((provider, session))
.map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase())
.log_err()
{
if let Some(user_id) = try_user_id(services, &infallible, true).await {
return Ok(user_id);
}
}
Err!(Request(UserInUse("User ID is not available.")))
}
#[tracing::instrument(level = "debug", skip_all, fields(username))]
async fn try_user_id(
services: &Services,
username: &str,
may_exist: bool,
) -> Option<OwnedUserId> {
let server_name = services.globals.server_name();
let user_id = parse_user_id(server_name, username)
.inspect_err(|e| warn!(?username, "Username invalid: {e}"))
.ok()?;
if services
.config
.forbidden_usernames
.is_match(username)
{
warn!(?username, "Username forbidden.");
return None;
}
if services.users.exists(&user_id).await {
debug_warn!(?username, "Username exists.");
if services
.users
.origin(&user_id)
.await
.ok()
.is_none_or(|origin| origin != "sso")
{
debug_warn!(?username, "Username has non-sso origin.");
return None;
}
if !may_exist {
return None;
}
}
Some(user_id)
}
fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
match UserId::parse_with_server_name(username, server_name) {
| Err(e) =>
Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}")))),
| Ok(user_id) => match user_id.validate_strict() {
| Ok(()) => Ok(user_id),
| Err(e) => Err!(Request(InvalidUsername(debug_error!(
"Username {username} contains disallowed characters or spaces: {e}"
)))),
},
}
}

View File

@@ -38,6 +38,9 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
.ruma_route(&client::login_route)
.ruma_route(&client::login_token_route)
.ruma_route(&client::refresh_token_route)
.ruma_route(&client::sso_login_route)
.ruma_route(&client::sso_login_with_provider_route)
.ruma_route(&client::sso_callback_route)
.ruma_route(&client::whoami_route)
.ruma_route(&client::logout_route)
.ruma_route(&client::logout_all_route)

View File

@@ -1,6 +1,7 @@
use std::{fmt::Debug, mem, ops::Deref};
use axum::{body::Body, extract::FromRequest};
use axum_extra::extract::cookie::CookieJar;
use bytes::{BufMut, Bytes, BytesMut};
use ruma::{
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
@@ -18,6 +19,9 @@ pub(crate) struct Args<T> {
/// Request struct body
pub(crate) body: T,
/// Cookies received from the useragent.
pub(crate) cookie: CookieJar,
/// Federation server authentication: X-Matrix origin
/// None when not a federation server.
pub(crate) origin: Option<OwnedServerName>,
@@ -110,6 +114,7 @@ where
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
Ok(Self {
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
cookie: request.cookie,
origin: auth.origin,
sender_user: auth.sender_user,
sender_device: auth.sender_device,

View File

@@ -5,7 +5,7 @@ use ruma::{
client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo},
},
};
use tuwunel_core::{Err, Error, Result, err, utils};
use tuwunel_core::{Err, Error, Result, err, is_equal_to, utils};
use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH};
use crate::{Ruma, client::jwt};
@@ -67,9 +67,19 @@ where
| Some(ref json) => {
let sender_user = body
.sender_user
.as_ref()
.as_deref()
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
// Skip UIAA for SSO/OIDC users.
if services
.users
.origin(sender_user)
.await
.is_ok_and(is_equal_to!("sso"))
{
return Ok(sender_user.to_owned());
}
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services
.uiaa

View File

@@ -1,6 +1,7 @@
use std::str;
use axum::{RequestExt, RequestPartsExt, extract::Path};
use axum_extra::extract::cookie::CookieJar;
use bytes::Bytes;
use http::request::Parts;
use serde::Deserialize;
@@ -17,6 +18,7 @@ pub(super) type UserId = SmallString<[u8; 48]>;
#[derive(Debug)]
pub(super) struct Request {
pub(super) cookie: CookieJar,
pub(super) path: Path<PathParams>,
pub(super) query: QueryParams,
pub(super) body: Bytes,
@@ -33,6 +35,7 @@ pub(super) async fn from(
let limited = request.with_limited_body();
let (mut parts, body) = limited.into_parts();
let cookie: CookieJar = parts.extract().await?;
let path: Path<PathParams> = parts.extract().await?;
let query = parts.uri.query().unwrap_or_default();
let query = serde_html_form::from_str(query)
@@ -44,5 +47,5 @@ pub(super) async fn from(
.await
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?;
Ok(Request { path, query, body, parts })
Ok(Request { cookie, path, query, body, parts })
}

View File

@@ -3,7 +3,8 @@ pub mod manager;
pub mod proxy;
use std::{
collections::{BTreeMap, BTreeSet},
collections::{BTreeMap, BTreeSet, HashSet},
hash::{Hash, Hasher},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
};
@@ -16,7 +17,7 @@ use figment::providers::{Env, Format, Toml};
pub use figment::{Figment, value::Value as FigmentValue};
use regex::RegexSet;
use ruma::{
OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
OwnedMxcUri, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
api::client::discovery::discover_support::ContactRole,
};
use serde::{Deserialize, de::IgnoredAny};
@@ -28,6 +29,7 @@ pub use self::{check::check, manager::Manager};
use crate::{
Result, err,
error::Error,
utils,
utils::{string::EMPTY, sys},
};
@@ -57,7 +59,7 @@ use crate::{
### https://tuwunel.chat/configuration.html
"#,
ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates ldap jwt \
appservice"
appservice identity_provider"
)]
pub struct Config {
/// The server_name is the pretty name of this server. It is used as a
@@ -2153,6 +2155,13 @@ pub struct Config {
#[serde(default)]
pub appservice: BTreeMap<String, AppService>,
// external structure; separate sections
#[serde(default)]
pub identity_provider: HashSet<IdentityProvider>,
#[serde(default)]
pub sso_aware_preferred: bool,
#[serde(flatten)]
#[allow(clippy::zero_sized_map_values)]
// this is a catchall, the map shouldn't be zero at runtime
@@ -2468,6 +2477,134 @@ pub struct JwtConfig {
pub validate_signature: bool,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[config_example_generator(
filename = "tuwunel-example.toml",
section = "[global.identity_provider]"
)]
pub struct IdentityProvider {
/// The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
/// Google) or the software (e.g. keycloak, MAS) providing the identity.
/// When a brand is recognized we apply certain defaults to this config
/// for your convenience. For certain brands we apply essential internal
/// workarounds specific to that provider; it is important to configure this
/// field properly when a provider needs to be recognized (like GitHub for
/// example). Several configured providers can share the same brand name. It
/// is not case-sensitive.
#[serde(deserialize_with = "utils::string::de::to_lowercase")]
pub brand: String,
/// The ID of your OAuth application which the provider generates upon
/// registration. This ID then uniquely identifies this configuration
/// instance itself, becoming the identity provider's ID and must be unique
/// and remain unchanged.
pub client_id: String,
/// Secret key the provider generated for you along with the `client_id`
/// above. Unlike the `client_id`, the `client_secret` can be changed here
/// whenever the provider regenerates one for you.
pub client_secret: String,
/// The callback URL configured when registering the OAuth application with
/// the provider. Tuwunel's callback URL must be strictly formatted exactly
/// as instructed. The URL host must point directly at the matrix server and
/// use the following path:
/// `/_matrix/client/unstable/login/sso/callback/<client_id>` where
/// `<client_id>` is the same one configured for this provider above.
pub callback_url: Option<Url>,
/// Optional display-name for this provider instance seen on the login page
/// by users. It defaults to `brand`. When configuring multiple providers
/// using the same `brand` this can be set to distinguish them.
pub name: Option<String>,
/// Optional icon for the provider. The canonical providers have a default
/// icon based on the `brand` supplied above when this is not supplied. Note
/// that it uses an MXC url which is curious in the auth-media era and may
/// not be reliable.
pub icon: Option<OwnedMxcUri>,
/// Optional list of scopes to authorize. An empty array does not impose any
/// restrictions from here, effectively defaulting to all scopes you
/// configured for the OAuth application at the provider. This setting
/// allows for restricting to a subset of those scopes for this instance.
/// Note the user can further restrict scopes during their authorization.
///
/// default: []
#[serde(default)]
pub scope: BTreeSet<String>,
/// List of userinfo claims which shape and restrict the way we compute a
/// Matrix UserId for new registrations. Reviewing Tuwunel's documentation
/// will be necessary for a complete description in detail. An empty array
/// imposes no restriction here, avoiding generated fallbacks as much as
/// possible. For simplicity we reserve a claim called "unique" which can be
/// listed alone to ensure *only* generated ID's are used for registrations.
///
/// default: []
#[serde(default)]
pub userid_claims: BTreeSet<String>,
/// Issuer URL the provider publishes for you. We have pre-supplied default
/// values for some of the canonical providers, making this field optional
/// based on the `brand` set above. Otherwise it is required for OIDC
/// discovery to acquire additional provider configuration, and it must be
/// correct to pass validations during various interactions.
pub issuer_url: Option<Url>,
/// Extra path components after the issuer_url leading to the location of
/// the `.well-known` directory used for discovery. This will be empty for
/// specification-compliant providers. We have supplied any known values
/// based on `brand` (e.g. `/login/oauth` for GitHub).
pub base_path: Option<String>,
/// Overrides the `.well-known` location where the provider's OIDC
/// configuration is found. It is very unlikely you will need to set this;
/// available for developers or special purposes only.
pub discovery_url: Option<Url>,
/// Overrides the authorize URL requested during the grant phase. This is
/// generally discovered or derived automatically, but may be required as a
/// workaround for any non-standard or undiscoverable provider.
pub authorization_url: Option<Url>,
/// Overrides the access token URL; the same caveats apply as with the other
/// URL overrides.
pub token_url: Option<Url>,
/// Overrides the revocation URL; the same caveats apply as with the other
/// URL overrides.
pub revocation_url: Option<Url>,
/// Overrides the introspection URL; the same caveats apply as with the
/// other URL overrides.
pub introspection_url: Option<Url>,
/// Overrides the userinfo URL; the same caveats apply as with the other URL
/// overrides.
pub userinfo_url: Option<Url>,
/// Whether to perform discovery and adjust this provider's configuration
/// accordingly. This defaults to true. When true, it is an error when
/// discovery fails and authorizations will not be attempted to the
/// provider.
#[serde(default = "true_fn")]
pub discovery: bool,
/// The duration in seconds before a grant authorization session expires.
#[serde(default = "default_sso_grant_session_duration")]
pub grant_session_duration: Option<u64>,
}
impl IdentityProvider {
#[must_use]
pub fn id(&self) -> &str { self.client_id.as_str() }
}
impl Hash for IdentityProvider {
fn hash<H: Hasher>(&self, state: &mut H) { self.id().hash(state) }
}
#[derive(Clone, Debug, Default, Deserialize)]
#[config_example_generator(
filename = "tuwunel-example.toml",
@@ -3016,3 +3153,5 @@ fn default_one_time_key_limit() -> usize { 256 }
fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 }
fn default_max_join_attempts_per_join_request() -> usize { 3 }
fn default_sso_grant_session_duration() -> Option<u64> { Some(180) }

View File

@@ -113,6 +113,14 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "mediaid_user",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "oauthid_session",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "oauthuniqid_oauthid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "onetimekeyid_onetimekeys",
..descriptor::RANDOM_SMALL
@@ -403,6 +411,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "userid_masterkeyid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_oauthid",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "userid_origin",
..descriptor::RANDOM

View File

@@ -110,6 +110,7 @@ ruma.workspace = true
rustls.workspace = true
rustyline-async.workspace = true
rustyline-async.optional = true
serde_html_form.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_yaml.workspace = true

View File

@@ -21,6 +21,7 @@ pub struct Service {
pub sender: ClientLazylock,
pub appservice: ClientLazylock,
pub pusher: ClientLazylock,
pub oauth: ClientLazylock,
pub cidr_range_denylist: Vec<IPAddress>,
}
@@ -112,6 +113,11 @@ impl crate::Service for Service {
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
.redirect(redirect::Policy::limited(2))),
oauth: create_client!(config, services; base(config)?
.dns_resolver2(Arc::clone(&services.resolver.resolver))
.redirect(redirect::Policy::limited(0))
.pool_max_idle_per_host(1)),
cidr_range_denylist: config
.ip_range_denylist
.iter()

View File

@@ -19,6 +19,7 @@ pub mod globals;
pub mod key_backups;
pub mod media;
pub mod membership;
pub mod oauth;
pub mod presence;
pub mod pusher;
pub mod resolver;

265
src/service/oauth/mod.rs Normal file
View File

@@ -0,0 +1,265 @@
pub mod providers;
pub mod sessions;
pub mod user_info;
use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
use reqwest::{Method, header::ACCEPT};
use ruma::UserId;
use serde::Serialize;
use serde_json::Value as JsonValue;
use tuwunel_core::{
Err, Result, err, implement,
utils::{hash::sha256, result::LogErr},
};
use url::Url;
pub use self::{
providers::Provider,
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
user_info::UserInfo,
};
use self::{providers::Providers, sessions::Sessions};
use crate::SelfServices;
pub struct Service {
services: SelfServices,
pub providers: Arc<Providers>,
pub sessions: Arc<Sessions>,
}
impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
let providers = Arc::new(Providers::build(args));
let sessions = Arc::new(Sessions::build(args, providers.clone()));
Ok(Arc::new(Self {
services: args.services.clone(),
sessions,
providers,
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_userinfo(
&self,
(provider, session): (&Provider, &Session),
) -> Result<UserInfo> {
let url = provider
.userinfo_url
.clone()
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
self.request((Some(provider), Some(session)), Method::GET, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_tokeninfo(
&self,
(provider, session): (&Provider, &Session),
) -> Result<UserInfo> {
let url = provider
.introspection_url
.clone()
.ok_or_else(|| {
err!(Config("introspection_url", "Missing introspection URL in config"))
})?;
self.request((Some(provider), Some(session)), Method::GET, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
#[derive(Debug, Serialize)]
struct RevokeQuery<'a> {
client_id: &'a str,
client_secret: &'a str,
}
let query = RevokeQuery {
client_id: &provider.client_id,
client_secret: &provider.client_secret,
};
let query = serde_html_form::to_string(&query)?;
let url = provider
.revocation_url
.clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url)
.await
.log_err()
.map(|_| ())
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip_all, ret)]
pub async fn request_token(
&self,
(provider, session): (&Provider, &Session),
code: &str,
) -> Result<Session> {
#[derive(Debug, Serialize)]
struct TokenQuery<'a> {
client_id: &'a str,
client_secret: &'a str,
grant_type: &'a str,
code: &'a str,
code_verifier: Option<&'a str>,
redirect_uri: Option<&'a str>,
}
let query = TokenQuery {
client_id: &provider.client_id,
client_secret: &provider.client_secret,
grant_type: "authorization_code",
code,
code_verifier: session.code_verifier.as_deref(),
redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
};
let query = serde_html_form::to_string(&query)?;
let url = provider
.token_url
.clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url)
.await
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err()
}
/// Send a request to a provider; this is somewhat abstract since URL's are
/// formed prior to this call and could point at anything, however this function
/// uses the oauth-specific http client and is configured for JSON with special
/// casing for an `error` property in the response.
#[implement(Service)]
#[tracing::instrument(
name = "request",
level = "debug",
ret(level = "trace"),
skip(self)
)]
pub async fn request(
&self,
(provider, session): (Option<&Provider>, Option<&Session>),
method: Method,
url: Url,
) -> Result<JsonValue> {
let mut request = self
.services
.client
.oauth
.request(method, url)
.header(ACCEPT, "application/json");
if let Some(session) = session {
if let Some(access_token) = session.access_token.clone() {
request = request.bearer_auth(access_token);
}
}
let response: JsonValue = request
.send()
.await?
.error_for_status()?
.json()
.await?;
if let Some(response) = response.as_object().as_ref()
&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
{
let description = response
.get("error_description")
.and_then(JsonValue::as_str)
.unwrap_or("(no description)");
return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
}
Ok(response)
}
#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
let session = self.sessions.get_by_user(user_id).await?;
let provider = self.sessions.provider(&session).await?;
Ok((provider, session))
}
#[inline]
pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
}
#[inline]
pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
}
#[inline]
pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
}
pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
let hash = sha256::delimited([iss, sub].iter());
let b64 = b64encode.encode(hash);
Ok(b64)
}
fn unique_id_parts<'a>(
(provider, session): (&'a Provider, &'a Session),
) -> Result<(&'a str, &'a str)> {
provider
.issuer_url
.as_ref()
.map(Url::as_str)
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
.and_then(|iss| unique_id_iss_parts((iss, session)))
}
fn unique_id_sub_parts<'a>(
(provider, sub): (&'a Provider, &'a str),
) -> Result<(&'a str, &'a str)> {
provider
.issuer_url
.as_ref()
.map(Url::as_str)
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
.map(|iss| (iss, sub))
}
fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
session
.user_info
.as_ref()
.map(|user_info| user_info.sub.as_str())
.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
.map(|sub| (iss, sub))
}

View File

@@ -0,0 +1,246 @@
use std::collections::BTreeMap;
use serde_json::{Map as JsonObject, Value as JsonValue};
use tokio::sync::RwLock;
pub use tuwunel_core::config::IdentityProvider as Provider;
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
use url::Url;
use crate::SelfServices;
#[derive(Default)]
pub struct Providers {
services: SelfServices,
providers: RwLock<BTreeMap<String, Provider>>,
}
#[implement(Providers)]
pub(super) fn build(args: &crate::Args<'_>) -> Self {
Self {
services: args.services.clone(),
..Default::default()
}
}
/// Get the Provider configuration after any discovery and adjustments
/// made on top of the admin's configuration. This incurs network-based
/// discovery on the first call but responds from cache on subsequent calls.
#[implement(Providers)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn get(&self, id: &str) -> Result<Provider> {
if let Some(provider) = self.providers.read().await.get(id).cloned() {
return Ok(provider);
}
let config = self.get_config(id)?;
let mut map = self.providers.write().await;
let config = self.configure(config).await?;
debug!(?id, ?config);
_ = map.insert(id.into(), config.clone());
Ok(config)
}
/// Get the admin-configured Provider which exists prior to any
/// reconciliation with the well-known discovery (the server's config is
/// immutable); though it is important to note the server config can be
/// reloaded. This will Err NotFound for a non-existent idp.
#[implement(Providers)]
pub fn get_config(&self, id: &str) -> Result<Provider> {
self.services
.config
.identity_provider
.iter()
.find(|config| config.id() == id)
.cloned()
.ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider"))))
}
/// Configure an identity provider; takes the admin-configured instance from the
/// server's config, queries the provider for discovery, and then returns an
/// updated config based on the proper reconciliation. This final config is then
/// cached in memory to avoid repeating this process.
#[implement(Providers)]
#[tracing::instrument(
level = INFO_SPAN_LEVEL,
ret(level = "debug"),
skip(self),
)]
async fn configure(&self, mut provider: Provider) -> Result<Provider> {
_ = provider
.name
.get_or_insert_with(|| provider.brand.clone());
if provider.issuer_url.is_none() {
_ = provider
.issuer_url
.replace(match provider.brand.as_str() {
| "github" => "https://github.com".try_into()?,
| "gitlab" => "https://gitlab.com".try_into()?,
| "google" => "https://accounts.google.com".try_into()?,
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
});
}
if provider.base_path.is_none() {
provider.base_path = match provider.brand.as_str() {
| "github" => Some("/login/oauth".to_owned()),
| _ => None,
};
}
let response = self
.discover(&provider)
.await
.and_then(|response| {
response.as_object().cloned().ok_or_else(|| {
err!(Request(NotJson("Expecting JSON object for discovery response")))
})
})
.and_then(|response| check_issuer(response, &provider))?;
if provider.authorization_url.is_none() {
response
.get("authorization_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/authorize").ok())
.map(|url| provider.authorization_url.replace(url));
}
if provider.revocation_url.is_none() {
response
.get("revocation_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/revocation").ok())
.map(|url| provider.revocation_url.replace(url));
}
if provider.introspection_url.is_none() {
response
.get("introspection_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| make_url(&provider, "/introspection").ok())
.map(|url| provider.introspection_url.replace(url));
}
if provider.userinfo_url.is_none() {
response
.get("userinfo_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| match provider.brand.as_str() {
| "github" => "https://api.github.com/user".try_into().ok(),
| _ => make_url(&provider, "/userinfo").ok(),
})
.map(|url| provider.userinfo_url.replace(url));
}
if provider.token_url.is_none() {
response
.get("token_endpoint")
.and_then(JsonValue::as_str)
.map(Url::parse)
.transpose()?
.or_else(|| {
let path = if provider.brand == "github" {
"/access_token"
} else {
"/token"
};
make_url(&provider, path).ok()
})
.map(|url| provider.token_url.replace(url));
}
Ok(provider)
}
#[implement(Providers)]
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
self.services
.client
.oauth
.get(discovery_url(provider)?)
.send()
.await?
.error_for_status()?
.json()
.await
.map_err(Into::into)
}
fn discovery_url(provider: &Provider) -> Result<Url> {
let default_url = provider
.discovery
.then(|| make_url(provider, "/.well-known/openid-configuration"))
.transpose()?;
let Some(url) = provider
.discovery_url
.clone()
.filter(|_| provider.discovery)
.or(default_url)
else {
return Err!(Config(
"discovery_url",
"Failed to determine URL for discovery of provider {}",
provider.id()
));
};
Ok(url)
}
/// Validate that the locally configured `issuer_url` matches the issuer claimed
/// in any response. todo: cryptographic validation is not yet implemented here.
fn check_issuer(
response: JsonObject<String, JsonValue>,
provider: &Provider,
) -> Result<JsonObject<String, JsonValue>> {
let expected = provider
.issuer_url
.as_ref()
.map(Url::as_str)
.map(|url| url.trim_end_matches('/'));
let responded = response
.get("issuer")
.and_then(JsonValue::as_str)
.map(|url| url.trim_end_matches('/'));
if expected != responded {
return Err!(Request(Unauthorized(
"Configured issuer_url {expected:?} does not match discovered {responded:?}",
)));
}
Ok(response)
}
/// Generate a full URL for a request to the idp based on the idp's derived
/// configuration.
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
let mut suffix = provider.base_path.clone().unwrap_or_default();
suffix.push_str(path);
let url = provider
.issuer_url
.as_ref()
.ok_or_else(|| {
let id = &provider.client_id;
err!(Config("issuer_url", "Provider {id:?} required field"))
})?
.join(&suffix)?;
Ok(url)
}

View File

@@ -0,0 +1,234 @@
use std::{sync::Arc, time::SystemTime};
use futures::{Stream, StreamExt, TryFutureExt};
use ruma::{OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
use tuwunel_database::{Cbor, Deserialized, Map};
use url::Url;
use super::{Provider, Providers, UserInfo, unique_id};
use crate::SelfServices;
pub struct Sessions {
_services: SelfServices,
providers: Arc<Providers>,
db: Data,
}
struct Data {
oauthid_session: Arc<Map>,
oauthuniqid_oauthid: Arc<Map>,
userid_oauthid: Arc<Map>,
}
/// Session ultimately represents an OAuth authorization session yielding an
/// associated matrix user registration.
///
/// Mixed-use structure capable of deserializing response values, maintaining
/// the state between authorization steps, and maintaining the association to
/// the matrix user until deactivation or revocation.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Session {
/// Identity Provider ID (the `client_id` in the configuration) associated
/// with this session.
pub idp_id: Option<String>,
/// Session ID used as the index key for this session itself.
pub sess_id: Option<String>,
/// Token type (bearer, mac, etc).
pub token_type: Option<String>,
/// Access token to the provider.
pub access_token: Option<String>,
/// Duration in seconds the access_token is valid for.
pub expires_in: Option<u64>,
/// Point in time that the access_token expires.
pub expires_at: Option<SystemTime>,
/// Token used to refresh the access_token.
pub refresh_token: Option<String>,
/// Duration in seconds the refresh_token is valid for
pub refresh_token_expires_in: Option<u64>,
/// Point in time that the refresh_token expires.
pub refresh_token_expires_at: Option<SystemTime>,
/// Access scope actually granted (if supported).
pub scope: Option<String>,
/// Redirect URL
pub redirect_url: Option<Url>,
/// Challenge preimage
pub code_verifier: Option<String>,
/// Random string passed exclusively in the grant session cookie.
pub cookie_nonce: Option<String>,
/// Random single-use string passed in the provider redirect.
pub query_nonce: Option<String>,
/// Point in time the authorization grant session expires.
pub authorize_expires_at: Option<SystemTime>,
/// Associated User Id registration.
pub user_id: Option<OwnedUserId>,
/// Last userinfo response persisted here.
pub user_info: Option<UserInfo>,
}
/// Number of characters generated for our code_verifier. The code_verifier is a
/// random string which must be between 43 and 128 characters.
pub const CODE_VERIFIER_LENGTH: usize = 64;
/// Number of characters we will generate for the Session ID.
pub const SESSION_ID_LENGTH: usize = 32;
#[implement(Sessions)]
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
Self {
_services: args.services.clone(),
providers,
db: Data {
oauthid_session: args.db["oauthid_session"].clone(),
oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
userid_oauthid: args.db["userid_oauthid"].clone(),
},
}
}
#[implement(Sessions)]
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
self.db
.userid_oauthid
.keys()
.expect_ok()
.map(UserId::to_owned)
}
/// Delete database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn delete(&self, sess_id: &str) {
let Ok(session) = self.get(sess_id).await else {
return;
};
// Check the user_id still points to this sess_id before deleting. If not, the
// association was updated to a newer session.
if let Some(user_id) = session.user_id.as_deref() {
if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await {
if assoc_id == sess_id {
self.db.userid_oauthid.remove(user_id);
}
}
}
// Check the unique identity still points to this sess_id before deleting. If
// not, the association was updated to a newer session.
if let Some(idp_id) = session.idp_id.as_ref() {
if let Ok(provider) = self.providers.get(idp_id).await {
if let Ok(unique_id) = unique_id((&provider, &session)) {
if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await {
if assoc_id == sess_id {
self.db.oauthuniqid_oauthid.remove(&unique_id);
}
}
}
}
}
self.db.oauthid_session.remove(sess_id);
}
/// Create or overwrite database state for the session.
#[implement(Sessions)]
#[tracing::instrument(level = "info", skip(self))]
pub async fn put(&self, sess_id: &str, session: &Session) {
self.db
.oauthid_session
.raw_put(sess_id, Cbor(session));
if let Some(user_id) = session.user_id.as_deref() {
self.db.userid_oauthid.insert(user_id, sess_id);
}
if let Some(idp_id) = session.idp_id.as_ref() {
if let Ok(provider) = self.providers.get(idp_id).await {
if let Ok(unique_id) = unique_id((&provider, session)) {
self.db
.oauthuniqid_oauthid
.insert(&unique_id, sess_id);
}
}
}
}
/// Fetch database state for a session from its associated `sub`, in case
/// `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
self.get_sess_id_by_unique_id(unique_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its associated `user_id`, in case
/// `sess_id` is not known.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
self.get_sess_id_by_user(user_id)
.and_then(async |sess_id| self.get(&sess_id).await)
.await
}
/// Fetch database state for a session from its `sess_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get(&self, sess_id: &str) -> Result<Session> {
self.db
.oauthid_session
.get(sess_id)
.await
.deserialized::<Cbor<_>>()
.map(at!(0))
}
/// Resolve the `sess_id` from an associated `user_id`.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
self.db
.userid_oauthid
.get(user_id)
.await
.deserialized()
}
/// Resolve the `sess_id` from an associated provider subject id.
#[implement(Sessions)]
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
self.db
.oauthuniqid_oauthid
.get(unique_id)
.await
.deserialized()
}
#[implement(Sessions)]
pub async fn provider(&self, session: &Session) -> Result<Provider> {
let Some(idp_id) = session.idp_id.as_deref() else {
return Err!(Request(NotFound("No provider for this session")));
};
self.providers.get(idp_id).await
}

View File

@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};
/// Selection of userinfo response claims.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct UserInfo {
/// Unique identifier number or login username. Usually a number on most
/// services. We consider a concatenation of the `iss` and `sub` to be a
/// universally unique identifier for some user/identity; we index that in
/// `oauthidpsub_oauthid`.
///
/// Considered for user mxid only if none of the better fields are defined.
/// `login` alias intended for github.
#[serde(alias = "login")]
pub sub: String,
/// The login username we first consider when defined.
pub preferred_username: Option<String>,
/// The login username considered if none preferred.
pub nickname: Option<String>,
/// Full name.
pub name: Option<String>,
/// First name.
pub given_name: Option<String>,
/// Last name.
pub family_name: Option<String>,
/// Email address (`email` scope).
pub email: Option<String>,
/// URL to pfp (github/gitlab)
pub avatar_url: Option<String>,
/// URL to pfp (google)
pub picture: Option<String>,
}

View File

@@ -12,7 +12,7 @@ use crate::{
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
key_backups,
manager::Manager,
media, membership, presence, pusher, resolver, rooms, sending, server_keys,
media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys,
service::{Args, Service},
sync, transaction_ids, uiaa, users,
};
@@ -58,6 +58,7 @@ pub struct Services {
pub users: Arc<users::Service>,
pub membership: Arc<membership::Service>,
pub deactivate: Arc<deactivate::Service>,
pub oauth: Arc<oauth::Service>,
manager: Mutex<Option<Arc<Manager>>>,
pub server: Arc<Server>,
@@ -115,6 +116,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
users: users::Service::build(&args)?,
membership: membership::Service::build(&args)?,
deactivate: deactivate::Service::build(&args)?,
oauth: oauth::Service::build(&args)?,
manager: Mutex::new(None),
server,
@@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
cast!(self.users),
cast!(self.membership),
cast!(self.deactivate),
cast!(self.oauth),
]
.into_iter()
}

View File

@@ -14,10 +14,10 @@ use ruma::{
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
};
use tuwunel_core::{
Err, Result, debug_warn, err, is_equal_to,
Err, Result, debug_warn, err,
pdu::PduBuilder,
trace,
utils::{self, ReadyExt, stream::TryIgnore},
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
warn,
};
use tuwunel_database::{Deserialized, Json, Map};
@@ -132,6 +132,16 @@ impl Service {
/// Deactivate account
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
// Revoke any SSO authorizations
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
self.services
.oauth
.revoke_token((&provider, &session))
.await
.log_err()
.ok();
}
// Remove all associated devices
self.all_device_ids(user_id)
.for_each(|device_id| self.remove_device(user_id, device_id))
@@ -227,18 +237,15 @@ impl Service {
// Cannot change the password of a LDAP user. There are two special cases :
// - a `None` password can be used to deactivate a LDAP user
// - a "*" password is used as the default password of an active LDAP user
if cfg!(feature = "ldap")
&& password.is_some()
//
// The above now applies to all non-password origin users including SSO
// users. Note that users with no origin are also password-origin users.
if password.is_some()
&& password != Some("*")
&& self
.db
.userid_origin
.get(user_id)
.await
.deserialized::<String>()
.is_ok_and(is_equal_to!("ldap"))
&& let Ok(origin) = self.origin(user_id).await
&& origin != "password"
{
return Err!(Request(InvalidParam("Cannot change password of a LDAP user")));
return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user.")));
}
password

View File

@@ -1834,6 +1834,10 @@
#
#one_time_key_limit = 256
# This item is undocumented. Please contribute documentation for it.
#
#sso_aware_preferred = false
#[global.tls]
@@ -2097,6 +2101,131 @@
#[[global.identity_provider]]
# The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
# Google) or the software (e.g. keycloak, MAS) providing the identity.
# When a brand is recognized we apply certain defaults to this config
# for your convenience. For certain brands we apply essential internal
# workarounds specific to that provider; it is important to configure this
# field properly when a provider needs to be recognized (like GitHub for
# example). Several configured providers can share the same brand name. It
# is not case-sensitive.
#
#brand =
# The ID of your OAuth application which the provider generates upon
# registration. This ID then uniquely identifies this configuration
# instance itself, becoming the identity provider's ID and must be unique
# and remain unchanged.
#
#client_id =
# Secret key the provider generated for you along with the `client_id`
# above. Unlike the `client_id`, the `client_secret` can be changed here
# whenever the provider regenerates one for you.
#
#client_secret =
# The callback URL configured when registering the OAuth application with
# the provider. Tuwunel's callback URL must be strictly formatted exactly
# as instructed. The URL host must point directly at the matrix server and
# use the following path:
# `/_matrix/client/unstable/login/sso/callback/<client_id>` where
# `<client_id>` is the same one configured for this provider above.
#
#callback_url =
# Optional display-name for this provider instance seen on the login page
# by users. It defaults to `brand`. When configuring multiple providers
# using the same `brand` this can be set to distinguish them.
#
#name =
# Optional icon for the provider. The canonical providers have a default
# icon based on the `brand` supplied above when this is not supplied. Note
# that it uses an MXC url which is curious in the auth-media era and may
# not be reliable.
#
#icon =
# Optional list of scopes to authorize. An empty array does not impose any
# restrictions from here, effectively defaulting to all scopes you
# configured for the OAuth application at the provider. This setting
# allows for restricting to a subset of those scopes for this instance.
# Note the user can further restrict scopes during their authorization.
#
#scope = []
# List of userinfo claims which shape and restrict the way we compute a
# Matrix UserId for new registrations. Reviewing Tuwunel's documentation
# will be necessary for a complete description in detail. An empty array
# imposes no restriction here, avoiding generated fallbacks as much as
# possible. For simplicity we reserve a claim called "unique" which can be
# listed alone to ensure *only* generated ID's are used for registrations.
#
#userid_claims = []
# Issuer URL the provider publishes for you. We have pre-supplied default
# values for some of the canonical providers, making this field optional
# based on the `brand` set above. Otherwise it is required for OIDC
# discovery to acquire additional provider configuration, and it must be
# correct to pass validations during various interactions.
#
#issuer_url =
# Extra path components after the issuer_url leading to the location of
# the `.well-known` directory used for discovery. This will be empty for
# specification-compliant providers. We have supplied any known values
# based on `brand` (e.g. `/login/oauth` for GitHub).
#
#base_path =
# Overrides the `.well-known` location where the provider's OIDC
# configuration is found. It is very unlikely you will need to set this;
# available for developers or special purposes only.
#
#discovery_url =
# Overrides the authorize URL requested during the grant phase. This is
# generally discovered or derived automatically, but may be required as a
# workaround for any non-standard or undiscoverable provider.
#
#authorization_url =
# Overrides the access token URL; the same caveats apply as with the other
# URL overrides.
#
#token_url =
# Overrides the revocation URL; the same caveats apply as with the other
# URL overrides.
#
#revocation_url =
# Overrides the introspection URL; the same caveats apply as with the
# other URL overrides.
#
#introspection_url =
# Overrides the userinfo URL; the same caveats apply as with the other URL
# overrides.
#
#userinfo_url =
# Whether to perform discovery and adjust this provider's configuration
# accordingly. This defaults to true. When true, it is an error when
# discovery fails and authorizations will not be attempted to the
# provider.
#
#discovery = true
# The duration in seconds before a grant authorization session expires.
#
#grant_session_duration =
#[global.appservice.<ID>]
# The URL for the application service.