Implement SSO/OIDC support. (closes #7)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
27
Cargo.lock
generated
27
Cargo.lock
generated
@@ -383,6 +383,7 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"cookie",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
@@ -836,6 +837,17 @@ dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"time",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coolor"
|
||||
version = "1.1.0"
|
||||
@@ -1291,6 +1303,15 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enum-as-inner"
|
||||
version = "0.6.1"
|
||||
@@ -3582,6 +3603,7 @@ checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
@@ -3595,6 +3617,7 @@ dependencies = [
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
@@ -5141,6 +5164,7 @@ dependencies = [
|
||||
"tracing",
|
||||
"tuwunel_core",
|
||||
"tuwunel_service",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5185,6 +5209,8 @@ dependencies = [
|
||||
"ruma",
|
||||
"sanitize-filename",
|
||||
"serde",
|
||||
"serde_core",
|
||||
"serde_html_form",
|
||||
"serde_json",
|
||||
"serde_regex",
|
||||
"serde_yaml",
|
||||
@@ -5295,6 +5321,7 @@ dependencies = [
|
||||
"rustls",
|
||||
"rustyline-async",
|
||||
"serde",
|
||||
"serde_html_form",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"sha2",
|
||||
|
||||
@@ -76,6 +76,7 @@ version = "0.7"
|
||||
version = "0.10"
|
||||
default-features = false
|
||||
features = [
|
||||
"cookie",
|
||||
"typed-header",
|
||||
"tracing",
|
||||
]
|
||||
@@ -312,10 +313,12 @@ version = "1.12"
|
||||
version = "0.12"
|
||||
default-features = false
|
||||
features = [
|
||||
"rustls-tls-native-roots",
|
||||
"socks",
|
||||
"charset",
|
||||
"hickory-dns",
|
||||
"http2",
|
||||
"json",
|
||||
"rustls-tls-native-roots",
|
||||
"socks",
|
||||
]
|
||||
|
||||
[workspace.dependencies.ring]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod account_data;
|
||||
mod appservice;
|
||||
mod globals;
|
||||
mod oauth;
|
||||
mod presence;
|
||||
mod pusher;
|
||||
mod raw;
|
||||
@@ -18,10 +19,10 @@ use tuwunel_core::Result;
|
||||
|
||||
use self::{
|
||||
account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand,
|
||||
presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand, resolver::ResolverCommand,
|
||||
room_alias::RoomAliasCommand, room_state_cache::RoomStateCacheCommand,
|
||||
room_timeline::RoomTimelineCommand, sending::SendingCommand, short::ShortCommand,
|
||||
sync::SyncCommand, users::UsersCommand,
|
||||
oauth::OauthCommand, presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand,
|
||||
resolver::ResolverCommand, room_alias::RoomAliasCommand,
|
||||
room_state_cache::RoomStateCacheCommand, room_timeline::RoomTimelineCommand,
|
||||
sending::SendingCommand, short::ShortCommand, sync::SyncCommand, users::UsersCommand,
|
||||
};
|
||||
use crate::admin_command_dispatch;
|
||||
|
||||
@@ -81,6 +82,10 @@ pub(super) enum QueryCommand {
|
||||
#[command(subcommand)]
|
||||
Sync(SyncCommand),
|
||||
|
||||
/// - oauth service
|
||||
#[command(subcommand)]
|
||||
Oauth(OauthCommand),
|
||||
|
||||
/// - raw service
|
||||
#[command(subcommand)]
|
||||
Raw(RawCommand),
|
||||
|
||||
172
src/admin/query/oauth.rs
Normal file
172
src/admin/query/oauth.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use clap::Subcommand;
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use ruma::OwnedUserId;
|
||||
use tuwunel_core::{Err, Result, utils::stream::IterStream};
|
||||
use tuwunel_service::{
|
||||
Services,
|
||||
oauth::{Provider, Session},
|
||||
};
|
||||
|
||||
use crate::{admin_command, admin_command_dispatch};
|
||||
|
||||
#[admin_command_dispatch(handler_prefix = "oauth")]
|
||||
#[derive(Debug, Subcommand)]
|
||||
/// Query OAuth service state
|
||||
pub(crate) enum OauthCommand {
|
||||
/// List configured OAuth providers.
|
||||
ListProviders,
|
||||
|
||||
/// List users associated with an OAuth provider
|
||||
ListUsers,
|
||||
|
||||
/// Show active configuration of a provider.
|
||||
ShowProvider {
|
||||
id: String,
|
||||
|
||||
#[arg(long)]
|
||||
config: bool,
|
||||
},
|
||||
|
||||
/// Show session state
|
||||
ShowSession {
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Token introspection request to provider.
|
||||
TokenInfo {
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Revoke token for user_id or sess_id.
|
||||
Revoke {
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Remove oauth state (DANGER!)
|
||||
Remove {
|
||||
id: String,
|
||||
|
||||
#[arg(long)]
|
||||
force: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_list_providers(&self) -> Result {
|
||||
self.services
|
||||
.config
|
||||
.identity_provider
|
||||
.iter()
|
||||
.try_stream()
|
||||
.map_ok(Provider::id)
|
||||
.map_ok(|id| format!("{id}\n"))
|
||||
.try_for_each(async |id| self.write_str(&id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_list_users(&self) -> Result {
|
||||
self.services
|
||||
.oauth
|
||||
.sessions
|
||||
.users()
|
||||
.map(|id| format!("{id}\n"))
|
||||
.map(Ok)
|
||||
.try_for_each(async |id: String| self.write_str(&id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Result {
|
||||
if config {
|
||||
let config = self.services.oauth.providers.get_config(&id)?;
|
||||
|
||||
self.write_str(&format!("{config:#?}\n")).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = self.services.oauth.providers.get(&id).await?;
|
||||
|
||||
self.write_str(&format!("{provider:#?}\n")).await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_show_session(&self, id: String) -> Result {
|
||||
let session = find_session(self.services, &id).await?;
|
||||
|
||||
self.write_str(&format!("{session:#?}\n")).await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_token_info(&self, id: String) -> Result {
|
||||
let session = find_session(self.services, &id).await?;
|
||||
|
||||
let provider = self
|
||||
.services
|
||||
.oauth
|
||||
.sessions
|
||||
.provider(&session)
|
||||
.await?;
|
||||
|
||||
let tokeninfo = self
|
||||
.services
|
||||
.oauth
|
||||
.request_tokeninfo((&provider, &session))
|
||||
.await;
|
||||
|
||||
self.write_str(&format!("{tokeninfo:#?}\n")).await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_revoke(&self, id: String) -> Result {
|
||||
let session = find_session(self.services, &id).await?;
|
||||
|
||||
let provider = self
|
||||
.services
|
||||
.oauth
|
||||
.sessions
|
||||
.provider(&session)
|
||||
.await?;
|
||||
|
||||
self.services
|
||||
.oauth
|
||||
.revoke_token((&provider, &session))
|
||||
.await?;
|
||||
|
||||
self.write_str("done").await
|
||||
}
|
||||
|
||||
#[admin_command]
|
||||
pub(super) async fn oauth_remove(&self, id: String, force: bool) -> Result {
|
||||
let session = find_session(self.services, &id).await?;
|
||||
|
||||
let Some(sess_id) = session.sess_id else {
|
||||
return Err!("Missing sess_id in oauth Session state");
|
||||
};
|
||||
|
||||
if !force {
|
||||
return Err!(
|
||||
"Deleting these records can cause registration conflicts. Use --force to be sure."
|
||||
);
|
||||
}
|
||||
|
||||
self.services
|
||||
.oauth
|
||||
.sessions
|
||||
.delete(&sess_id)
|
||||
.await;
|
||||
|
||||
self.write_str("done").await
|
||||
}
|
||||
|
||||
async fn find_session(services: &Services, id: &str) -> Result<Session> {
|
||||
if let Ok(user_id) = OwnedUserId::parse(id) {
|
||||
services
|
||||
.oauth
|
||||
.sessions
|
||||
.get_by_user(&user_id)
|
||||
.await
|
||||
} else {
|
||||
services.oauth.sessions.get(id).await
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,7 @@ tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tuwunel-core.workspace = true
|
||||
tuwunel-service.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -4,6 +4,7 @@ mod ldap;
|
||||
mod logout;
|
||||
mod password;
|
||||
mod refresh;
|
||||
mod sso;
|
||||
mod token;
|
||||
|
||||
use axum::extract::State;
|
||||
@@ -12,8 +13,8 @@ use ruma::api::client::session::{
|
||||
get_login_types::{
|
||||
self,
|
||||
v3::{
|
||||
ApplicationServiceLoginType, JwtLoginType, LoginType, PasswordLoginType,
|
||||
TokenLoginType,
|
||||
ApplicationServiceLoginType, IdentityProvider, JwtLoginType, LoginType,
|
||||
PasswordLoginType, SsoLoginType, TokenLoginType,
|
||||
},
|
||||
},
|
||||
login::{
|
||||
@@ -28,6 +29,7 @@ use self::{ldap::ldap_login, password::password_login};
|
||||
pub(crate) use self::{
|
||||
logout::{logout_all_route, logout_route},
|
||||
refresh::refresh_token_route,
|
||||
sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route},
|
||||
token::login_token_route,
|
||||
};
|
||||
use super::TOKEN_LENGTH;
|
||||
@@ -50,6 +52,20 @@ pub(crate) async fn get_login_types_route(
|
||||
LoginType::Token(TokenLoginType {
|
||||
get_login_token: services.config.login_via_existing_session,
|
||||
}),
|
||||
LoginType::Sso(SsoLoginType {
|
||||
identity_providers: services
|
||||
.config
|
||||
.identity_provider
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|config| IdentityProvider {
|
||||
id: config.id().to_owned(),
|
||||
brand: Some(config.brand.clone().into()),
|
||||
icon: config.icon,
|
||||
name: config.name.unwrap_or(config.brand),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
]))
|
||||
}
|
||||
|
||||
@@ -89,6 +105,10 @@ pub(crate) async fn login_route(
|
||||
},
|
||||
};
|
||||
|
||||
if !services.users.is_active_local(&user_id).await {
|
||||
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||
}
|
||||
|
||||
// Generate a new token for the device
|
||||
let (access_token, expires_in) = services
|
||||
.users
|
||||
|
||||
601
src/api/client/session/sso.rs
Normal file
601
src/api/client/session/sso.rs
Normal file
@@ -0,0 +1,601 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||
use futures::{StreamExt, TryFutureExt, future::try_join};
|
||||
use itertools::Itertools;
|
||||
use reqwest::header::{CONTENT_TYPE, HeaderValue};
|
||||
use ruma::{
|
||||
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
|
||||
api::client::session::{sso_callback, sso_login, sso_login_with_provider},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{
|
||||
Err, Result, at,
|
||||
debug::INFO_SPAN_LEVEL,
|
||||
debug_info, debug_warn, err, info, utils,
|
||||
utils::{
|
||||
content_disposition::make_content_disposition,
|
||||
hash::sha256,
|
||||
result::{FlatOk, LogErr},
|
||||
string::{EMPTY, truncate_deterministic},
|
||||
timepoint_from_now, timepoint_has_passed,
|
||||
},
|
||||
warn,
|
||||
};
|
||||
use tuwunel_service::{
|
||||
Services,
|
||||
media::MXC_LENGTH,
|
||||
oauth::{
|
||||
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id,
|
||||
unique_id_sub,
|
||||
},
|
||||
users::Register,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::TOKEN_LENGTH;
|
||||
use crate::Ruma;
|
||||
|
||||
/// Grant phase query string.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GrantQuery<'a> {
|
||||
client_id: &'a str,
|
||||
state: &'a str,
|
||||
nonce: &'a str,
|
||||
scope: &'a str,
|
||||
response_type: &'a str,
|
||||
access_type: &'a str,
|
||||
code_challenge_method: &'a str,
|
||||
code_challenge: &'a str,
|
||||
redirect_uri: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct GrantCookie<'a> {
|
||||
client_id: &'a str,
|
||||
state: &'a str,
|
||||
nonce: &'a str,
|
||||
redirect_uri: &'a str,
|
||||
}
|
||||
|
||||
static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_login",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%client),
|
||||
)]
|
||||
pub(crate) async fn sso_login_route(
|
||||
State(_services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
_body: Ruma<sso_login::v3::Request>,
|
||||
) -> Result<sso_login::v3::Response> {
|
||||
Err!(Request(NotImplemented(
|
||||
"SSO login without specific provider has not been implemented."
|
||||
)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_login_with_provider",
|
||||
level = "info",
|
||||
skip_all,
|
||||
ret(level = "debug")
|
||||
fields(
|
||||
%client,
|
||||
idp_id = body.body.idp_id,
|
||||
),
|
||||
)]
|
||||
pub(crate) async fn sso_login_with_provider_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<sso_login_with_provider::v3::Request>,
|
||||
) -> Result<sso_login_with_provider::v3::Response> {
|
||||
let sso_login_with_provider::v3::Request { idp_id, redirect_url } = body.body;
|
||||
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
|
||||
return Err!(Request(InvalidParam("Invalid redirect_url")));
|
||||
};
|
||||
|
||||
let provider = services.oauth.providers.get(&idp_id).await?;
|
||||
let sess_id = utils::random_string(SESSION_ID_LENGTH);
|
||||
let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
|
||||
let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
|
||||
let scope = provider.scope.iter().join(" ");
|
||||
|
||||
let query = GrantQuery {
|
||||
client_id: &provider.client_id,
|
||||
state: &sess_id,
|
||||
nonce: &query_nonce,
|
||||
access_type: "online",
|
||||
response_type: "code",
|
||||
code_challenge_method: "S256",
|
||||
code_challenge: &code_challenge,
|
||||
redirect_uri: callback_uri,
|
||||
scope: scope
|
||||
.is_empty()
|
||||
.then_some("openid email profile")
|
||||
.unwrap_or(scope.as_str()),
|
||||
};
|
||||
|
||||
let location = provider
|
||||
.authorization_url
|
||||
.clone()
|
||||
.map(|mut location| {
|
||||
let query = serde_html_form::to_string(&query).ok();
|
||||
location.set_query(query.as_deref());
|
||||
location
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
err!(Config("authorization_url", "Missing required IdentityProvider config"))
|
||||
})?;
|
||||
|
||||
let cookie_val = GrantCookie {
|
||||
client_id: query.client_id,
|
||||
state: query.state,
|
||||
nonce: &cookie_nonce,
|
||||
redirect_uri: redirect_url.as_str(),
|
||||
};
|
||||
|
||||
let cookie_path = provider
|
||||
.callback_url
|
||||
.as_ref()
|
||||
.map(Url::path)
|
||||
.unwrap_or("/");
|
||||
|
||||
let cookie_max_age = provider
|
||||
.grant_session_duration
|
||||
.map(Duration::from_secs)
|
||||
.expect("Defaulted to Some value during configure_idp()")
|
||||
.try_into()
|
||||
.expect("std::time::Duration to time::Duration conversion failure");
|
||||
|
||||
let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
|
||||
.path(cookie_path)
|
||||
.max_age(cookie_max_age)
|
||||
.same_site(SameSite::None)
|
||||
.secure(true)
|
||||
.http_only(true)
|
||||
.build()
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
let session = Session {
|
||||
idp_id: Some(idp_id),
|
||||
sess_id: Some(sess_id.clone()),
|
||||
redirect_url: Some(redirect_url),
|
||||
code_verifier: Some(code_verifier),
|
||||
query_nonce: Some(query_nonce),
|
||||
cookie_nonce: Some(cookie_nonce),
|
||||
authorize_expires_at: provider
|
||||
.grant_session_duration
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?,
|
||||
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
services
|
||||
.oauth
|
||||
.sessions
|
||||
.put(&sess_id, &session)
|
||||
.await;
|
||||
|
||||
Ok(sso_login_with_provider::v3::Response {
|
||||
location: location.into(),
|
||||
cookie: Some(cookie),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_callback"
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
%client,
|
||||
cookie = ?body.cookie,
|
||||
body = ?body.body,
|
||||
),
|
||||
)]
|
||||
pub(crate) async fn sso_callback_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<sso_callback::unstable::Request>,
|
||||
) -> Result<sso_callback::unstable::Response> {
|
||||
let sess_id = body
|
||||
.body
|
||||
.state
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
|
||||
|
||||
let code = body
|
||||
.body
|
||||
.code
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
|
||||
|
||||
let session = services
|
||||
.oauth
|
||||
.sessions
|
||||
.get(sess_id)
|
||||
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
|
||||
|
||||
let idp_id = body.body.idp_id.as_str();
|
||||
let provider = services.oauth.providers.get(idp_id);
|
||||
let (provider, session) = try_join(provider, session).await.log_err()?;
|
||||
let client_id = &provider.client_id;
|
||||
|
||||
if session.idp_id.as_deref() != Some(idp_id) {
|
||||
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
|
||||
}
|
||||
|
||||
if session
|
||||
.authorize_expires_at
|
||||
.map(timepoint_has_passed)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Err!(Request(Unauthorized("Authorization grant session has expired.")));
|
||||
}
|
||||
|
||||
let cookie = body
|
||||
.cookie
|
||||
.get(GRANT_SESSION_COOKIE)
|
||||
.map(Cookie::value)
|
||||
.map(serde_html_form::from_str::<GrantCookie<'_>>)
|
||||
.transpose()?
|
||||
.ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
|
||||
|
||||
if cookie.client_id != client_id {
|
||||
return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
|
||||
}
|
||||
|
||||
if Some(cookie.nonce) != session.cookie_nonce.as_deref() {
|
||||
return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
|
||||
}
|
||||
|
||||
if cookie.state != sess_id {
|
||||
return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
|
||||
}
|
||||
|
||||
// Request access token.
|
||||
let token_response = services
|
||||
.oauth
|
||||
.request_token((&provider, &session), code)
|
||||
.await?;
|
||||
|
||||
let token_expires_at = token_response
|
||||
.expires_in
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?;
|
||||
|
||||
let refresh_token_expires_at = token_response
|
||||
.refresh_token_expires_in
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?;
|
||||
|
||||
// Update the session with access token results
|
||||
let session = Session {
|
||||
scope: token_response.scope,
|
||||
token_type: token_response.token_type,
|
||||
access_token: token_response.access_token,
|
||||
expires_at: token_expires_at,
|
||||
refresh_token: token_response.refresh_token,
|
||||
refresh_token_expires_at,
|
||||
..session
|
||||
};
|
||||
|
||||
// Request userinfo claims.
|
||||
let userinfo = services
|
||||
.oauth
|
||||
.request_userinfo((&provider, &session))
|
||||
.await?;
|
||||
|
||||
// Check for an existing session from this identity. We want to maintain one
|
||||
// session for each identity and keep the newer one which has up-to-date state
|
||||
// and access.
|
||||
let (user_id, old_sess_id) = match services
|
||||
.oauth
|
||||
.sessions
|
||||
.get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?)
|
||||
.await
|
||||
{
|
||||
| Ok(session) => (session.user_id, session.sess_id),
|
||||
| Err(error) if !error.is_not_found() => return Err(error),
|
||||
| Err(_) => (None, None),
|
||||
};
|
||||
|
||||
// Update the session with userinfo
|
||||
let session = Session {
|
||||
user_info: Some(userinfo.clone()),
|
||||
..session
|
||||
};
|
||||
|
||||
// Keep the user_id from the old session as best as possible.
|
||||
let user_id = match user_id {
|
||||
| Some(user_id) => user_id,
|
||||
| None => decide_user_id(&services, &provider, &session, &userinfo).await?,
|
||||
};
|
||||
|
||||
// Update the session with user_id
|
||||
let session = Session {
|
||||
user_id: Some(user_id.clone()),
|
||||
..session
|
||||
};
|
||||
|
||||
// Attempt to register a non-existing user.
|
||||
if !services.users.exists(&user_id).await {
|
||||
register_user(&services, &provider, &session, &userinfo, &user_id).await?;
|
||||
}
|
||||
|
||||
// Commit the updated session.
|
||||
services
|
||||
.oauth
|
||||
.sessions
|
||||
.put(sess_id, &session)
|
||||
.await;
|
||||
|
||||
// Delete any old session.
|
||||
if let Some(old_sess_id) = old_sess_id
|
||||
&& sess_id != old_sess_id
|
||||
{
|
||||
services.oauth.sessions.delete(&old_sess_id).await;
|
||||
}
|
||||
|
||||
if !services.users.is_active_local(&user_id).await {
|
||||
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||
}
|
||||
|
||||
// Allow the user to login to Matrix.
|
||||
let login_token = utils::random_string(TOKEN_LENGTH);
|
||||
let _login_token_expires_in = services
|
||||
.users
|
||||
.create_login_token(&user_id, &login_token);
|
||||
|
||||
let location = session
|
||||
.redirect_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
|
||||
.clone()
|
||||
.query_pairs_mut()
|
||||
.append_pair("loginToken", &login_token)
|
||||
.finish()
|
||||
.to_string();
|
||||
|
||||
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
|
||||
.removal()
|
||||
.build()
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "register",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
fields(user_id, userinfo)
|
||||
)]
|
||||
async fn register_user(
|
||||
services: &Services,
|
||||
provider: &Provider,
|
||||
session: &Session,
|
||||
userinfo: &UserInfo,
|
||||
user_id: &UserId,
|
||||
) -> Result {
|
||||
debug_info!(%user_id, "Creating new user account...");
|
||||
|
||||
services
|
||||
.users
|
||||
.full_register(Register {
|
||||
user_id: Some(user_id),
|
||||
password: Some("*"),
|
||||
origin: Some("sso"),
|
||||
displayname: userinfo.name.as_deref(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
if let Some(avatar_url) = userinfo
|
||||
.avatar_url
|
||||
.as_deref()
|
||||
.or(userinfo.picture.as_deref())
|
||||
{
|
||||
set_avatar(services, provider, session, userinfo, user_id, avatar_url)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
let idp_id = provider.id();
|
||||
let idp_name = provider
|
||||
.name
|
||||
.as_deref()
|
||||
.unwrap_or(provider.brand.as_str());
|
||||
|
||||
// log in conduit admin channel if a non-guest user registered
|
||||
let notice =
|
||||
format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})",);
|
||||
|
||||
info!("{notice}");
|
||||
if services.server.config.admin_room_notices {
|
||||
services.admin.notice(¬ice).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
|
||||
async fn set_avatar(
|
||||
services: &Services,
|
||||
_provider: &Provider,
|
||||
_session: &Session,
|
||||
_userinfo: &UserInfo,
|
||||
user_id: &UserId,
|
||||
avatar_url: &str,
|
||||
) -> Result {
|
||||
use reqwest::Response;
|
||||
|
||||
let response = services
|
||||
.client
|
||||
.default
|
||||
.get(avatar_url)
|
||||
.send()
|
||||
.await
|
||||
.and_then(Response::error_for_status)?;
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.map(HeaderValue::to_str)
|
||||
.flat_ok()
|
||||
.map(ToOwned::to_owned);
|
||||
|
||||
let mxc = Mxc {
|
||||
server_name: services.globals.server_name(),
|
||||
media_id: &utils::random_string(MXC_LENGTH),
|
||||
};
|
||||
|
||||
let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
|
||||
let bytes = response.bytes().await?;
|
||||
services
|
||||
.media
|
||||
.create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
|
||||
.await?;
|
||||
|
||||
let all_joined_rooms: Vec<OwnedRoomId> = services
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.map(ToOwned::to_owned)
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let mxc_uri = mxc.to_string().into();
|
||||
services
|
||||
.users
|
||||
.update_avatar_url(user_id, Some(mxc_uri), None, &all_joined_rooms)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
ret(level = "debug")
|
||||
skip_all,
|
||||
fields(user),
|
||||
)]
|
||||
async fn decide_user_id(
|
||||
services: &Services,
|
||||
provider: &Provider,
|
||||
session: &Session,
|
||||
userinfo: &UserInfo,
|
||||
) -> Result<OwnedUserId> {
|
||||
let allowed =
|
||||
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
|
||||
|
||||
let choices = [
|
||||
userinfo
|
||||
.preferred_username
|
||||
.as_deref()
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("preferred_username")),
|
||||
userinfo
|
||||
.nickname
|
||||
.as_deref()
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("nickname")),
|
||||
provider
|
||||
.brand
|
||||
.eq(&"github")
|
||||
.then_some(userinfo.sub.as_str())
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("login")),
|
||||
userinfo
|
||||
.email
|
||||
.as_deref()
|
||||
.and_then(|email| email.split_once('@'))
|
||||
.map(at!(0))
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("email")),
|
||||
];
|
||||
|
||||
for choice in choices.into_iter().flatten() {
|
||||
if let Some(user_id) = try_user_id(services, &choice, false).await {
|
||||
return Ok(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(infallible) = unique_id((provider, session))
|
||||
.map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase())
|
||||
.log_err()
|
||||
{
|
||||
if let Some(user_id) = try_user_id(services, &infallible, true).await {
|
||||
return Ok(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
Err!(Request(UserInUse("User ID is not available.")))
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, fields(username))]
|
||||
async fn try_user_id(
|
||||
services: &Services,
|
||||
username: &str,
|
||||
may_exist: bool,
|
||||
) -> Option<OwnedUserId> {
|
||||
let server_name = services.globals.server_name();
|
||||
let user_id = parse_user_id(server_name, username)
|
||||
.inspect_err(|e| warn!(?username, "Username invalid: {e}"))
|
||||
.ok()?;
|
||||
|
||||
if services
|
||||
.config
|
||||
.forbidden_usernames
|
||||
.is_match(username)
|
||||
{
|
||||
warn!(?username, "Username forbidden.");
|
||||
return None;
|
||||
}
|
||||
|
||||
if services.users.exists(&user_id).await {
|
||||
debug_warn!(?username, "Username exists.");
|
||||
|
||||
if services
|
||||
.users
|
||||
.origin(&user_id)
|
||||
.await
|
||||
.ok()
|
||||
.is_none_or(|origin| origin != "sso")
|
||||
{
|
||||
debug_warn!(?username, "Username has non-sso origin.");
|
||||
return None;
|
||||
}
|
||||
|
||||
if !may_exist {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(user_id)
|
||||
}
|
||||
|
||||
fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
|
||||
match UserId::parse_with_server_name(username, server_name) {
|
||||
| Err(e) =>
|
||||
Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}")))),
|
||||
| Ok(user_id) => match user_id.validate_strict() {
|
||||
| Ok(()) => Ok(user_id),
|
||||
| Err(e) => Err!(Request(InvalidUsername(debug_error!(
|
||||
"Username {username} contains disallowed characters or spaces: {e}"
|
||||
)))),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,9 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::login_route)
|
||||
.ruma_route(&client::login_token_route)
|
||||
.ruma_route(&client::refresh_token_route)
|
||||
.ruma_route(&client::sso_login_route)
|
||||
.ruma_route(&client::sso_login_with_provider_route)
|
||||
.ruma_route(&client::sso_callback_route)
|
||||
.ruma_route(&client::whoami_route)
|
||||
.ruma_route(&client::logout_route)
|
||||
.ruma_route(&client::logout_all_route)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{fmt::Debug, mem, ops::Deref};
|
||||
|
||||
use axum::{body::Body, extract::FromRequest};
|
||||
use axum_extra::extract::cookie::CookieJar;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use ruma::{
|
||||
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
|
||||
@@ -18,6 +19,9 @@ pub(crate) struct Args<T> {
|
||||
/// Request struct body
|
||||
pub(crate) body: T,
|
||||
|
||||
/// Cookies received from the useragent.
|
||||
pub(crate) cookie: CookieJar,
|
||||
|
||||
/// Federation server authentication: X-Matrix origin
|
||||
/// None when not a federation server.
|
||||
pub(crate) origin: Option<OwnedServerName>,
|
||||
@@ -110,6 +114,7 @@ where
|
||||
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
||||
Ok(Self {
|
||||
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
|
||||
cookie: request.cookie,
|
||||
origin: auth.origin,
|
||||
sender_user: auth.sender_user,
|
||||
sender_device: auth.sender_device,
|
||||
|
||||
@@ -5,7 +5,7 @@ use ruma::{
|
||||
client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo},
|
||||
},
|
||||
};
|
||||
use tuwunel_core::{Err, Error, Result, err, utils};
|
||||
use tuwunel_core::{Err, Error, Result, err, is_equal_to, utils};
|
||||
use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH};
|
||||
|
||||
use crate::{Ruma, client::jwt};
|
||||
@@ -67,9 +67,19 @@ where
|
||||
| Some(ref json) => {
|
||||
let sender_user = body
|
||||
.sender_user
|
||||
.as_ref()
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
|
||||
|
||||
// Skip UIAA for SSO/OIDC users.
|
||||
if services
|
||||
.users
|
||||
.origin(sender_user)
|
||||
.await
|
||||
.is_ok_and(is_equal_to!("sso"))
|
||||
{
|
||||
return Ok(sender_user.to_owned());
|
||||
}
|
||||
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services
|
||||
.uiaa
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::str;
|
||||
|
||||
use axum::{RequestExt, RequestPartsExt, extract::Path};
|
||||
use axum_extra::extract::cookie::CookieJar;
|
||||
use bytes::Bytes;
|
||||
use http::request::Parts;
|
||||
use serde::Deserialize;
|
||||
@@ -17,6 +18,7 @@ pub(super) type UserId = SmallString<[u8; 48]>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct Request {
|
||||
pub(super) cookie: CookieJar,
|
||||
pub(super) path: Path<PathParams>,
|
||||
pub(super) query: QueryParams,
|
||||
pub(super) body: Bytes,
|
||||
@@ -33,6 +35,7 @@ pub(super) async fn from(
|
||||
let limited = request.with_limited_body();
|
||||
let (mut parts, body) = limited.into_parts();
|
||||
|
||||
let cookie: CookieJar = parts.extract().await?;
|
||||
let path: Path<PathParams> = parts.extract().await?;
|
||||
let query = parts.uri.query().unwrap_or_default();
|
||||
let query = serde_html_form::from_str(query)
|
||||
@@ -44,5 +47,5 @@ pub(super) async fn from(
|
||||
.await
|
||||
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?;
|
||||
|
||||
Ok(Request { path, query, body, parts })
|
||||
Ok(Request { cookie, path, query, body, parts })
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ pub mod manager;
|
||||
pub mod proxy;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
collections::{BTreeMap, BTreeSet, HashSet},
|
||||
hash::{Hash, Hasher},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
@@ -16,7 +17,7 @@ use figment::providers::{Env, Format, Toml};
|
||||
pub use figment::{Figment, value::Value as FigmentValue};
|
||||
use regex::RegexSet;
|
||||
use ruma::{
|
||||
OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
|
||||
OwnedMxcUri, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
|
||||
api::client::discovery::discover_support::ContactRole,
|
||||
};
|
||||
use serde::{Deserialize, de::IgnoredAny};
|
||||
@@ -28,6 +29,7 @@ pub use self::{check::check, manager::Manager};
|
||||
use crate::{
|
||||
Result, err,
|
||||
error::Error,
|
||||
utils,
|
||||
utils::{string::EMPTY, sys},
|
||||
};
|
||||
|
||||
@@ -57,7 +59,7 @@ use crate::{
|
||||
### https://tuwunel.chat/configuration.html
|
||||
"#,
|
||||
ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates ldap jwt \
|
||||
appservice"
|
||||
appservice identity_provider"
|
||||
)]
|
||||
pub struct Config {
|
||||
/// The server_name is the pretty name of this server. It is used as a
|
||||
@@ -2153,6 +2155,13 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub appservice: BTreeMap<String, AppService>,
|
||||
|
||||
// external structure; separate sections
|
||||
#[serde(default)]
|
||||
pub identity_provider: HashSet<IdentityProvider>,
|
||||
|
||||
#[serde(default)]
|
||||
pub sso_aware_preferred: bool,
|
||||
|
||||
#[serde(flatten)]
|
||||
#[allow(clippy::zero_sized_map_values)]
|
||||
// this is a catchall, the map shouldn't be zero at runtime
|
||||
@@ -2468,6 +2477,134 @@ pub struct JwtConfig {
|
||||
pub validate_signature: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[config_example_generator(
|
||||
filename = "tuwunel-example.toml",
|
||||
section = "[global.identity_provider]"
|
||||
)]
|
||||
pub struct IdentityProvider {
|
||||
/// The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
|
||||
/// Google) or the software (e.g. keycloak, MAS) providing the identity.
|
||||
/// When a brand is recognized we apply certain defaults to this config
|
||||
/// for your convenience. For certain brands we apply essential internal
|
||||
/// workarounds specific to that provider; it is important to configure this
|
||||
/// field properly when a provider needs to be recognized (like GitHub for
|
||||
/// example). Several configured providers can share the same brand name. It
|
||||
/// is not case-sensitive.
|
||||
#[serde(deserialize_with = "utils::string::de::to_lowercase")]
|
||||
pub brand: String,
|
||||
|
||||
/// The ID of your OAuth application which the provider generates upon
|
||||
/// registration. This ID then uniquely identifies this configuration
|
||||
/// instance itself, becoming the identity provider's ID and must be unique
|
||||
/// and remain unchanged.
|
||||
pub client_id: String,
|
||||
|
||||
/// Secret key the provider generated for you along with the `client_id`
|
||||
/// above. Unlike the `client_id`, the `client_secret` can be changed here
|
||||
/// whenever the provider regenerates one for you.
|
||||
pub client_secret: String,
|
||||
|
||||
/// The callback URL configured when registering the OAuth application with
|
||||
/// the provider. Tuwunel's callback URL must be strictly formatted exactly
|
||||
/// as instructed. The URL host must point directly at the matrix server and
|
||||
/// use the following path:
|
||||
/// `/_matrix/client/unstable/login/sso/callback/<client_id>` where
|
||||
/// `<client_id>` is the same one configured for this provider above.
|
||||
pub callback_url: Option<Url>,
|
||||
|
||||
/// Optional display-name for this provider instance seen on the login page
|
||||
/// by users. It defaults to `brand`. When configuring multiple providers
|
||||
/// using the same `brand` this can be set to distinguish them.
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Optional icon for the provider. The canonical providers have a default
|
||||
/// icon based on the `brand` supplied above when this is not supplied. Note
|
||||
/// that it uses an MXC url which is curious in the auth-media era and may
|
||||
/// not be reliable.
|
||||
pub icon: Option<OwnedMxcUri>,
|
||||
|
||||
/// Optional list of scopes to authorize. An empty array does not impose any
|
||||
/// restrictions from here, effectively defaulting to all scopes you
|
||||
/// configured for the OAuth application at the provider. This setting
|
||||
/// allows for restricting to a subset of those scopes for this instance.
|
||||
/// Note the user can further restrict scopes during their authorization.
|
||||
///
|
||||
/// default: []
|
||||
#[serde(default)]
|
||||
pub scope: BTreeSet<String>,
|
||||
|
||||
/// List of userinfo claims which shape and restrict the way we compute a
|
||||
/// Matrix UserId for new registrations. Reviewing Tuwunel's documentation
|
||||
/// will be necessary for a complete description in detail. An empty array
|
||||
/// imposes no restriction here, avoiding generated fallbacks as much as
|
||||
/// possible. For simplicity we reserve a claim called "unique" which can be
|
||||
/// listed alone to ensure *only* generated ID's are used for registrations.
|
||||
///
|
||||
/// default: []
|
||||
#[serde(default)]
|
||||
pub userid_claims: BTreeSet<String>,
|
||||
|
||||
/// Issuer URL the provider publishes for you. We have pre-supplied default
|
||||
/// values for some of the canonical providers, making this field optional
|
||||
/// based on the `brand` set above. Otherwise it is required for OIDC
|
||||
/// discovery to acquire additional provider configuration, and it must be
|
||||
/// correct to pass validations during various interactions.
|
||||
pub issuer_url: Option<Url>,
|
||||
|
||||
/// Extra path components after the issuer_url leading to the location of
|
||||
/// the `.well-known` directory used for discovery. This will be empty for
|
||||
/// specification-compliant providers. We have supplied any known values
|
||||
/// based on `brand` (e.g. `/login/oauth` for GitHub).
|
||||
pub base_path: Option<String>,
|
||||
|
||||
/// Overrides the `.well-known` location where the provider's OIDC
|
||||
/// configuration is found. It is very unlikely you will need to set this;
|
||||
/// available for developers or special purposes only.
|
||||
pub discovery_url: Option<Url>,
|
||||
|
||||
/// Overrides the authorize URL requested during the grant phase. This is
|
||||
/// generally discovered or derived automatically, but may be required as a
|
||||
/// workaround for any non-standard or undiscoverable provider.
|
||||
pub authorization_url: Option<Url>,
|
||||
|
||||
/// Overrides the access token URL; the same caveats apply as with the other
|
||||
/// URL overrides.
|
||||
pub token_url: Option<Url>,
|
||||
|
||||
/// Overrides the revocation URL; the same caveats apply as with the other
|
||||
/// URL overrides.
|
||||
pub revocation_url: Option<Url>,
|
||||
|
||||
/// Overrides the introspection URL; the same caveats apply as with the
|
||||
/// other URL overrides.
|
||||
pub introspection_url: Option<Url>,
|
||||
|
||||
/// Overrides the userinfo URL; the same caveats apply as with the other URL
|
||||
/// overrides.
|
||||
pub userinfo_url: Option<Url>,
|
||||
|
||||
/// Whether to perform discovery and adjust this provider's configuration
|
||||
/// accordingly. This defaults to true. When true, it is an error when
|
||||
/// discovery fails and authorizations will not be attempted to the
|
||||
/// provider.
|
||||
#[serde(default = "true_fn")]
|
||||
pub discovery: bool,
|
||||
|
||||
/// The duration in seconds before a grant authorization session expires.
|
||||
#[serde(default = "default_sso_grant_session_duration")]
|
||||
pub grant_session_duration: Option<u64>,
|
||||
}
|
||||
|
||||
impl IdentityProvider {
|
||||
#[must_use]
|
||||
pub fn id(&self) -> &str { self.client_id.as_str() }
|
||||
}
|
||||
|
||||
impl Hash for IdentityProvider {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) { self.id().hash(state) }
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize)]
|
||||
#[config_example_generator(
|
||||
filename = "tuwunel-example.toml",
|
||||
@@ -3016,3 +3153,5 @@ fn default_one_time_key_limit() -> usize { 256 }
|
||||
fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 }
|
||||
|
||||
fn default_max_join_attempts_per_join_request() -> usize { 3 }
|
||||
|
||||
fn default_sso_grant_session_duration() -> Option<u64> { Some(180) }
|
||||
|
||||
@@ -113,6 +113,14 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "mediaid_user",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oauthid_session",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oauthuniqid_oauthid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "onetimekeyid_onetimekeys",
|
||||
..descriptor::RANDOM_SMALL
|
||||
@@ -403,6 +411,10 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "userid_masterkeyid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userid_oauthid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "userid_origin",
|
||||
..descriptor::RANDOM
|
||||
|
||||
@@ -110,6 +110,7 @@ ruma.workspace = true
|
||||
rustls.workspace = true
|
||||
rustyline-async.workspace = true
|
||||
rustyline-async.optional = true
|
||||
serde_html_form.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
serde_yaml.workspace = true
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct Service {
|
||||
pub sender: ClientLazylock,
|
||||
pub appservice: ClientLazylock,
|
||||
pub pusher: ClientLazylock,
|
||||
pub oauth: ClientLazylock,
|
||||
|
||||
pub cidr_range_denylist: Vec<IPAddress>,
|
||||
}
|
||||
@@ -112,6 +113,11 @@ impl crate::Service for Service {
|
||||
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
|
||||
.redirect(redirect::Policy::limited(2))),
|
||||
|
||||
oauth: create_client!(config, services; base(config)?
|
||||
.dns_resolver2(Arc::clone(&services.resolver.resolver))
|
||||
.redirect(redirect::Policy::limited(0))
|
||||
.pool_max_idle_per_host(1)),
|
||||
|
||||
cidr_range_denylist: config
|
||||
.ip_range_denylist
|
||||
.iter()
|
||||
|
||||
@@ -19,6 +19,7 @@ pub mod globals;
|
||||
pub mod key_backups;
|
||||
pub mod media;
|
||||
pub mod membership;
|
||||
pub mod oauth;
|
||||
pub mod presence;
|
||||
pub mod pusher;
|
||||
pub mod resolver;
|
||||
|
||||
265
src/service/oauth/mod.rs
Normal file
265
src/service/oauth/mod.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
pub mod providers;
|
||||
pub mod sessions;
|
||||
pub mod user_info;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
|
||||
use reqwest::{Method, header::ACCEPT};
|
||||
use ruma::UserId;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, err, implement,
|
||||
utils::{hash::sha256, result::LogErr},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
pub use self::{
|
||||
providers::Provider,
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
|
||||
user_info::UserInfo,
|
||||
};
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
use crate::SelfServices;
|
||||
|
||||
pub struct Service {
|
||||
services: SelfServices,
|
||||
pub providers: Arc<Providers>,
|
||||
pub sessions: Arc<Sessions>,
|
||||
}
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let providers = Arc::new(Providers::build(args));
|
||||
let sessions = Arc::new(Sessions::build(args, providers.clone()));
|
||||
Ok(Arc::new(Self {
|
||||
services: args.services.clone(),
|
||||
sessions,
|
||||
providers,
|
||||
}))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_userinfo(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
) -> Result<UserInfo> {
|
||||
let url = provider
|
||||
.userinfo_url
|
||||
.clone()
|
||||
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_tokeninfo(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
) -> Result<UserInfo> {
|
||||
let url = provider
|
||||
.introspection_url
|
||||
.clone()
|
||||
.ok_or_else(|| {
|
||||
err!(Config("introspection_url", "Missing introspection URL in config"))
|
||||
})?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RevokeQuery<'a> {
|
||||
client_id: &'a str,
|
||||
client_secret: &'a str,
|
||||
}
|
||||
|
||||
let query = RevokeQuery {
|
||||
client_id: &provider.client_id,
|
||||
client_secret: &provider.client_secret,
|
||||
};
|
||||
|
||||
let query = serde_html_form::to_string(&query)?;
|
||||
let url = provider
|
||||
.revocation_url
|
||||
.clone()
|
||||
.map(|mut url| {
|
||||
url.set_query(Some(&query));
|
||||
url
|
||||
})
|
||||
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||
.await
|
||||
.log_err()
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_token(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
code: &str,
|
||||
) -> Result<Session> {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenQuery<'a> {
|
||||
client_id: &'a str,
|
||||
client_secret: &'a str,
|
||||
grant_type: &'a str,
|
||||
code: &'a str,
|
||||
code_verifier: Option<&'a str>,
|
||||
redirect_uri: Option<&'a str>,
|
||||
}
|
||||
|
||||
let query = TokenQuery {
|
||||
client_id: &provider.client_id,
|
||||
client_secret: &provider.client_secret,
|
||||
grant_type: "authorization_code",
|
||||
code,
|
||||
code_verifier: session.code_verifier.as_deref(),
|
||||
redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
|
||||
};
|
||||
|
||||
let query = serde_html_form::to_string(&query)?;
|
||||
let url = provider
|
||||
.token_url
|
||||
.clone()
|
||||
.map(|mut url| {
|
||||
url.set_query(Some(&query));
|
||||
url
|
||||
})
|
||||
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
/// Send a request to a provider; this is somewhat abstract since URL's are
|
||||
/// formed prior to this call and could point at anything, however this function
|
||||
/// uses the oauth-specific http client and is configured for JSON with special
|
||||
/// casing for an `error` property in the response.
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "request",
|
||||
level = "debug",
|
||||
ret(level = "trace"),
|
||||
skip(self)
|
||||
)]
|
||||
pub async fn request(
|
||||
&self,
|
||||
(provider, session): (Option<&Provider>, Option<&Session>),
|
||||
method: Method,
|
||||
url: Url,
|
||||
) -> Result<JsonValue> {
|
||||
let mut request = self
|
||||
.services
|
||||
.client
|
||||
.oauth
|
||||
.request(method, url)
|
||||
.header(ACCEPT, "application/json");
|
||||
|
||||
if let Some(session) = session {
|
||||
if let Some(access_token) = session.access_token.clone() {
|
||||
request = request.bearer_auth(access_token);
|
||||
}
|
||||
}
|
||||
|
||||
let response: JsonValue = request
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(response) = response.as_object().as_ref()
|
||||
&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
|
||||
{
|
||||
let description = response
|
||||
.get("error_description")
|
||||
.and_then(JsonValue::as_str)
|
||||
.unwrap_or("(no description)");
|
||||
|
||||
return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
|
||||
let session = self.sessions.get_by_user(user_id).await?;
|
||||
let provider = self.sessions.provider(&session).await?;
|
||||
|
||||
Ok((provider, session))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
|
||||
unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
|
||||
unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
|
||||
unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
|
||||
let hash = sha256::delimited([iss, sub].iter());
|
||||
let b64 = b64encode.encode(hash);
|
||||
|
||||
Ok(b64)
|
||||
}
|
||||
|
||||
fn unique_id_parts<'a>(
|
||||
(provider, session): (&'a Provider, &'a Session),
|
||||
) -> Result<(&'a str, &'a str)> {
|
||||
provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||
.and_then(|iss| unique_id_iss_parts((iss, session)))
|
||||
}
|
||||
|
||||
fn unique_id_sub_parts<'a>(
|
||||
(provider, sub): (&'a Provider, &'a str),
|
||||
) -> Result<(&'a str, &'a str)> {
|
||||
provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||
.map(|iss| (iss, sub))
|
||||
}
|
||||
|
||||
fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
|
||||
session
|
||||
.user_info
|
||||
.as_ref()
|
||||
.map(|user_info| user_info.sub.as_str())
|
||||
.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
|
||||
.map(|sub| (iss, sub))
|
||||
}
|
||||
246
src/service/oauth/providers.rs
Normal file
246
src/service/oauth/providers.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde_json::{Map as JsonObject, Value as JsonValue};
|
||||
use tokio::sync::RwLock;
|
||||
pub use tuwunel_core::config::IdentityProvider as Provider;
|
||||
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
|
||||
use url::Url;
|
||||
|
||||
use crate::SelfServices;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Providers {
|
||||
services: SelfServices,
|
||||
providers: RwLock<BTreeMap<String, Provider>>,
|
||||
}
|
||||
|
||||
#[implement(Providers)]
|
||||
pub(super) fn build(args: &crate::Args<'_>) -> Self {
|
||||
Self {
|
||||
services: args.services.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the Provider configuration after any discovery and adjustments
|
||||
/// made on top of the admin's configuration. This incurs network-based
|
||||
/// discovery on the first call but responds from cache on subsequent calls.
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get(&self, id: &str) -> Result<Provider> {
|
||||
if let Some(provider) = self.providers.read().await.get(id).cloned() {
|
||||
return Ok(provider);
|
||||
}
|
||||
|
||||
let config = self.get_config(id)?;
|
||||
let mut map = self.providers.write().await;
|
||||
let config = self.configure(config).await?;
|
||||
|
||||
debug!(?id, ?config);
|
||||
_ = map.insert(id.into(), config.clone());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Get the admin-configured Provider which exists prior to any
|
||||
/// reconciliation with the well-known discovery (the server's config is
|
||||
/// immutable); though it is important to note the server config can be
|
||||
/// reloaded. This will Err NotFound for a non-existent idp.
|
||||
#[implement(Providers)]
|
||||
pub fn get_config(&self, id: &str) -> Result<Provider> {
|
||||
self.services
|
||||
.config
|
||||
.identity_provider
|
||||
.iter()
|
||||
.find(|config| config.id() == id)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider"))))
|
||||
}
|
||||
|
||||
/// Configure an identity provider; takes the admin-configured instance from the
|
||||
/// server's config, queries the provider for discovery, and then returns an
|
||||
/// updated config based on the proper reconciliation. This final config is then
|
||||
/// cached in memory to avoid repeating this process.
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(
|
||||
level = INFO_SPAN_LEVEL,
|
||||
ret(level = "debug"),
|
||||
skip(self),
|
||||
)]
|
||||
async fn configure(&self, mut provider: Provider) -> Result<Provider> {
|
||||
_ = provider
|
||||
.name
|
||||
.get_or_insert_with(|| provider.brand.clone());
|
||||
|
||||
if provider.issuer_url.is_none() {
|
||||
_ = provider
|
||||
.issuer_url
|
||||
.replace(match provider.brand.as_str() {
|
||||
| "github" => "https://github.com".try_into()?,
|
||||
| "gitlab" => "https://gitlab.com".try_into()?,
|
||||
| "google" => "https://accounts.google.com".try_into()?,
|
||||
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
|
||||
});
|
||||
}
|
||||
|
||||
if provider.base_path.is_none() {
|
||||
provider.base_path = match provider.brand.as_str() {
|
||||
| "github" => Some("/login/oauth".to_owned()),
|
||||
| _ => None,
|
||||
};
|
||||
}
|
||||
|
||||
let response = self
|
||||
.discover(&provider)
|
||||
.await
|
||||
.and_then(|response| {
|
||||
response.as_object().cloned().ok_or_else(|| {
|
||||
err!(Request(NotJson("Expecting JSON object for discovery response")))
|
||||
})
|
||||
})
|
||||
.and_then(|response| check_issuer(response, &provider))?;
|
||||
|
||||
if provider.authorization_url.is_none() {
|
||||
response
|
||||
.get("authorization_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/authorize").ok())
|
||||
.map(|url| provider.authorization_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.revocation_url.is_none() {
|
||||
response
|
||||
.get("revocation_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/revocation").ok())
|
||||
.map(|url| provider.revocation_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.introspection_url.is_none() {
|
||||
response
|
||||
.get("introspection_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/introspection").ok())
|
||||
.map(|url| provider.introspection_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.userinfo_url.is_none() {
|
||||
response
|
||||
.get("userinfo_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| match provider.brand.as_str() {
|
||||
| "github" => "https://api.github.com/user".try_into().ok(),
|
||||
| _ => make_url(&provider, "/userinfo").ok(),
|
||||
})
|
||||
.map(|url| provider.userinfo_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.token_url.is_none() {
|
||||
response
|
||||
.get("token_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| {
|
||||
let path = if provider.brand == "github" {
|
||||
"/access_token"
|
||||
} else {
|
||||
"/token"
|
||||
};
|
||||
|
||||
make_url(&provider, path).ok()
|
||||
})
|
||||
.map(|url| provider.token_url.replace(url));
|
||||
}
|
||||
|
||||
Ok(provider)
|
||||
}
|
||||
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
|
||||
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
|
||||
self.services
|
||||
.client
|
||||
.oauth
|
||||
.get(discovery_url(provider)?)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn discovery_url(provider: &Provider) -> Result<Url> {
|
||||
let default_url = provider
|
||||
.discovery
|
||||
.then(|| make_url(provider, "/.well-known/openid-configuration"))
|
||||
.transpose()?;
|
||||
|
||||
let Some(url) = provider
|
||||
.discovery_url
|
||||
.clone()
|
||||
.filter(|_| provider.discovery)
|
||||
.or(default_url)
|
||||
else {
|
||||
return Err!(Config(
|
||||
"discovery_url",
|
||||
"Failed to determine URL for discovery of provider {}",
|
||||
provider.id()
|
||||
));
|
||||
};
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
/// Validate that the locally configured `issuer_url` matches the issuer claimed
|
||||
/// in any response. todo: cryptographic validation is not yet implemented here.
|
||||
fn check_issuer(
|
||||
response: JsonObject<String, JsonValue>,
|
||||
provider: &Provider,
|
||||
) -> Result<JsonObject<String, JsonValue>> {
|
||||
let expected = provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.map(|url| url.trim_end_matches('/'));
|
||||
|
||||
let responded = response
|
||||
.get("issuer")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(|url| url.trim_end_matches('/'));
|
||||
|
||||
if expected != responded {
|
||||
return Err!(Request(Unauthorized(
|
||||
"Configured issuer_url {expected:?} does not match discovered {responded:?}",
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Generate a full URL for a request to the idp based on the idp's derived
|
||||
/// configuration.
|
||||
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
|
||||
let mut suffix = provider.base_path.clone().unwrap_or_default();
|
||||
|
||||
suffix.push_str(path);
|
||||
let url = provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
let id = &provider.client_id;
|
||||
err!(Config("issuer_url", "Provider {id:?} required field"))
|
||||
})?
|
||||
.join(&suffix)?;
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
234
src/service/oauth/sessions.rs
Normal file
234
src/service/oauth/sessions.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::{sync::Arc, time::SystemTime};
|
||||
|
||||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use ruma::{OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
use url::Url;
|
||||
|
||||
use super::{Provider, Providers, UserInfo, unique_id};
|
||||
use crate::SelfServices;
|
||||
|
||||
pub struct Sessions {
|
||||
_services: SelfServices,
|
||||
providers: Arc<Providers>,
|
||||
db: Data,
|
||||
}
|
||||
|
||||
struct Data {
|
||||
oauthid_session: Arc<Map>,
|
||||
oauthuniqid_oauthid: Arc<Map>,
|
||||
userid_oauthid: Arc<Map>,
|
||||
}
|
||||
|
||||
/// Session ultimately represents an OAuth authorization session yielding an
|
||||
/// associated matrix user registration.
|
||||
///
|
||||
/// Mixed-use structure capable of deserializing response values, maintaining
|
||||
/// the state between authorization steps, and maintaining the association to
|
||||
/// the matrix user until deactivation or revocation.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Session {
|
||||
/// Identity Provider ID (the `client_id` in the configuration) associated
|
||||
/// with this session.
|
||||
pub idp_id: Option<String>,
|
||||
|
||||
/// Session ID used as the index key for this session itself.
|
||||
pub sess_id: Option<String>,
|
||||
|
||||
/// Token type (bearer, mac, etc).
|
||||
pub token_type: Option<String>,
|
||||
|
||||
/// Access token to the provider.
|
||||
pub access_token: Option<String>,
|
||||
|
||||
/// Duration in seconds the access_token is valid for.
|
||||
pub expires_in: Option<u64>,
|
||||
|
||||
/// Point in time that the access_token expires.
|
||||
pub expires_at: Option<SystemTime>,
|
||||
|
||||
/// Token used to refresh the access_token.
|
||||
pub refresh_token: Option<String>,
|
||||
|
||||
/// Duration in seconds the refresh_token is valid for
|
||||
pub refresh_token_expires_in: Option<u64>,
|
||||
|
||||
/// Point in time that the refresh_token expires.
|
||||
pub refresh_token_expires_at: Option<SystemTime>,
|
||||
|
||||
/// Access scope actually granted (if supported).
|
||||
pub scope: Option<String>,
|
||||
|
||||
/// Redirect URL
|
||||
pub redirect_url: Option<Url>,
|
||||
|
||||
/// Challenge preimage
|
||||
pub code_verifier: Option<String>,
|
||||
|
||||
/// Random string passed exclusively in the grant session cookie.
|
||||
pub cookie_nonce: Option<String>,
|
||||
|
||||
/// Random single-use string passed in the provider redirect.
|
||||
pub query_nonce: Option<String>,
|
||||
|
||||
/// Point in time the authorization grant session expires.
|
||||
pub authorize_expires_at: Option<SystemTime>,
|
||||
|
||||
/// Associated User Id registration.
|
||||
pub user_id: Option<OwnedUserId>,
|
||||
|
||||
/// Last userinfo response persisted here.
|
||||
pub user_info: Option<UserInfo>,
|
||||
}
|
||||
|
||||
/// Number of characters generated for our code_verifier. The code_verifier is a
|
||||
/// random string which must be between 43 and 128 characters.
|
||||
pub const CODE_VERIFIER_LENGTH: usize = 64;
|
||||
|
||||
/// Number of characters we will generate for the Session ID.
|
||||
pub const SESSION_ID_LENGTH: usize = 32;
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
||||
Self {
|
||||
_services: args.services.clone(),
|
||||
providers,
|
||||
db: Data {
|
||||
oauthid_session: args.db["oauthid_session"].clone(),
|
||||
oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
|
||||
userid_oauthid: args.db["userid_oauthid"].clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.keys()
|
||||
.expect_ok()
|
||||
.map(UserId::to_owned)
|
||||
}
|
||||
|
||||
/// Delete database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn delete(&self, sess_id: &str) {
|
||||
let Ok(session) = self.get(sess_id).await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check the user_id still points to this sess_id before deleting. If not, the
|
||||
// association was updated to a newer session.
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await {
|
||||
if assoc_id == sess_id {
|
||||
self.db.userid_oauthid.remove(user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check the unique identity still points to this sess_id before deleting. If
|
||||
// not, the association was updated to a newer session.
|
||||
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||
if let Ok(unique_id) = unique_id((&provider, &session)) {
|
||||
if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await {
|
||||
if assoc_id == sess_id {
|
||||
self.db.oauthuniqid_oauthid.remove(&unique_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.db.oauthid_session.remove(sess_id);
|
||||
}
|
||||
|
||||
/// Create or overwrite database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "info", skip(self))]
|
||||
pub async fn put(&self, sess_id: &str, session: &Session) {
|
||||
self.db
|
||||
.oauthid_session
|
||||
.raw_put(sess_id, Cbor(session));
|
||||
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
self.db.userid_oauthid.insert(user_id, sess_id);
|
||||
}
|
||||
|
||||
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||
if let Ok(unique_id) = unique_id((&provider, session)) {
|
||||
self.db
|
||||
.oauthuniqid_oauthid
|
||||
.insert(&unique_id, sess_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `sub`, in case
|
||||
/// `sess_id` is not known.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
|
||||
self.get_sess_id_by_unique_id(unique_id)
|
||||
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `user_id`, in case
|
||||
/// `sess_id` is not known.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
|
||||
self.get_sess_id_by_user(user_id)
|
||||
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its `sess_id`.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get(&self, sess_id: &str) -> Result<Session> {
|
||||
self.db
|
||||
.oauthid_session
|
||||
.get(sess_id)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated `user_id`.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated provider subject id.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
|
||||
self.db
|
||||
.oauthuniqid_oauthid
|
||||
.get(unique_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub async fn provider(&self, session: &Session) -> Result<Provider> {
|
||||
let Some(idp_id) = session.idp_id.as_deref() else {
|
||||
return Err!(Request(NotFound("No provider for this session")));
|
||||
};
|
||||
|
||||
self.providers.get(idp_id).await
|
||||
}
|
||||
39
src/service/oauth/user_info.rs
Normal file
39
src/service/oauth/user_info.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Selection of userinfo response claims.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct UserInfo {
|
||||
/// Unique identifier number or login username. Usually a number on most
|
||||
/// services. We consider a concatenation of the `iss` and `sub` to be a
|
||||
/// universally unique identifier for some user/identity; we index that in
|
||||
/// `oauthidpsub_oauthid`.
|
||||
///
|
||||
/// Considered for user mxid only if none of the better fields are defined.
|
||||
/// `login` alias intended for github.
|
||||
#[serde(alias = "login")]
|
||||
pub sub: String,
|
||||
|
||||
/// The login username we first consider when defined.
|
||||
pub preferred_username: Option<String>,
|
||||
|
||||
/// The login username considered if none preferred.
|
||||
pub nickname: Option<String>,
|
||||
|
||||
/// Full name.
|
||||
pub name: Option<String>,
|
||||
|
||||
/// First name.
|
||||
pub given_name: Option<String>,
|
||||
|
||||
/// Last name.
|
||||
pub family_name: Option<String>,
|
||||
|
||||
/// Email address (`email` scope).
|
||||
pub email: Option<String>,
|
||||
|
||||
/// URL to pfp (github/gitlab)
|
||||
pub avatar_url: Option<String>,
|
||||
|
||||
/// URL to pfp (google)
|
||||
pub picture: Option<String>,
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
|
||||
key_backups,
|
||||
manager::Manager,
|
||||
media, membership, presence, pusher, resolver, rooms, sending, server_keys,
|
||||
media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys,
|
||||
service::{Args, Service},
|
||||
sync, transaction_ids, uiaa, users,
|
||||
};
|
||||
@@ -58,6 +58,7 @@ pub struct Services {
|
||||
pub users: Arc<users::Service>,
|
||||
pub membership: Arc<membership::Service>,
|
||||
pub deactivate: Arc<deactivate::Service>,
|
||||
pub oauth: Arc<oauth::Service>,
|
||||
|
||||
manager: Mutex<Option<Arc<Manager>>>,
|
||||
pub server: Arc<Server>,
|
||||
@@ -115,6 +116,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
|
||||
users: users::Service::build(&args)?,
|
||||
membership: membership::Service::build(&args)?,
|
||||
deactivate: deactivate::Service::build(&args)?,
|
||||
oauth: oauth::Service::build(&args)?,
|
||||
|
||||
manager: Mutex::new(None),
|
||||
server,
|
||||
@@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
|
||||
cast!(self.users),
|
||||
cast!(self.membership),
|
||||
cast!(self.deactivate),
|
||||
cast!(self.oauth),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ use ruma::{
|
||||
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug_warn, err, is_equal_to,
|
||||
Err, Result, debug_warn, err,
|
||||
pdu::PduBuilder,
|
||||
trace,
|
||||
utils::{self, ReadyExt, stream::TryIgnore},
|
||||
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
|
||||
warn,
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Json, Map};
|
||||
@@ -132,6 +132,16 @@ impl Service {
|
||||
|
||||
/// Deactivate account
|
||||
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
|
||||
// Revoke any SSO authorizations
|
||||
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
|
||||
self.services
|
||||
.oauth
|
||||
.revoke_token((&provider, &session))
|
||||
.await
|
||||
.log_err()
|
||||
.ok();
|
||||
}
|
||||
|
||||
// Remove all associated devices
|
||||
self.all_device_ids(user_id)
|
||||
.for_each(|device_id| self.remove_device(user_id, device_id))
|
||||
@@ -227,18 +237,15 @@ impl Service {
|
||||
// Cannot change the password of a LDAP user. There are two special cases :
|
||||
// - a `None` password can be used to deactivate a LDAP user
|
||||
// - a "*" password is used as the default password of an active LDAP user
|
||||
if cfg!(feature = "ldap")
|
||||
&& password.is_some()
|
||||
//
|
||||
// The above now applies to all non-password origin users including SSO
|
||||
// users. Note that users with no origin are also password-origin users.
|
||||
if password.is_some()
|
||||
&& password != Some("*")
|
||||
&& self
|
||||
.db
|
||||
.userid_origin
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized::<String>()
|
||||
.is_ok_and(is_equal_to!("ldap"))
|
||||
&& let Ok(origin) = self.origin(user_id).await
|
||||
&& origin != "password"
|
||||
{
|
||||
return Err!(Request(InvalidParam("Cannot change password of a LDAP user")));
|
||||
return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user.")));
|
||||
}
|
||||
|
||||
password
|
||||
|
||||
@@ -1834,6 +1834,10 @@
|
||||
#
|
||||
#one_time_key_limit = 256
|
||||
|
||||
# This item is undocumented. Please contribute documentation for it.
|
||||
#
|
||||
#sso_aware_preferred = false
|
||||
|
||||
|
||||
|
||||
#[global.tls]
|
||||
@@ -2097,6 +2101,131 @@
|
||||
|
||||
|
||||
|
||||
#[[global.identity_provider]]
|
||||
|
||||
# The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
|
||||
# Google) or the software (e.g. keycloak, MAS) providing the identity.
|
||||
# When a brand is recognized we apply certain defaults to this config
|
||||
# for your convenience. For certain brands we apply essential internal
|
||||
# workarounds specific to that provider; it is important to configure this
|
||||
# field properly when a provider needs to be recognized (like GitHub for
|
||||
# example). Several configured providers can share the same brand name. It
|
||||
# is not case-sensitive.
|
||||
#
|
||||
#brand =
|
||||
|
||||
# The ID of your OAuth application which the provider generates upon
|
||||
# registration. This ID then uniquely identifies this configuration
|
||||
# instance itself, becoming the identity provider's ID and must be unique
|
||||
# and remain unchanged.
|
||||
#
|
||||
#client_id =
|
||||
|
||||
# Secret key the provider generated for you along with the `client_id`
|
||||
# above. Unlike the `client_id`, the `client_secret` can be changed here
|
||||
# whenever the provider regenerates one for you.
|
||||
#
|
||||
#client_secret =
|
||||
|
||||
# The callback URL configured when registering the OAuth application with
|
||||
# the provider. Tuwunel's callback URL must be strictly formatted exactly
|
||||
# as instructed. The URL host must point directly at the matrix server and
|
||||
# use the following path:
|
||||
# `/_matrix/client/unstable/login/sso/callback/<client_id>` where
|
||||
# `<client_id>` is the same one configured for this provider above.
|
||||
#
|
||||
#callback_url =
|
||||
|
||||
# Optional display-name for this provider instance seen on the login page
|
||||
# by users. It defaults to `brand`. When configuring multiple providers
|
||||
# using the same `brand` this can be set to distinguish them.
|
||||
#
|
||||
#name =
|
||||
|
||||
# Optional icon for the provider. The canonical providers have a default
|
||||
# icon based on the `brand` supplied above when this is not supplied. Note
|
||||
# that it uses an MXC url which is curious in the auth-media era and may
|
||||
# not be reliable.
|
||||
#
|
||||
#icon =
|
||||
|
||||
# Optional list of scopes to authorize. An empty array does not impose any
|
||||
# restrictions from here, effectively defaulting to all scopes you
|
||||
# configured for the OAuth application at the provider. This setting
|
||||
# allows for restricting to a subset of those scopes for this instance.
|
||||
# Note the user can further restrict scopes during their authorization.
|
||||
#
|
||||
#scope = []
|
||||
|
||||
# List of userinfo claims which shape and restrict the way we compute a
|
||||
# Matrix UserId for new registrations. Reviewing Tuwunel's documentation
|
||||
# will be necessary for a complete description in detail. An empty array
|
||||
# imposes no restriction here, avoiding generated fallbacks as much as
|
||||
# possible. For simplicity we reserve a claim called "unique" which can be
|
||||
# listed alone to ensure *only* generated ID's are used for registrations.
|
||||
#
|
||||
#userid_claims = []
|
||||
|
||||
# Issuer URL the provider publishes for you. We have pre-supplied default
|
||||
# values for some of the canonical providers, making this field optional
|
||||
# based on the `brand` set above. Otherwise it is required for OIDC
|
||||
# discovery to acquire additional provider configuration, and it must be
|
||||
# correct to pass validations during various interactions.
|
||||
#
|
||||
#issuer_url =
|
||||
|
||||
# Extra path components after the issuer_url leading to the location of
|
||||
# the `.well-known` directory used for discovery. This will be empty for
|
||||
# specification-compliant providers. We have supplied any known values
|
||||
# based on `brand` (e.g. `/login/oauth` for GitHub).
|
||||
#
|
||||
#base_path =
|
||||
|
||||
# Overrides the `.well-known` location where the provider's OIDC
|
||||
# configuration is found. It is very unlikely you will need to set this;
|
||||
# available for developers or special purposes only.
|
||||
#
|
||||
#discovery_url =
|
||||
|
||||
# Overrides the authorize URL requested during the grant phase. This is
|
||||
# generally discovered or derived automatically, but may be required as a
|
||||
# workaround for any non-standard or undiscoverable provider.
|
||||
#
|
||||
#authorization_url =
|
||||
|
||||
# Overrides the access token URL; the same caveats apply as with the other
|
||||
# URL overrides.
|
||||
#
|
||||
#token_url =
|
||||
|
||||
# Overrides the revocation URL; the same caveats apply as with the other
|
||||
# URL overrides.
|
||||
#
|
||||
#revocation_url =
|
||||
|
||||
# Overrides the introspection URL; the same caveats apply as with the
|
||||
# other URL overrides.
|
||||
#
|
||||
#introspection_url =
|
||||
|
||||
# Overrides the userinfo URL; the same caveats apply as with the other URL
|
||||
# overrides.
|
||||
#
|
||||
#userinfo_url =
|
||||
|
||||
# Whether to perform discovery and adjust this provider's configuration
|
||||
# accordingly. This defaults to true. When true, it is an error when
|
||||
# discovery fails and authorizations will not be attempted to the
|
||||
# provider.
|
||||
#
|
||||
#discovery = true
|
||||
|
||||
# The duration in seconds before a grant authorization session expires.
|
||||
#
|
||||
#grant_session_duration =
|
||||
|
||||
|
||||
|
||||
#[global.appservice.<ID>]
|
||||
|
||||
# The URL for the application service.
|
||||
|
||||
Reference in New Issue
Block a user