Implement SSO/OIDC support. (closes #7)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
27
Cargo.lock
generated
27
Cargo.lock
generated
@@ -383,6 +383,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"cookie",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"headers",
|
"headers",
|
||||||
"http",
|
"http",
|
||||||
@@ -836,6 +837,17 @@ dependencies = [
|
|||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cookie"
|
||||||
|
version = "0.18.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||||
|
dependencies = [
|
||||||
|
"percent-encoding",
|
||||||
|
"time",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coolor"
|
name = "coolor"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
@@ -1291,6 +1303,15 @@ version = "1.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "encoding_rs"
|
||||||
|
version = "0.8.35"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "enum-as-inner"
|
name = "enum-as-inner"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
@@ -3582,6 +3603,7 @@ checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"encoding_rs",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
@@ -3595,6 +3617,7 @@ dependencies = [
|
|||||||
"hyper-util",
|
"hyper-util",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"log",
|
"log",
|
||||||
|
"mime",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
@@ -5141,6 +5164,7 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
"tuwunel_core",
|
"tuwunel_core",
|
||||||
"tuwunel_service",
|
"tuwunel_service",
|
||||||
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5185,6 +5209,8 @@ dependencies = [
|
|||||||
"ruma",
|
"ruma",
|
||||||
"sanitize-filename",
|
"sanitize-filename",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_core",
|
||||||
|
"serde_html_form",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_regex",
|
"serde_regex",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
@@ -5295,6 +5321,7 @@ dependencies = [
|
|||||||
"rustls",
|
"rustls",
|
||||||
"rustyline-async",
|
"rustyline-async",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_html_form",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
"sha2",
|
"sha2",
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ version = "0.7"
|
|||||||
version = "0.10"
|
version = "0.10"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = [
|
features = [
|
||||||
|
"cookie",
|
||||||
"typed-header",
|
"typed-header",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
@@ -312,10 +313,12 @@ version = "1.12"
|
|||||||
version = "0.12"
|
version = "0.12"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = [
|
features = [
|
||||||
"rustls-tls-native-roots",
|
"charset",
|
||||||
"socks",
|
|
||||||
"hickory-dns",
|
"hickory-dns",
|
||||||
"http2",
|
"http2",
|
||||||
|
"json",
|
||||||
|
"rustls-tls-native-roots",
|
||||||
|
"socks",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.dependencies.ring]
|
[workspace.dependencies.ring]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
mod account_data;
|
mod account_data;
|
||||||
mod appservice;
|
mod appservice;
|
||||||
mod globals;
|
mod globals;
|
||||||
|
mod oauth;
|
||||||
mod presence;
|
mod presence;
|
||||||
mod pusher;
|
mod pusher;
|
||||||
mod raw;
|
mod raw;
|
||||||
@@ -18,10 +19,10 @@ use tuwunel_core::Result;
|
|||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand,
|
account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand,
|
||||||
presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand, resolver::ResolverCommand,
|
oauth::OauthCommand, presence::PresenceCommand, pusher::PusherCommand, raw::RawCommand,
|
||||||
room_alias::RoomAliasCommand, room_state_cache::RoomStateCacheCommand,
|
resolver::ResolverCommand, room_alias::RoomAliasCommand,
|
||||||
room_timeline::RoomTimelineCommand, sending::SendingCommand, short::ShortCommand,
|
room_state_cache::RoomStateCacheCommand, room_timeline::RoomTimelineCommand,
|
||||||
sync::SyncCommand, users::UsersCommand,
|
sending::SendingCommand, short::ShortCommand, sync::SyncCommand, users::UsersCommand,
|
||||||
};
|
};
|
||||||
use crate::admin_command_dispatch;
|
use crate::admin_command_dispatch;
|
||||||
|
|
||||||
@@ -81,6 +82,10 @@ pub(super) enum QueryCommand {
|
|||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
Sync(SyncCommand),
|
Sync(SyncCommand),
|
||||||
|
|
||||||
|
/// - oauth service
|
||||||
|
#[command(subcommand)]
|
||||||
|
Oauth(OauthCommand),
|
||||||
|
|
||||||
/// - raw service
|
/// - raw service
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
Raw(RawCommand),
|
Raw(RawCommand),
|
||||||
|
|||||||
172
src/admin/query/oauth.rs
Normal file
172
src/admin/query/oauth.rs
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
use clap::Subcommand;
|
||||||
|
use futures::{StreamExt, TryStreamExt};
|
||||||
|
use ruma::OwnedUserId;
|
||||||
|
use tuwunel_core::{Err, Result, utils::stream::IterStream};
|
||||||
|
use tuwunel_service::{
|
||||||
|
Services,
|
||||||
|
oauth::{Provider, Session},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{admin_command, admin_command_dispatch};
|
||||||
|
|
||||||
|
#[admin_command_dispatch(handler_prefix = "oauth")]
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
/// Query OAuth service state
|
||||||
|
pub(crate) enum OauthCommand {
|
||||||
|
/// List configured OAuth providers.
|
||||||
|
ListProviders,
|
||||||
|
|
||||||
|
/// List users associated with an OAuth provider
|
||||||
|
ListUsers,
|
||||||
|
|
||||||
|
/// Show active configuration of a provider.
|
||||||
|
ShowProvider {
|
||||||
|
id: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: bool,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Show session state
|
||||||
|
ShowSession {
|
||||||
|
id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Token introspection request to provider.
|
||||||
|
TokenInfo {
|
||||||
|
id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Revoke token for user_id or sess_id.
|
||||||
|
Revoke {
|
||||||
|
id: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Remove oauth state (DANGER!)
|
||||||
|
Remove {
|
||||||
|
id: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
force: bool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_list_providers(&self) -> Result {
|
||||||
|
self.services
|
||||||
|
.config
|
||||||
|
.identity_provider
|
||||||
|
.iter()
|
||||||
|
.try_stream()
|
||||||
|
.map_ok(Provider::id)
|
||||||
|
.map_ok(|id| format!("{id}\n"))
|
||||||
|
.try_for_each(async |id| self.write_str(&id).await)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_list_users(&self) -> Result {
|
||||||
|
self.services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.users()
|
||||||
|
.map(|id| format!("{id}\n"))
|
||||||
|
.map(Ok)
|
||||||
|
.try_for_each(async |id: String| self.write_str(&id).await)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_show_provider(&self, id: String, config: bool) -> Result {
|
||||||
|
if config {
|
||||||
|
let config = self.services.oauth.providers.get_config(&id)?;
|
||||||
|
|
||||||
|
self.write_str(&format!("{config:#?}\n")).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = self.services.oauth.providers.get(&id).await?;
|
||||||
|
|
||||||
|
self.write_str(&format!("{provider:#?}\n")).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_show_session(&self, id: String) -> Result {
|
||||||
|
let session = find_session(self.services, &id).await?;
|
||||||
|
|
||||||
|
self.write_str(&format!("{session:#?}\n")).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_token_info(&self, id: String) -> Result {
|
||||||
|
let session = find_session(self.services, &id).await?;
|
||||||
|
|
||||||
|
let provider = self
|
||||||
|
.services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.provider(&session)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let tokeninfo = self
|
||||||
|
.services
|
||||||
|
.oauth
|
||||||
|
.request_tokeninfo((&provider, &session))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.write_str(&format!("{tokeninfo:#?}\n")).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_revoke(&self, id: String) -> Result {
|
||||||
|
let session = find_session(self.services, &id).await?;
|
||||||
|
|
||||||
|
let provider = self
|
||||||
|
.services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.provider(&session)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.services
|
||||||
|
.oauth
|
||||||
|
.revoke_token((&provider, &session))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.write_str("done").await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[admin_command]
|
||||||
|
pub(super) async fn oauth_remove(&self, id: String, force: bool) -> Result {
|
||||||
|
let session = find_session(self.services, &id).await?;
|
||||||
|
|
||||||
|
let Some(sess_id) = session.sess_id else {
|
||||||
|
return Err!("Missing sess_id in oauth Session state");
|
||||||
|
};
|
||||||
|
|
||||||
|
if !force {
|
||||||
|
return Err!(
|
||||||
|
"Deleting these records can cause registration conflicts. Use --force to be sure."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.delete(&sess_id)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.write_str("done").await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_session(services: &Services, id: &str) -> Result<Session> {
|
||||||
|
if let Ok(user_id) = OwnedUserId::parse(id) {
|
||||||
|
services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.get_by_user(&user_id)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
services.oauth.sessions.get(id).await
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -101,6 +101,7 @@ tokio.workspace = true
|
|||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
tuwunel-core.workspace = true
|
tuwunel-core.workspace = true
|
||||||
tuwunel-service.workspace = true
|
tuwunel-service.workspace = true
|
||||||
|
url.workspace = true
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ mod ldap;
|
|||||||
mod logout;
|
mod logout;
|
||||||
mod password;
|
mod password;
|
||||||
mod refresh;
|
mod refresh;
|
||||||
|
mod sso;
|
||||||
mod token;
|
mod token;
|
||||||
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
@@ -12,8 +13,8 @@ use ruma::api::client::session::{
|
|||||||
get_login_types::{
|
get_login_types::{
|
||||||
self,
|
self,
|
||||||
v3::{
|
v3::{
|
||||||
ApplicationServiceLoginType, JwtLoginType, LoginType, PasswordLoginType,
|
ApplicationServiceLoginType, IdentityProvider, JwtLoginType, LoginType,
|
||||||
TokenLoginType,
|
PasswordLoginType, SsoLoginType, TokenLoginType,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
login::{
|
login::{
|
||||||
@@ -28,6 +29,7 @@ use self::{ldap::ldap_login, password::password_login};
|
|||||||
pub(crate) use self::{
|
pub(crate) use self::{
|
||||||
logout::{logout_all_route, logout_route},
|
logout::{logout_all_route, logout_route},
|
||||||
refresh::refresh_token_route,
|
refresh::refresh_token_route,
|
||||||
|
sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route},
|
||||||
token::login_token_route,
|
token::login_token_route,
|
||||||
};
|
};
|
||||||
use super::TOKEN_LENGTH;
|
use super::TOKEN_LENGTH;
|
||||||
@@ -50,6 +52,20 @@ pub(crate) async fn get_login_types_route(
|
|||||||
LoginType::Token(TokenLoginType {
|
LoginType::Token(TokenLoginType {
|
||||||
get_login_token: services.config.login_via_existing_session,
|
get_login_token: services.config.login_via_existing_session,
|
||||||
}),
|
}),
|
||||||
|
LoginType::Sso(SsoLoginType {
|
||||||
|
identity_providers: services
|
||||||
|
.config
|
||||||
|
.identity_provider
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(|config| IdentityProvider {
|
||||||
|
id: config.id().to_owned(),
|
||||||
|
brand: Some(config.brand.clone().into()),
|
||||||
|
icon: config.icon,
|
||||||
|
name: config.name.unwrap_or(config.brand),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}),
|
||||||
]))
|
]))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +105,10 @@ pub(crate) async fn login_route(
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if !services.users.is_active_local(&user_id).await {
|
||||||
|
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||||
|
}
|
||||||
|
|
||||||
// Generate a new token for the device
|
// Generate a new token for the device
|
||||||
let (access_token, expires_in) = services
|
let (access_token, expires_in) = services
|
||||||
.users
|
.users
|
||||||
|
|||||||
601
src/api/client/session/sso.rs
Normal file
601
src/api/client/session/sso.rs
Normal file
@@ -0,0 +1,601 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum_client_ip::InsecureClientIp;
|
||||||
|
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||||
|
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||||
|
use futures::{StreamExt, TryFutureExt, future::try_join};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use reqwest::header::{CONTENT_TYPE, HeaderValue};
|
||||||
|
use ruma::{
|
||||||
|
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
|
||||||
|
api::client::session::{sso_callback, sso_login, sso_login_with_provider},
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tuwunel_core::{
|
||||||
|
Err, Result, at,
|
||||||
|
debug::INFO_SPAN_LEVEL,
|
||||||
|
debug_info, debug_warn, err, info, utils,
|
||||||
|
utils::{
|
||||||
|
content_disposition::make_content_disposition,
|
||||||
|
hash::sha256,
|
||||||
|
result::{FlatOk, LogErr},
|
||||||
|
string::{EMPTY, truncate_deterministic},
|
||||||
|
timepoint_from_now, timepoint_has_passed,
|
||||||
|
},
|
||||||
|
warn,
|
||||||
|
};
|
||||||
|
use tuwunel_service::{
|
||||||
|
Services,
|
||||||
|
media::MXC_LENGTH,
|
||||||
|
oauth::{
|
||||||
|
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id,
|
||||||
|
unique_id_sub,
|
||||||
|
},
|
||||||
|
users::Register,
|
||||||
|
};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use super::TOKEN_LENGTH;
|
||||||
|
use crate::Ruma;
|
||||||
|
|
||||||
|
/// Grant phase query string.
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct GrantQuery<'a> {
|
||||||
|
client_id: &'a str,
|
||||||
|
state: &'a str,
|
||||||
|
nonce: &'a str,
|
||||||
|
scope: &'a str,
|
||||||
|
response_type: &'a str,
|
||||||
|
access_type: &'a str,
|
||||||
|
code_challenge_method: &'a str,
|
||||||
|
code_challenge: &'a str,
|
||||||
|
redirect_uri: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
struct GrantCookie<'a> {
|
||||||
|
client_id: &'a str,
|
||||||
|
state: &'a str,
|
||||||
|
nonce: &'a str,
|
||||||
|
redirect_uri: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "sso_login",
|
||||||
|
level = "debug",
|
||||||
|
skip_all,
|
||||||
|
fields(%client),
|
||||||
|
)]
|
||||||
|
pub(crate) async fn sso_login_route(
|
||||||
|
State(_services): State<crate::State>,
|
||||||
|
InsecureClientIp(client): InsecureClientIp,
|
||||||
|
_body: Ruma<sso_login::v3::Request>,
|
||||||
|
) -> Result<sso_login::v3::Response> {
|
||||||
|
Err!(Request(NotImplemented(
|
||||||
|
"SSO login without specific provider has not been implemented."
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "sso_login_with_provider",
|
||||||
|
level = "info",
|
||||||
|
skip_all,
|
||||||
|
ret(level = "debug")
|
||||||
|
fields(
|
||||||
|
%client,
|
||||||
|
idp_id = body.body.idp_id,
|
||||||
|
),
|
||||||
|
)]
|
||||||
|
pub(crate) async fn sso_login_with_provider_route(
|
||||||
|
State(services): State<crate::State>,
|
||||||
|
InsecureClientIp(client): InsecureClientIp,
|
||||||
|
body: Ruma<sso_login_with_provider::v3::Request>,
|
||||||
|
) -> Result<sso_login_with_provider::v3::Response> {
|
||||||
|
let sso_login_with_provider::v3::Request { idp_id, redirect_url } = body.body;
|
||||||
|
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
|
||||||
|
return Err!(Request(InvalidParam("Invalid redirect_url")));
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = services.oauth.providers.get(&idp_id).await?;
|
||||||
|
let sess_id = utils::random_string(SESSION_ID_LENGTH);
|
||||||
|
let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||||
|
let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||||
|
let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||||
|
let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
|
||||||
|
let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
|
||||||
|
let scope = provider.scope.iter().join(" ");
|
||||||
|
|
||||||
|
let query = GrantQuery {
|
||||||
|
client_id: &provider.client_id,
|
||||||
|
state: &sess_id,
|
||||||
|
nonce: &query_nonce,
|
||||||
|
access_type: "online",
|
||||||
|
response_type: "code",
|
||||||
|
code_challenge_method: "S256",
|
||||||
|
code_challenge: &code_challenge,
|
||||||
|
redirect_uri: callback_uri,
|
||||||
|
scope: scope
|
||||||
|
.is_empty()
|
||||||
|
.then_some("openid email profile")
|
||||||
|
.unwrap_or(scope.as_str()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let location = provider
|
||||||
|
.authorization_url
|
||||||
|
.clone()
|
||||||
|
.map(|mut location| {
|
||||||
|
let query = serde_html_form::to_string(&query).ok();
|
||||||
|
location.set_query(query.as_deref());
|
||||||
|
location
|
||||||
|
})
|
||||||
|
.ok_or_else(|| {
|
||||||
|
err!(Config("authorization_url", "Missing required IdentityProvider config"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let cookie_val = GrantCookie {
|
||||||
|
client_id: query.client_id,
|
||||||
|
state: query.state,
|
||||||
|
nonce: &cookie_nonce,
|
||||||
|
redirect_uri: redirect_url.as_str(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let cookie_path = provider
|
||||||
|
.callback_url
|
||||||
|
.as_ref()
|
||||||
|
.map(Url::path)
|
||||||
|
.unwrap_or("/");
|
||||||
|
|
||||||
|
let cookie_max_age = provider
|
||||||
|
.grant_session_duration
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
.expect("Defaulted to Some value during configure_idp()")
|
||||||
|
.try_into()
|
||||||
|
.expect("std::time::Duration to time::Duration conversion failure");
|
||||||
|
|
||||||
|
let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
|
||||||
|
.path(cookie_path)
|
||||||
|
.max_age(cookie_max_age)
|
||||||
|
.same_site(SameSite::None)
|
||||||
|
.secure(true)
|
||||||
|
.http_only(true)
|
||||||
|
.build()
|
||||||
|
.to_string()
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let session = Session {
|
||||||
|
idp_id: Some(idp_id),
|
||||||
|
sess_id: Some(sess_id.clone()),
|
||||||
|
redirect_url: Some(redirect_url),
|
||||||
|
code_verifier: Some(code_verifier),
|
||||||
|
query_nonce: Some(query_nonce),
|
||||||
|
cookie_nonce: Some(cookie_nonce),
|
||||||
|
authorize_expires_at: provider
|
||||||
|
.grant_session_duration
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
.map(timepoint_from_now)
|
||||||
|
.transpose()?,
|
||||||
|
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.put(&sess_id, &session)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(sso_login_with_provider::v3::Response {
|
||||||
|
location: location.into(),
|
||||||
|
cookie: Some(cookie),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "sso_callback"
|
||||||
|
level = "debug",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
%client,
|
||||||
|
cookie = ?body.cookie,
|
||||||
|
body = ?body.body,
|
||||||
|
),
|
||||||
|
)]
|
||||||
|
pub(crate) async fn sso_callback_route(
|
||||||
|
State(services): State<crate::State>,
|
||||||
|
InsecureClientIp(client): InsecureClientIp,
|
||||||
|
body: Ruma<sso_callback::unstable::Request>,
|
||||||
|
) -> Result<sso_callback::unstable::Response> {
|
||||||
|
let sess_id = body
|
||||||
|
.body
|
||||||
|
.state
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
|
||||||
|
|
||||||
|
let code = body
|
||||||
|
.body
|
||||||
|
.code
|
||||||
|
.as_deref()
|
||||||
|
.ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
|
||||||
|
|
||||||
|
let session = services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.get(sess_id)
|
||||||
|
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
|
||||||
|
|
||||||
|
let idp_id = body.body.idp_id.as_str();
|
||||||
|
let provider = services.oauth.providers.get(idp_id);
|
||||||
|
let (provider, session) = try_join(provider, session).await.log_err()?;
|
||||||
|
let client_id = &provider.client_id;
|
||||||
|
|
||||||
|
if session.idp_id.as_deref() != Some(idp_id) {
|
||||||
|
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if session
|
||||||
|
.authorize_expires_at
|
||||||
|
.map(timepoint_has_passed)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
return Err!(Request(Unauthorized("Authorization grant session has expired.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let cookie = body
|
||||||
|
.cookie
|
||||||
|
.get(GRANT_SESSION_COOKIE)
|
||||||
|
.map(Cookie::value)
|
||||||
|
.map(serde_html_form::from_str::<GrantCookie<'_>>)
|
||||||
|
.transpose()?
|
||||||
|
.ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
|
||||||
|
|
||||||
|
if cookie.client_id != client_id {
|
||||||
|
return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if Some(cookie.nonce) != session.cookie_nonce.as_deref() {
|
||||||
|
return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie.state != sess_id {
|
||||||
|
return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request access token.
|
||||||
|
let token_response = services
|
||||||
|
.oauth
|
||||||
|
.request_token((&provider, &session), code)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let token_expires_at = token_response
|
||||||
|
.expires_in
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
.map(timepoint_from_now)
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let refresh_token_expires_at = token_response
|
||||||
|
.refresh_token_expires_in
|
||||||
|
.map(Duration::from_secs)
|
||||||
|
.map(timepoint_from_now)
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
// Update the session with access token results
|
||||||
|
let session = Session {
|
||||||
|
scope: token_response.scope,
|
||||||
|
token_type: token_response.token_type,
|
||||||
|
access_token: token_response.access_token,
|
||||||
|
expires_at: token_expires_at,
|
||||||
|
refresh_token: token_response.refresh_token,
|
||||||
|
refresh_token_expires_at,
|
||||||
|
..session
|
||||||
|
};
|
||||||
|
|
||||||
|
// Request userinfo claims.
|
||||||
|
let userinfo = services
|
||||||
|
.oauth
|
||||||
|
.request_userinfo((&provider, &session))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Check for an existing session from this identity. We want to maintain one
|
||||||
|
// session for each identity and keep the newer one which has up-to-date state
|
||||||
|
// and access.
|
||||||
|
let (user_id, old_sess_id) = match services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
| Ok(session) => (session.user_id, session.sess_id),
|
||||||
|
| Err(error) if !error.is_not_found() => return Err(error),
|
||||||
|
| Err(_) => (None, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update the session with userinfo
|
||||||
|
let session = Session {
|
||||||
|
user_info: Some(userinfo.clone()),
|
||||||
|
..session
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keep the user_id from the old session as best as possible.
|
||||||
|
let user_id = match user_id {
|
||||||
|
| Some(user_id) => user_id,
|
||||||
|
| None => decide_user_id(&services, &provider, &session, &userinfo).await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update the session with user_id
|
||||||
|
let session = Session {
|
||||||
|
user_id: Some(user_id.clone()),
|
||||||
|
..session
|
||||||
|
};
|
||||||
|
|
||||||
|
// Attempt to register a non-existing user.
|
||||||
|
if !services.users.exists(&user_id).await {
|
||||||
|
register_user(&services, &provider, &session, &userinfo, &user_id).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit the updated session.
|
||||||
|
services
|
||||||
|
.oauth
|
||||||
|
.sessions
|
||||||
|
.put(sess_id, &session)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Delete any old session.
|
||||||
|
if let Some(old_sess_id) = old_sess_id
|
||||||
|
&& sess_id != old_sess_id
|
||||||
|
{
|
||||||
|
services.oauth.sessions.delete(&old_sess_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !services.users.is_active_local(&user_id).await {
|
||||||
|
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow the user to login to Matrix.
|
||||||
|
let login_token = utils::random_string(TOKEN_LENGTH);
|
||||||
|
let _login_token_expires_in = services
|
||||||
|
.users
|
||||||
|
.create_login_token(&user_id, &login_token);
|
||||||
|
|
||||||
|
let location = session
|
||||||
|
.redirect_url
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
|
||||||
|
.clone()
|
||||||
|
.query_pairs_mut()
|
||||||
|
.append_pair("loginToken", &login_token)
|
||||||
|
.finish()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
|
||||||
|
.removal()
|
||||||
|
.build()
|
||||||
|
.to_string()
|
||||||
|
.into();
|
||||||
|
|
||||||
|
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "register",
|
||||||
|
level = INFO_SPAN_LEVEL,
|
||||||
|
skip_all,
|
||||||
|
fields(user_id, userinfo)
|
||||||
|
)]
|
||||||
|
async fn register_user(
|
||||||
|
services: &Services,
|
||||||
|
provider: &Provider,
|
||||||
|
session: &Session,
|
||||||
|
userinfo: &UserInfo,
|
||||||
|
user_id: &UserId,
|
||||||
|
) -> Result {
|
||||||
|
debug_info!(%user_id, "Creating new user account...");
|
||||||
|
|
||||||
|
services
|
||||||
|
.users
|
||||||
|
.full_register(Register {
|
||||||
|
user_id: Some(user_id),
|
||||||
|
password: Some("*"),
|
||||||
|
origin: Some("sso"),
|
||||||
|
displayname: userinfo.name.as_deref(),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(avatar_url) = userinfo
|
||||||
|
.avatar_url
|
||||||
|
.as_deref()
|
||||||
|
.or(userinfo.picture.as_deref())
|
||||||
|
{
|
||||||
|
set_avatar(services, provider, session, userinfo, user_id, avatar_url)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
let idp_id = provider.id();
|
||||||
|
let idp_name = provider
|
||||||
|
.name
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(provider.brand.as_str());
|
||||||
|
|
||||||
|
// log in conduit admin channel if a non-guest user registered
|
||||||
|
let notice =
|
||||||
|
format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})",);
|
||||||
|
|
||||||
|
info!("{notice}");
|
||||||
|
if services.server.config.admin_room_notices {
|
||||||
|
services.admin.notice(¬ice).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
|
||||||
|
async fn set_avatar(
|
||||||
|
services: &Services,
|
||||||
|
_provider: &Provider,
|
||||||
|
_session: &Session,
|
||||||
|
_userinfo: &UserInfo,
|
||||||
|
user_id: &UserId,
|
||||||
|
avatar_url: &str,
|
||||||
|
) -> Result {
|
||||||
|
use reqwest::Response;
|
||||||
|
|
||||||
|
let response = services
|
||||||
|
.client
|
||||||
|
.default
|
||||||
|
.get(avatar_url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.and_then(Response::error_for_status)?;
|
||||||
|
|
||||||
|
let content_type = response
|
||||||
|
.headers()
|
||||||
|
.get(CONTENT_TYPE)
|
||||||
|
.map(HeaderValue::to_str)
|
||||||
|
.flat_ok()
|
||||||
|
.map(ToOwned::to_owned);
|
||||||
|
|
||||||
|
let mxc = Mxc {
|
||||||
|
server_name: services.globals.server_name(),
|
||||||
|
media_id: &utils::random_string(MXC_LENGTH),
|
||||||
|
};
|
||||||
|
|
||||||
|
let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
|
||||||
|
let bytes = response.bytes().await?;
|
||||||
|
services
|
||||||
|
.media
|
||||||
|
.create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let all_joined_rooms: Vec<OwnedRoomId> = services
|
||||||
|
.state_cache
|
||||||
|
.rooms_joined(user_id)
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.collect()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mxc_uri = mxc.to_string().into();
|
||||||
|
services
|
||||||
|
.users
|
||||||
|
.update_avatar_url(user_id, Some(mxc_uri), None, &all_joined_rooms)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
level = "debug",
|
||||||
|
ret(level = "debug")
|
||||||
|
skip_all,
|
||||||
|
fields(user),
|
||||||
|
)]
|
||||||
|
async fn decide_user_id(
|
||||||
|
services: &Services,
|
||||||
|
provider: &Provider,
|
||||||
|
session: &Session,
|
||||||
|
userinfo: &UserInfo,
|
||||||
|
) -> Result<OwnedUserId> {
|
||||||
|
let allowed =
|
||||||
|
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
|
||||||
|
|
||||||
|
let choices = [
|
||||||
|
userinfo
|
||||||
|
.preferred_username
|
||||||
|
.as_deref()
|
||||||
|
.map(str::to_lowercase)
|
||||||
|
.filter(|_| allowed("preferred_username")),
|
||||||
|
userinfo
|
||||||
|
.nickname
|
||||||
|
.as_deref()
|
||||||
|
.map(str::to_lowercase)
|
||||||
|
.filter(|_| allowed("nickname")),
|
||||||
|
provider
|
||||||
|
.brand
|
||||||
|
.eq(&"github")
|
||||||
|
.then_some(userinfo.sub.as_str())
|
||||||
|
.map(str::to_lowercase)
|
||||||
|
.filter(|_| allowed("login")),
|
||||||
|
userinfo
|
||||||
|
.email
|
||||||
|
.as_deref()
|
||||||
|
.and_then(|email| email.split_once('@'))
|
||||||
|
.map(at!(0))
|
||||||
|
.map(str::to_lowercase)
|
||||||
|
.filter(|_| allowed("email")),
|
||||||
|
];
|
||||||
|
|
||||||
|
for choice in choices.into_iter().flatten() {
|
||||||
|
if let Some(user_id) = try_user_id(services, &choice, false).await {
|
||||||
|
return Ok(user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(infallible) = unique_id((provider, session))
|
||||||
|
.map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase())
|
||||||
|
.log_err()
|
||||||
|
{
|
||||||
|
if let Some(user_id) = try_user_id(services, &infallible, true).await {
|
||||||
|
return Ok(user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err!(Request(UserInUse("User ID is not available.")))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, fields(username))]
|
||||||
|
async fn try_user_id(
|
||||||
|
services: &Services,
|
||||||
|
username: &str,
|
||||||
|
may_exist: bool,
|
||||||
|
) -> Option<OwnedUserId> {
|
||||||
|
let server_name = services.globals.server_name();
|
||||||
|
let user_id = parse_user_id(server_name, username)
|
||||||
|
.inspect_err(|e| warn!(?username, "Username invalid: {e}"))
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if services
|
||||||
|
.config
|
||||||
|
.forbidden_usernames
|
||||||
|
.is_match(username)
|
||||||
|
{
|
||||||
|
warn!(?username, "Username forbidden.");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if services.users.exists(&user_id).await {
|
||||||
|
debug_warn!(?username, "Username exists.");
|
||||||
|
|
||||||
|
if services
|
||||||
|
.users
|
||||||
|
.origin(&user_id)
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
.is_none_or(|origin| origin != "sso")
|
||||||
|
{
|
||||||
|
debug_warn!(?username, "Username has non-sso origin.");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !may_exist {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
|
||||||
|
match UserId::parse_with_server_name(username, server_name) {
|
||||||
|
| Err(e) =>
|
||||||
|
Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}")))),
|
||||||
|
| Ok(user_id) => match user_id.validate_strict() {
|
||||||
|
| Ok(()) => Ok(user_id),
|
||||||
|
| Err(e) => Err!(Request(InvalidUsername(debug_error!(
|
||||||
|
"Username {username} contains disallowed characters or spaces: {e}"
|
||||||
|
)))),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,6 +38,9 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
|||||||
.ruma_route(&client::login_route)
|
.ruma_route(&client::login_route)
|
||||||
.ruma_route(&client::login_token_route)
|
.ruma_route(&client::login_token_route)
|
||||||
.ruma_route(&client::refresh_token_route)
|
.ruma_route(&client::refresh_token_route)
|
||||||
|
.ruma_route(&client::sso_login_route)
|
||||||
|
.ruma_route(&client::sso_login_with_provider_route)
|
||||||
|
.ruma_route(&client::sso_callback_route)
|
||||||
.ruma_route(&client::whoami_route)
|
.ruma_route(&client::whoami_route)
|
||||||
.ruma_route(&client::logout_route)
|
.ruma_route(&client::logout_route)
|
||||||
.ruma_route(&client::logout_all_route)
|
.ruma_route(&client::logout_all_route)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::{fmt::Debug, mem, ops::Deref};
|
use std::{fmt::Debug, mem, ops::Deref};
|
||||||
|
|
||||||
use axum::{body::Body, extract::FromRequest};
|
use axum::{body::Body, extract::FromRequest};
|
||||||
|
use axum_extra::extract::cookie::CookieJar;
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
|
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
|
||||||
@@ -18,6 +19,9 @@ pub(crate) struct Args<T> {
|
|||||||
/// Request struct body
|
/// Request struct body
|
||||||
pub(crate) body: T,
|
pub(crate) body: T,
|
||||||
|
|
||||||
|
/// Cookies received from the useragent.
|
||||||
|
pub(crate) cookie: CookieJar,
|
||||||
|
|
||||||
/// Federation server authentication: X-Matrix origin
|
/// Federation server authentication: X-Matrix origin
|
||||||
/// None when not a federation server.
|
/// None when not a federation server.
|
||||||
pub(crate) origin: Option<OwnedServerName>,
|
pub(crate) origin: Option<OwnedServerName>,
|
||||||
@@ -110,6 +114,7 @@ where
|
|||||||
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
|
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
|
||||||
|
cookie: request.cookie,
|
||||||
origin: auth.origin,
|
origin: auth.origin,
|
||||||
sender_user: auth.sender_user,
|
sender_user: auth.sender_user,
|
||||||
sender_device: auth.sender_device,
|
sender_device: auth.sender_device,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use ruma::{
|
|||||||
client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo},
|
client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use tuwunel_core::{Err, Error, Result, err, utils};
|
use tuwunel_core::{Err, Error, Result, err, is_equal_to, utils};
|
||||||
use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH};
|
use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH};
|
||||||
|
|
||||||
use crate::{Ruma, client::jwt};
|
use crate::{Ruma, client::jwt};
|
||||||
@@ -67,9 +67,19 @@ where
|
|||||||
| Some(ref json) => {
|
| Some(ref json) => {
|
||||||
let sender_user = body
|
let sender_user = body
|
||||||
.sender_user
|
.sender_user
|
||||||
.as_ref()
|
.as_deref()
|
||||||
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
|
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
|
||||||
|
|
||||||
|
// Skip UIAA for SSO/OIDC users.
|
||||||
|
if services
|
||||||
|
.users
|
||||||
|
.origin(sender_user)
|
||||||
|
.await
|
||||||
|
.is_ok_and(is_equal_to!("sso"))
|
||||||
|
{
|
||||||
|
return Ok(sender_user.to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services
|
services
|
||||||
.uiaa
|
.uiaa
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
use axum::{RequestExt, RequestPartsExt, extract::Path};
|
use axum::{RequestExt, RequestPartsExt, extract::Path};
|
||||||
|
use axum_extra::extract::cookie::CookieJar;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::request::Parts;
|
use http::request::Parts;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@@ -17,6 +18,7 @@ pub(super) type UserId = SmallString<[u8; 48]>;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(super) struct Request {
|
pub(super) struct Request {
|
||||||
|
pub(super) cookie: CookieJar,
|
||||||
pub(super) path: Path<PathParams>,
|
pub(super) path: Path<PathParams>,
|
||||||
pub(super) query: QueryParams,
|
pub(super) query: QueryParams,
|
||||||
pub(super) body: Bytes,
|
pub(super) body: Bytes,
|
||||||
@@ -33,6 +35,7 @@ pub(super) async fn from(
|
|||||||
let limited = request.with_limited_body();
|
let limited = request.with_limited_body();
|
||||||
let (mut parts, body) = limited.into_parts();
|
let (mut parts, body) = limited.into_parts();
|
||||||
|
|
||||||
|
let cookie: CookieJar = parts.extract().await?;
|
||||||
let path: Path<PathParams> = parts.extract().await?;
|
let path: Path<PathParams> = parts.extract().await?;
|
||||||
let query = parts.uri.query().unwrap_or_default();
|
let query = parts.uri.query().unwrap_or_default();
|
||||||
let query = serde_html_form::from_str(query)
|
let query = serde_html_form::from_str(query)
|
||||||
@@ -44,5 +47,5 @@ pub(super) async fn from(
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?;
|
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?;
|
||||||
|
|
||||||
Ok(Request { path, query, body, parts })
|
Ok(Request { cookie, path, query, body, parts })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ pub mod manager;
|
|||||||
pub mod proxy;
|
pub mod proxy;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, BTreeSet},
|
collections::{BTreeMap, BTreeSet, HashSet},
|
||||||
|
hash::{Hash, Hasher},
|
||||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
@@ -16,7 +17,7 @@ use figment::providers::{Env, Format, Toml};
|
|||||||
pub use figment::{Figment, value::Value as FigmentValue};
|
pub use figment::{Figment, value::Value as FigmentValue};
|
||||||
use regex::RegexSet;
|
use regex::RegexSet;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
|
OwnedMxcUri, OwnedRoomOrAliasId, OwnedServerName, OwnedUserId, RoomVersionId,
|
||||||
api::client::discovery::discover_support::ContactRole,
|
api::client::discovery::discover_support::ContactRole,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, de::IgnoredAny};
|
use serde::{Deserialize, de::IgnoredAny};
|
||||||
@@ -28,6 +29,7 @@ pub use self::{check::check, manager::Manager};
|
|||||||
use crate::{
|
use crate::{
|
||||||
Result, err,
|
Result, err,
|
||||||
error::Error,
|
error::Error,
|
||||||
|
utils,
|
||||||
utils::{string::EMPTY, sys},
|
utils::{string::EMPTY, sys},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -57,7 +59,7 @@ use crate::{
|
|||||||
### https://tuwunel.chat/configuration.html
|
### https://tuwunel.chat/configuration.html
|
||||||
"#,
|
"#,
|
||||||
ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates ldap jwt \
|
ignore = "catchall well_known tls blurhashing allow_invalid_tls_certificates ldap jwt \
|
||||||
appservice"
|
appservice identity_provider"
|
||||||
)]
|
)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
/// The server_name is the pretty name of this server. It is used as a
|
/// The server_name is the pretty name of this server. It is used as a
|
||||||
@@ -2153,6 +2155,13 @@ pub struct Config {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub appservice: BTreeMap<String, AppService>,
|
pub appservice: BTreeMap<String, AppService>,
|
||||||
|
|
||||||
|
// external structure; separate sections
|
||||||
|
#[serde(default)]
|
||||||
|
pub identity_provider: HashSet<IdentityProvider>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub sso_aware_preferred: bool,
|
||||||
|
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
#[allow(clippy::zero_sized_map_values)]
|
#[allow(clippy::zero_sized_map_values)]
|
||||||
// this is a catchall, the map shouldn't be zero at runtime
|
// this is a catchall, the map shouldn't be zero at runtime
|
||||||
@@ -2468,6 +2477,134 @@ pub struct JwtConfig {
|
|||||||
pub validate_signature: bool,
|
pub validate_signature: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||||
|
#[config_example_generator(
|
||||||
|
filename = "tuwunel-example.toml",
|
||||||
|
section = "[global.identity_provider]"
|
||||||
|
)]
|
||||||
|
pub struct IdentityProvider {
|
||||||
|
/// The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
|
||||||
|
/// Google) or the software (e.g. keycloak, MAS) providing the identity.
|
||||||
|
/// When a brand is recognized we apply certain defaults to this config
|
||||||
|
/// for your convenience. For certain brands we apply essential internal
|
||||||
|
/// workarounds specific to that provider; it is important to configure this
|
||||||
|
/// field properly when a provider needs to be recognized (like GitHub for
|
||||||
|
/// example). Several configured providers can share the same brand name. It
|
||||||
|
/// is not case-sensitive.
|
||||||
|
#[serde(deserialize_with = "utils::string::de::to_lowercase")]
|
||||||
|
pub brand: String,
|
||||||
|
|
||||||
|
/// The ID of your OAuth application which the provider generates upon
|
||||||
|
/// registration. This ID then uniquely identifies this configuration
|
||||||
|
/// instance itself, becoming the identity provider's ID and must be unique
|
||||||
|
/// and remain unchanged.
|
||||||
|
pub client_id: String,
|
||||||
|
|
||||||
|
/// Secret key the provider generated for you along with the `client_id`
|
||||||
|
/// above. Unlike the `client_id`, the `client_secret` can be changed here
|
||||||
|
/// whenever the provider regenerates one for you.
|
||||||
|
pub client_secret: String,
|
||||||
|
|
||||||
|
/// The callback URL configured when registering the OAuth application with
|
||||||
|
/// the provider. Tuwunel's callback URL must be strictly formatted exactly
|
||||||
|
/// as instructed. The URL host must point directly at the matrix server and
|
||||||
|
/// use the following path:
|
||||||
|
/// `/_matrix/client/unstable/login/sso/callback/<client_id>` where
|
||||||
|
/// `<client_id>` is the same one configured for this provider above.
|
||||||
|
pub callback_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Optional display-name for this provider instance seen on the login page
|
||||||
|
/// by users. It defaults to `brand`. When configuring multiple providers
|
||||||
|
/// using the same `brand` this can be set to distinguish them.
|
||||||
|
pub name: Option<String>,
|
||||||
|
|
||||||
|
/// Optional icon for the provider. The canonical providers have a default
|
||||||
|
/// icon based on the `brand` supplied above when this is not supplied. Note
|
||||||
|
/// that it uses an MXC url which is curious in the auth-media era and may
|
||||||
|
/// not be reliable.
|
||||||
|
pub icon: Option<OwnedMxcUri>,
|
||||||
|
|
||||||
|
/// Optional list of scopes to authorize. An empty array does not impose any
|
||||||
|
/// restrictions from here, effectively defaulting to all scopes you
|
||||||
|
/// configured for the OAuth application at the provider. This setting
|
||||||
|
/// allows for restricting to a subset of those scopes for this instance.
|
||||||
|
/// Note the user can further restrict scopes during their authorization.
|
||||||
|
///
|
||||||
|
/// default: []
|
||||||
|
#[serde(default)]
|
||||||
|
pub scope: BTreeSet<String>,
|
||||||
|
|
||||||
|
/// List of userinfo claims which shape and restrict the way we compute a
|
||||||
|
/// Matrix UserId for new registrations. Reviewing Tuwunel's documentation
|
||||||
|
/// will be necessary for a complete description in detail. An empty array
|
||||||
|
/// imposes no restriction here, avoiding generated fallbacks as much as
|
||||||
|
/// possible. For simplicity we reserve a claim called "unique" which can be
|
||||||
|
/// listed alone to ensure *only* generated ID's are used for registrations.
|
||||||
|
///
|
||||||
|
/// default: []
|
||||||
|
#[serde(default)]
|
||||||
|
pub userid_claims: BTreeSet<String>,
|
||||||
|
|
||||||
|
/// Issuer URL the provider publishes for you. We have pre-supplied default
|
||||||
|
/// values for some of the canonical providers, making this field optional
|
||||||
|
/// based on the `brand` set above. Otherwise it is required for OIDC
|
||||||
|
/// discovery to acquire additional provider configuration, and it must be
|
||||||
|
/// correct to pass validations during various interactions.
|
||||||
|
pub issuer_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Extra path components after the issuer_url leading to the location of
|
||||||
|
/// the `.well-known` directory used for discovery. This will be empty for
|
||||||
|
/// specification-compliant providers. We have supplied any known values
|
||||||
|
/// based on `brand` (e.g. `/login/oauth` for GitHub).
|
||||||
|
pub base_path: Option<String>,
|
||||||
|
|
||||||
|
/// Overrides the `.well-known` location where the provider's OIDC
|
||||||
|
/// configuration is found. It is very unlikely you will need to set this;
|
||||||
|
/// available for developers or special purposes only.
|
||||||
|
pub discovery_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Overrides the authorize URL requested during the grant phase. This is
|
||||||
|
/// generally discovered or derived automatically, but may be required as a
|
||||||
|
/// workaround for any non-standard or undiscoverable provider.
|
||||||
|
pub authorization_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Overrides the access token URL; the same caveats apply as with the other
|
||||||
|
/// URL overrides.
|
||||||
|
pub token_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Overrides the revocation URL; the same caveats apply as with the other
|
||||||
|
/// URL overrides.
|
||||||
|
pub revocation_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Overrides the introspection URL; the same caveats apply as with the
|
||||||
|
/// other URL overrides.
|
||||||
|
pub introspection_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Overrides the userinfo URL; the same caveats apply as with the other URL
|
||||||
|
/// overrides.
|
||||||
|
pub userinfo_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Whether to perform discovery and adjust this provider's configuration
|
||||||
|
/// accordingly. This defaults to true. When true, it is an error when
|
||||||
|
/// discovery fails and authorizations will not be attempted to the
|
||||||
|
/// provider.
|
||||||
|
#[serde(default = "true_fn")]
|
||||||
|
pub discovery: bool,
|
||||||
|
|
||||||
|
/// The duration in seconds before a grant authorization session expires.
|
||||||
|
#[serde(default = "default_sso_grant_session_duration")]
|
||||||
|
pub grant_session_duration: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IdentityProvider {
|
||||||
|
#[must_use]
|
||||||
|
pub fn id(&self) -> &str { self.client_id.as_str() }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Hash for IdentityProvider {
|
||||||
|
fn hash<H: Hasher>(&self, state: &mut H) { self.id().hash(state) }
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Deserialize)]
|
#[derive(Clone, Debug, Default, Deserialize)]
|
||||||
#[config_example_generator(
|
#[config_example_generator(
|
||||||
filename = "tuwunel-example.toml",
|
filename = "tuwunel-example.toml",
|
||||||
@@ -3016,3 +3153,5 @@ fn default_one_time_key_limit() -> usize { 256 }
|
|||||||
fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 }
|
fn default_max_make_join_attempts_per_join_attempt() -> usize { 48 }
|
||||||
|
|
||||||
fn default_max_join_attempts_per_join_request() -> usize { 3 }
|
fn default_max_join_attempts_per_join_request() -> usize { 3 }
|
||||||
|
|
||||||
|
fn default_sso_grant_session_duration() -> Option<u64> { Some(180) }
|
||||||
|
|||||||
@@ -113,6 +113,14 @@ pub(super) static MAPS: &[Descriptor] = &[
|
|||||||
name: "mediaid_user",
|
name: "mediaid_user",
|
||||||
..descriptor::RANDOM_SMALL
|
..descriptor::RANDOM_SMALL
|
||||||
},
|
},
|
||||||
|
Descriptor {
|
||||||
|
name: "oauthid_session",
|
||||||
|
..descriptor::RANDOM_SMALL
|
||||||
|
},
|
||||||
|
Descriptor {
|
||||||
|
name: "oauthuniqid_oauthid",
|
||||||
|
..descriptor::RANDOM_SMALL
|
||||||
|
},
|
||||||
Descriptor {
|
Descriptor {
|
||||||
name: "onetimekeyid_onetimekeys",
|
name: "onetimekeyid_onetimekeys",
|
||||||
..descriptor::RANDOM_SMALL
|
..descriptor::RANDOM_SMALL
|
||||||
@@ -403,6 +411,10 @@ pub(super) static MAPS: &[Descriptor] = &[
|
|||||||
name: "userid_masterkeyid",
|
name: "userid_masterkeyid",
|
||||||
..descriptor::RANDOM_SMALL
|
..descriptor::RANDOM_SMALL
|
||||||
},
|
},
|
||||||
|
Descriptor {
|
||||||
|
name: "userid_oauthid",
|
||||||
|
..descriptor::RANDOM_SMALL
|
||||||
|
},
|
||||||
Descriptor {
|
Descriptor {
|
||||||
name: "userid_origin",
|
name: "userid_origin",
|
||||||
..descriptor::RANDOM
|
..descriptor::RANDOM
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ ruma.workspace = true
|
|||||||
rustls.workspace = true
|
rustls.workspace = true
|
||||||
rustyline-async.workspace = true
|
rustyline-async.workspace = true
|
||||||
rustyline-async.optional = true
|
rustyline-async.optional = true
|
||||||
|
serde_html_form.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_yaml.workspace = true
|
serde_yaml.workspace = true
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pub struct Service {
|
|||||||
pub sender: ClientLazylock,
|
pub sender: ClientLazylock,
|
||||||
pub appservice: ClientLazylock,
|
pub appservice: ClientLazylock,
|
||||||
pub pusher: ClientLazylock,
|
pub pusher: ClientLazylock,
|
||||||
|
pub oauth: ClientLazylock,
|
||||||
|
|
||||||
pub cidr_range_denylist: Vec<IPAddress>,
|
pub cidr_range_denylist: Vec<IPAddress>,
|
||||||
}
|
}
|
||||||
@@ -112,6 +113,11 @@ impl crate::Service for Service {
|
|||||||
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
|
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
|
||||||
.redirect(redirect::Policy::limited(2))),
|
.redirect(redirect::Policy::limited(2))),
|
||||||
|
|
||||||
|
oauth: create_client!(config, services; base(config)?
|
||||||
|
.dns_resolver2(Arc::clone(&services.resolver.resolver))
|
||||||
|
.redirect(redirect::Policy::limited(0))
|
||||||
|
.pool_max_idle_per_host(1)),
|
||||||
|
|
||||||
cidr_range_denylist: config
|
cidr_range_denylist: config
|
||||||
.ip_range_denylist
|
.ip_range_denylist
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ pub mod globals;
|
|||||||
pub mod key_backups;
|
pub mod key_backups;
|
||||||
pub mod media;
|
pub mod media;
|
||||||
pub mod membership;
|
pub mod membership;
|
||||||
|
pub mod oauth;
|
||||||
pub mod presence;
|
pub mod presence;
|
||||||
pub mod pusher;
|
pub mod pusher;
|
||||||
pub mod resolver;
|
pub mod resolver;
|
||||||
|
|||||||
265
src/service/oauth/mod.rs
Normal file
265
src/service/oauth/mod.rs
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
pub mod providers;
|
||||||
|
pub mod sessions;
|
||||||
|
pub mod user_info;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
|
||||||
|
use reqwest::{Method, header::ACCEPT};
|
||||||
|
use ruma::UserId;
|
||||||
|
use serde::Serialize;
|
||||||
|
use serde_json::Value as JsonValue;
|
||||||
|
use tuwunel_core::{
|
||||||
|
Err, Result, err, implement,
|
||||||
|
utils::{hash::sha256, result::LogErr},
|
||||||
|
};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
pub use self::{
|
||||||
|
providers::Provider,
|
||||||
|
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
|
||||||
|
user_info::UserInfo,
|
||||||
|
};
|
||||||
|
use self::{providers::Providers, sessions::Sessions};
|
||||||
|
use crate::SelfServices;
|
||||||
|
|
||||||
|
pub struct Service {
|
||||||
|
services: SelfServices,
|
||||||
|
pub providers: Arc<Providers>,
|
||||||
|
pub sessions: Arc<Sessions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl crate::Service for Service {
|
||||||
|
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||||
|
let providers = Arc::new(Providers::build(args));
|
||||||
|
let sessions = Arc::new(Sessions::build(args, providers.clone()));
|
||||||
|
Ok(Arc::new(Self {
|
||||||
|
services: args.services.clone(),
|
||||||
|
sessions,
|
||||||
|
providers,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||||
|
pub async fn request_userinfo(
|
||||||
|
&self,
|
||||||
|
(provider, session): (&Provider, &Session),
|
||||||
|
) -> Result<UserInfo> {
|
||||||
|
let url = provider
|
||||||
|
.userinfo_url
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
|
||||||
|
|
||||||
|
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||||
|
.await
|
||||||
|
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||||
|
.log_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||||
|
pub async fn request_tokeninfo(
|
||||||
|
&self,
|
||||||
|
(provider, session): (&Provider, &Session),
|
||||||
|
) -> Result<UserInfo> {
|
||||||
|
let url = provider
|
||||||
|
.introspection_url
|
||||||
|
.clone()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
err!(Config("introspection_url", "Missing introspection URL in config"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||||
|
.await
|
||||||
|
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||||
|
.log_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||||
|
pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct RevokeQuery<'a> {
|
||||||
|
client_id: &'a str,
|
||||||
|
client_secret: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
let query = RevokeQuery {
|
||||||
|
client_id: &provider.client_id,
|
||||||
|
client_secret: &provider.client_secret,
|
||||||
|
};
|
||||||
|
|
||||||
|
let query = serde_html_form::to_string(&query)?;
|
||||||
|
let url = provider
|
||||||
|
.revocation_url
|
||||||
|
.clone()
|
||||||
|
.map(|mut url| {
|
||||||
|
url.set_query(Some(&query));
|
||||||
|
url
|
||||||
|
})
|
||||||
|
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
|
||||||
|
|
||||||
|
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
.map(|_| ())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||||
|
pub async fn request_token(
|
||||||
|
&self,
|
||||||
|
(provider, session): (&Provider, &Session),
|
||||||
|
code: &str,
|
||||||
|
) -> Result<Session> {
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct TokenQuery<'a> {
|
||||||
|
client_id: &'a str,
|
||||||
|
client_secret: &'a str,
|
||||||
|
grant_type: &'a str,
|
||||||
|
code: &'a str,
|
||||||
|
code_verifier: Option<&'a str>,
|
||||||
|
redirect_uri: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let query = TokenQuery {
|
||||||
|
client_id: &provider.client_id,
|
||||||
|
client_secret: &provider.client_secret,
|
||||||
|
grant_type: "authorization_code",
|
||||||
|
code,
|
||||||
|
code_verifier: session.code_verifier.as_deref(),
|
||||||
|
redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
|
||||||
|
};
|
||||||
|
|
||||||
|
let query = serde_html_form::to_string(&query)?;
|
||||||
|
let url = provider
|
||||||
|
.token_url
|
||||||
|
.clone()
|
||||||
|
.map(|mut url| {
|
||||||
|
url.set_query(Some(&query));
|
||||||
|
url
|
||||||
|
})
|
||||||
|
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
|
||||||
|
|
||||||
|
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||||
|
.await
|
||||||
|
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||||
|
.log_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a request to a provider; this is somewhat abstract since URL's are
|
||||||
|
/// formed prior to this call and could point at anything, however this function
|
||||||
|
/// uses the oauth-specific http client and is configured for JSON with special
|
||||||
|
/// casing for an `error` property in the response.
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(
|
||||||
|
name = "request",
|
||||||
|
level = "debug",
|
||||||
|
ret(level = "trace"),
|
||||||
|
skip(self)
|
||||||
|
)]
|
||||||
|
pub async fn request(
|
||||||
|
&self,
|
||||||
|
(provider, session): (Option<&Provider>, Option<&Session>),
|
||||||
|
method: Method,
|
||||||
|
url: Url,
|
||||||
|
) -> Result<JsonValue> {
|
||||||
|
let mut request = self
|
||||||
|
.services
|
||||||
|
.client
|
||||||
|
.oauth
|
||||||
|
.request(method, url)
|
||||||
|
.header(ACCEPT, "application/json");
|
||||||
|
|
||||||
|
if let Some(session) = session {
|
||||||
|
if let Some(access_token) = session.access_token.clone() {
|
||||||
|
request = request.bearer_auth(access_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response: JsonValue = request
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(response) = response.as_object().as_ref()
|
||||||
|
&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
|
||||||
|
{
|
||||||
|
let description = response
|
||||||
|
.get("error_description")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.unwrap_or("(no description)");
|
||||||
|
|
||||||
|
return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Service)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self))]
|
||||||
|
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
|
||||||
|
let session = self.sessions.get_by_user(user_id).await?;
|
||||||
|
let provider = self.sessions.provider(&session).await?;
|
||||||
|
|
||||||
|
Ok((provider, session))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
|
||||||
|
unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
|
||||||
|
unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
|
||||||
|
unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
|
||||||
|
let hash = sha256::delimited([iss, sub].iter());
|
||||||
|
let b64 = b64encode.encode(hash);
|
||||||
|
|
||||||
|
Ok(b64)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unique_id_parts<'a>(
|
||||||
|
(provider, session): (&'a Provider, &'a Session),
|
||||||
|
) -> Result<(&'a str, &'a str)> {
|
||||||
|
provider
|
||||||
|
.issuer_url
|
||||||
|
.as_ref()
|
||||||
|
.map(Url::as_str)
|
||||||
|
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||||
|
.and_then(|iss| unique_id_iss_parts((iss, session)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unique_id_sub_parts<'a>(
|
||||||
|
(provider, sub): (&'a Provider, &'a str),
|
||||||
|
) -> Result<(&'a str, &'a str)> {
|
||||||
|
provider
|
||||||
|
.issuer_url
|
||||||
|
.as_ref()
|
||||||
|
.map(Url::as_str)
|
||||||
|
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||||
|
.map(|iss| (iss, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
|
||||||
|
session
|
||||||
|
.user_info
|
||||||
|
.as_ref()
|
||||||
|
.map(|user_info| user_info.sub.as_str())
|
||||||
|
.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
|
||||||
|
.map(|sub| (iss, sub))
|
||||||
|
}
|
||||||
246
src/service/oauth/providers.rs
Normal file
246
src/service/oauth/providers.rs
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use serde_json::{Map as JsonObject, Value as JsonValue};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
pub use tuwunel_core::config::IdentityProvider as Provider;
|
||||||
|
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::SelfServices;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct Providers {
|
||||||
|
services: SelfServices,
|
||||||
|
providers: RwLock<BTreeMap<String, Provider>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Providers)]
|
||||||
|
pub(super) fn build(args: &crate::Args<'_>) -> Self {
|
||||||
|
Self {
|
||||||
|
services: args.services.clone(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the Provider configuration after any discovery and adjustments
|
||||||
|
/// made on top of the admin's configuration. This incurs network-based
|
||||||
|
/// discovery on the first call but responds from cache on subsequent calls.
|
||||||
|
#[implement(Providers)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self))]
|
||||||
|
pub async fn get(&self, id: &str) -> Result<Provider> {
|
||||||
|
if let Some(provider) = self.providers.read().await.get(id).cloned() {
|
||||||
|
return Ok(provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = self.get_config(id)?;
|
||||||
|
let mut map = self.providers.write().await;
|
||||||
|
let config = self.configure(config).await?;
|
||||||
|
|
||||||
|
debug!(?id, ?config);
|
||||||
|
_ = map.insert(id.into(), config.clone());
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the admin-configured Provider which exists prior to any
|
||||||
|
/// reconciliation with the well-known discovery (the server's config is
|
||||||
|
/// immutable); though it is important to note the server config can be
|
||||||
|
/// reloaded. This will Err NotFound for a non-existent idp.
|
||||||
|
#[implement(Providers)]
|
||||||
|
pub fn get_config(&self, id: &str) -> Result<Provider> {
|
||||||
|
self.services
|
||||||
|
.config
|
||||||
|
.identity_provider
|
||||||
|
.iter()
|
||||||
|
.find(|config| config.id() == id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider"))))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configure an identity provider; takes the admin-configured instance from the
|
||||||
|
/// server's config, queries the provider for discovery, and then returns an
|
||||||
|
/// updated config based on the proper reconciliation. This final config is then
|
||||||
|
/// cached in memory to avoid repeating this process.
|
||||||
|
#[implement(Providers)]
|
||||||
|
#[tracing::instrument(
|
||||||
|
level = INFO_SPAN_LEVEL,
|
||||||
|
ret(level = "debug"),
|
||||||
|
skip(self),
|
||||||
|
)]
|
||||||
|
async fn configure(&self, mut provider: Provider) -> Result<Provider> {
|
||||||
|
_ = provider
|
||||||
|
.name
|
||||||
|
.get_or_insert_with(|| provider.brand.clone());
|
||||||
|
|
||||||
|
if provider.issuer_url.is_none() {
|
||||||
|
_ = provider
|
||||||
|
.issuer_url
|
||||||
|
.replace(match provider.brand.as_str() {
|
||||||
|
| "github" => "https://github.com".try_into()?,
|
||||||
|
| "gitlab" => "https://gitlab.com".try_into()?,
|
||||||
|
| "google" => "https://accounts.google.com".try_into()?,
|
||||||
|
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.base_path.is_none() {
|
||||||
|
provider.base_path = match provider.brand.as_str() {
|
||||||
|
| "github" => Some("/login/oauth".to_owned()),
|
||||||
|
| _ => None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.discover(&provider)
|
||||||
|
.await
|
||||||
|
.and_then(|response| {
|
||||||
|
response.as_object().cloned().ok_or_else(|| {
|
||||||
|
err!(Request(NotJson("Expecting JSON object for discovery response")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.and_then(|response| check_issuer(response, &provider))?;
|
||||||
|
|
||||||
|
if provider.authorization_url.is_none() {
|
||||||
|
response
|
||||||
|
.get("authorization_endpoint")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(Url::parse)
|
||||||
|
.transpose()?
|
||||||
|
.or_else(|| make_url(&provider, "/authorize").ok())
|
||||||
|
.map(|url| provider.authorization_url.replace(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.revocation_url.is_none() {
|
||||||
|
response
|
||||||
|
.get("revocation_endpoint")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(Url::parse)
|
||||||
|
.transpose()?
|
||||||
|
.or_else(|| make_url(&provider, "/revocation").ok())
|
||||||
|
.map(|url| provider.revocation_url.replace(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.introspection_url.is_none() {
|
||||||
|
response
|
||||||
|
.get("introspection_endpoint")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(Url::parse)
|
||||||
|
.transpose()?
|
||||||
|
.or_else(|| make_url(&provider, "/introspection").ok())
|
||||||
|
.map(|url| provider.introspection_url.replace(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.userinfo_url.is_none() {
|
||||||
|
response
|
||||||
|
.get("userinfo_endpoint")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(Url::parse)
|
||||||
|
.transpose()?
|
||||||
|
.or_else(|| match provider.brand.as_str() {
|
||||||
|
| "github" => "https://api.github.com/user".try_into().ok(),
|
||||||
|
| _ => make_url(&provider, "/userinfo").ok(),
|
||||||
|
})
|
||||||
|
.map(|url| provider.userinfo_url.replace(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider.token_url.is_none() {
|
||||||
|
response
|
||||||
|
.get("token_endpoint")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(Url::parse)
|
||||||
|
.transpose()?
|
||||||
|
.or_else(|| {
|
||||||
|
let path = if provider.brand == "github" {
|
||||||
|
"/access_token"
|
||||||
|
} else {
|
||||||
|
"/token"
|
||||||
|
};
|
||||||
|
|
||||||
|
make_url(&provider, path).ok()
|
||||||
|
})
|
||||||
|
.map(|url| provider.token_url.replace(url));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Providers)]
|
||||||
|
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
|
||||||
|
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
|
||||||
|
self.services
|
||||||
|
.client
|
||||||
|
.oauth
|
||||||
|
.get(discovery_url(provider)?)
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discovery_url(provider: &Provider) -> Result<Url> {
|
||||||
|
let default_url = provider
|
||||||
|
.discovery
|
||||||
|
.then(|| make_url(provider, "/.well-known/openid-configuration"))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let Some(url) = provider
|
||||||
|
.discovery_url
|
||||||
|
.clone()
|
||||||
|
.filter(|_| provider.discovery)
|
||||||
|
.or(default_url)
|
||||||
|
else {
|
||||||
|
return Err!(Config(
|
||||||
|
"discovery_url",
|
||||||
|
"Failed to determine URL for discovery of provider {}",
|
||||||
|
provider.id()
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that the locally configured `issuer_url` matches the issuer claimed
|
||||||
|
/// in any response. todo: cryptographic validation is not yet implemented here.
|
||||||
|
fn check_issuer(
|
||||||
|
response: JsonObject<String, JsonValue>,
|
||||||
|
provider: &Provider,
|
||||||
|
) -> Result<JsonObject<String, JsonValue>> {
|
||||||
|
let expected = provider
|
||||||
|
.issuer_url
|
||||||
|
.as_ref()
|
||||||
|
.map(Url::as_str)
|
||||||
|
.map(|url| url.trim_end_matches('/'));
|
||||||
|
|
||||||
|
let responded = response
|
||||||
|
.get("issuer")
|
||||||
|
.and_then(JsonValue::as_str)
|
||||||
|
.map(|url| url.trim_end_matches('/'));
|
||||||
|
|
||||||
|
if expected != responded {
|
||||||
|
return Err!(Request(Unauthorized(
|
||||||
|
"Configured issuer_url {expected:?} does not match discovered {responded:?}",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a full URL for a request to the idp based on the idp's derived
|
||||||
|
/// configuration.
|
||||||
|
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
|
||||||
|
let mut suffix = provider.base_path.clone().unwrap_or_default();
|
||||||
|
|
||||||
|
suffix.push_str(path);
|
||||||
|
let url = provider
|
||||||
|
.issuer_url
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
let id = &provider.client_id;
|
||||||
|
err!(Config("issuer_url", "Provider {id:?} required field"))
|
||||||
|
})?
|
||||||
|
.join(&suffix)?;
|
||||||
|
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
234
src/service/oauth/sessions.rs
Normal file
234
src/service/oauth/sessions.rs
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
use std::{sync::Arc, time::SystemTime};
|
||||||
|
|
||||||
|
use futures::{Stream, StreamExt, TryFutureExt};
|
||||||
|
use ruma::{OwnedUserId, UserId};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
|
||||||
|
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use super::{Provider, Providers, UserInfo, unique_id};
|
||||||
|
use crate::SelfServices;
|
||||||
|
|
||||||
|
pub struct Sessions {
|
||||||
|
_services: SelfServices,
|
||||||
|
providers: Arc<Providers>,
|
||||||
|
db: Data,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Data {
|
||||||
|
oauthid_session: Arc<Map>,
|
||||||
|
oauthuniqid_oauthid: Arc<Map>,
|
||||||
|
userid_oauthid: Arc<Map>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Session ultimately represents an OAuth authorization session yielding an
|
||||||
|
/// associated matrix user registration.
|
||||||
|
///
|
||||||
|
/// Mixed-use structure capable of deserializing response values, maintaining
|
||||||
|
/// the state between authorization steps, and maintaining the association to
|
||||||
|
/// the matrix user until deactivation or revocation.
|
||||||
|
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||||
|
pub struct Session {
|
||||||
|
/// Identity Provider ID (the `client_id` in the configuration) associated
|
||||||
|
/// with this session.
|
||||||
|
pub idp_id: Option<String>,
|
||||||
|
|
||||||
|
/// Session ID used as the index key for this session itself.
|
||||||
|
pub sess_id: Option<String>,
|
||||||
|
|
||||||
|
/// Token type (bearer, mac, etc).
|
||||||
|
pub token_type: Option<String>,
|
||||||
|
|
||||||
|
/// Access token to the provider.
|
||||||
|
pub access_token: Option<String>,
|
||||||
|
|
||||||
|
/// Duration in seconds the access_token is valid for.
|
||||||
|
pub expires_in: Option<u64>,
|
||||||
|
|
||||||
|
/// Point in time that the access_token expires.
|
||||||
|
pub expires_at: Option<SystemTime>,
|
||||||
|
|
||||||
|
/// Token used to refresh the access_token.
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
|
|
||||||
|
/// Duration in seconds the refresh_token is valid for
|
||||||
|
pub refresh_token_expires_in: Option<u64>,
|
||||||
|
|
||||||
|
/// Point in time that the refresh_token expires.
|
||||||
|
pub refresh_token_expires_at: Option<SystemTime>,
|
||||||
|
|
||||||
|
/// Access scope actually granted (if supported).
|
||||||
|
pub scope: Option<String>,
|
||||||
|
|
||||||
|
/// Redirect URL
|
||||||
|
pub redirect_url: Option<Url>,
|
||||||
|
|
||||||
|
/// Challenge preimage
|
||||||
|
pub code_verifier: Option<String>,
|
||||||
|
|
||||||
|
/// Random string passed exclusively in the grant session cookie.
|
||||||
|
pub cookie_nonce: Option<String>,
|
||||||
|
|
||||||
|
/// Random single-use string passed in the provider redirect.
|
||||||
|
pub query_nonce: Option<String>,
|
||||||
|
|
||||||
|
/// Point in time the authorization grant session expires.
|
||||||
|
pub authorize_expires_at: Option<SystemTime>,
|
||||||
|
|
||||||
|
/// Associated User Id registration.
|
||||||
|
pub user_id: Option<OwnedUserId>,
|
||||||
|
|
||||||
|
/// Last userinfo response persisted here.
|
||||||
|
pub user_info: Option<UserInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of characters generated for our code_verifier. The code_verifier is a
|
||||||
|
/// random string which must be between 43 and 128 characters.
|
||||||
|
pub const CODE_VERIFIER_LENGTH: usize = 64;
|
||||||
|
|
||||||
|
/// Number of characters we will generate for the Session ID.
|
||||||
|
pub const SESSION_ID_LENGTH: usize = 32;
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
||||||
|
Self {
|
||||||
|
_services: args.services.clone(),
|
||||||
|
providers,
|
||||||
|
db: Data {
|
||||||
|
oauthid_session: args.db["oauthid_session"].clone(),
|
||||||
|
oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
|
||||||
|
userid_oauthid: args.db["userid_oauthid"].clone(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
||||||
|
self.db
|
||||||
|
.userid_oauthid
|
||||||
|
.keys()
|
||||||
|
.expect_ok()
|
||||||
|
.map(UserId::to_owned)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete database state for the session.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self))]
|
||||||
|
pub async fn delete(&self, sess_id: &str) {
|
||||||
|
let Ok(session) = self.get(sess_id).await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check the user_id still points to this sess_id before deleting. If not, the
|
||||||
|
// association was updated to a newer session.
|
||||||
|
if let Some(user_id) = session.user_id.as_deref() {
|
||||||
|
if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await {
|
||||||
|
if assoc_id == sess_id {
|
||||||
|
self.db.userid_oauthid.remove(user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the unique identity still points to this sess_id before deleting. If
|
||||||
|
// not, the association was updated to a newer session.
|
||||||
|
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||||
|
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||||
|
if let Ok(unique_id) = unique_id((&provider, &session)) {
|
||||||
|
if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await {
|
||||||
|
if assoc_id == sess_id {
|
||||||
|
self.db.oauthuniqid_oauthid.remove(&unique_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.db.oauthid_session.remove(sess_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create or overwrite database state for the session.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "info", skip(self))]
|
||||||
|
pub async fn put(&self, sess_id: &str, session: &Session) {
|
||||||
|
self.db
|
||||||
|
.oauthid_session
|
||||||
|
.raw_put(sess_id, Cbor(session));
|
||||||
|
|
||||||
|
if let Some(user_id) = session.user_id.as_deref() {
|
||||||
|
self.db.userid_oauthid.insert(user_id, sess_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||||
|
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||||
|
if let Ok(unique_id) = unique_id((&provider, session)) {
|
||||||
|
self.db
|
||||||
|
.oauthuniqid_oauthid
|
||||||
|
.insert(&unique_id, sess_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch database state for a session from its associated `sub`, in case
|
||||||
|
/// `sess_id` is not known.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||||
|
pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
|
||||||
|
self.get_sess_id_by_unique_id(unique_id)
|
||||||
|
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch database state for a session from its associated `user_id`, in case
|
||||||
|
/// `sess_id` is not known.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||||
|
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
|
||||||
|
self.get_sess_id_by_user(user_id)
|
||||||
|
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch database state for a session from its `sess_id`.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||||
|
pub async fn get(&self, sess_id: &str) -> Result<Session> {
|
||||||
|
self.db
|
||||||
|
.oauthid_session
|
||||||
|
.get(sess_id)
|
||||||
|
.await
|
||||||
|
.deserialized::<Cbor<_>>()
|
||||||
|
.map(at!(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve the `sess_id` from an associated `user_id`.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||||
|
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
|
||||||
|
self.db
|
||||||
|
.userid_oauthid
|
||||||
|
.get(user_id)
|
||||||
|
.await
|
||||||
|
.deserialized()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve the `sess_id` from an associated provider subject id.
|
||||||
|
#[implement(Sessions)]
|
||||||
|
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||||
|
pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
|
||||||
|
self.db
|
||||||
|
.oauthuniqid_oauthid
|
||||||
|
.get(unique_id)
|
||||||
|
.await
|
||||||
|
.deserialized()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[implement(Sessions)]
|
||||||
|
pub async fn provider(&self, session: &Session) -> Result<Provider> {
|
||||||
|
let Some(idp_id) = session.idp_id.as_deref() else {
|
||||||
|
return Err!(Request(NotFound("No provider for this session")));
|
||||||
|
};
|
||||||
|
|
||||||
|
self.providers.get(idp_id).await
|
||||||
|
}
|
||||||
39
src/service/oauth/user_info.rs
Normal file
39
src/service/oauth/user_info.rs
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Selection of userinfo response claims.
|
||||||
|
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||||
|
pub struct UserInfo {
|
||||||
|
/// Unique identifier number or login username. Usually a number on most
|
||||||
|
/// services. We consider a concatenation of the `iss` and `sub` to be a
|
||||||
|
/// universally unique identifier for some user/identity; we index that in
|
||||||
|
/// `oauthidpsub_oauthid`.
|
||||||
|
///
|
||||||
|
/// Considered for user mxid only if none of the better fields are defined.
|
||||||
|
/// `login` alias intended for github.
|
||||||
|
#[serde(alias = "login")]
|
||||||
|
pub sub: String,
|
||||||
|
|
||||||
|
/// The login username we first consider when defined.
|
||||||
|
pub preferred_username: Option<String>,
|
||||||
|
|
||||||
|
/// The login username considered if none preferred.
|
||||||
|
pub nickname: Option<String>,
|
||||||
|
|
||||||
|
/// Full name.
|
||||||
|
pub name: Option<String>,
|
||||||
|
|
||||||
|
/// First name.
|
||||||
|
pub given_name: Option<String>,
|
||||||
|
|
||||||
|
/// Last name.
|
||||||
|
pub family_name: Option<String>,
|
||||||
|
|
||||||
|
/// Email address (`email` scope).
|
||||||
|
pub email: Option<String>,
|
||||||
|
|
||||||
|
/// URL to pfp (github/gitlab)
|
||||||
|
pub avatar_url: Option<String>,
|
||||||
|
|
||||||
|
/// URL to pfp (google)
|
||||||
|
pub picture: Option<String>,
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ use crate::{
|
|||||||
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
|
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
|
||||||
key_backups,
|
key_backups,
|
||||||
manager::Manager,
|
manager::Manager,
|
||||||
media, membership, presence, pusher, resolver, rooms, sending, server_keys,
|
media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys,
|
||||||
service::{Args, Service},
|
service::{Args, Service},
|
||||||
sync, transaction_ids, uiaa, users,
|
sync, transaction_ids, uiaa, users,
|
||||||
};
|
};
|
||||||
@@ -58,6 +58,7 @@ pub struct Services {
|
|||||||
pub users: Arc<users::Service>,
|
pub users: Arc<users::Service>,
|
||||||
pub membership: Arc<membership::Service>,
|
pub membership: Arc<membership::Service>,
|
||||||
pub deactivate: Arc<deactivate::Service>,
|
pub deactivate: Arc<deactivate::Service>,
|
||||||
|
pub oauth: Arc<oauth::Service>,
|
||||||
|
|
||||||
manager: Mutex<Option<Arc<Manager>>>,
|
manager: Mutex<Option<Arc<Manager>>>,
|
||||||
pub server: Arc<Server>,
|
pub server: Arc<Server>,
|
||||||
@@ -115,6 +116,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
|
|||||||
users: users::Service::build(&args)?,
|
users: users::Service::build(&args)?,
|
||||||
membership: membership::Service::build(&args)?,
|
membership: membership::Service::build(&args)?,
|
||||||
deactivate: deactivate::Service::build(&args)?,
|
deactivate: deactivate::Service::build(&args)?,
|
||||||
|
oauth: oauth::Service::build(&args)?,
|
||||||
|
|
||||||
manager: Mutex::new(None),
|
manager: Mutex::new(None),
|
||||||
server,
|
server,
|
||||||
@@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
|
|||||||
cast!(self.users),
|
cast!(self.users),
|
||||||
cast!(self.membership),
|
cast!(self.membership),
|
||||||
cast!(self.deactivate),
|
cast!(self.deactivate),
|
||||||
|
cast!(self.oauth),
|
||||||
]
|
]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ use ruma::{
|
|||||||
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
|
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
|
||||||
};
|
};
|
||||||
use tuwunel_core::{
|
use tuwunel_core::{
|
||||||
Err, Result, debug_warn, err, is_equal_to,
|
Err, Result, debug_warn, err,
|
||||||
pdu::PduBuilder,
|
pdu::PduBuilder,
|
||||||
trace,
|
trace,
|
||||||
utils::{self, ReadyExt, stream::TryIgnore},
|
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
|
||||||
warn,
|
warn,
|
||||||
};
|
};
|
||||||
use tuwunel_database::{Deserialized, Json, Map};
|
use tuwunel_database::{Deserialized, Json, Map};
|
||||||
@@ -132,6 +132,16 @@ impl Service {
|
|||||||
|
|
||||||
/// Deactivate account
|
/// Deactivate account
|
||||||
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
|
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
|
||||||
|
// Revoke any SSO authorizations
|
||||||
|
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
|
||||||
|
self.services
|
||||||
|
.oauth
|
||||||
|
.revoke_token((&provider, &session))
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
// Remove all associated devices
|
// Remove all associated devices
|
||||||
self.all_device_ids(user_id)
|
self.all_device_ids(user_id)
|
||||||
.for_each(|device_id| self.remove_device(user_id, device_id))
|
.for_each(|device_id| self.remove_device(user_id, device_id))
|
||||||
@@ -227,18 +237,15 @@ impl Service {
|
|||||||
// Cannot change the password of a LDAP user. There are two special cases :
|
// Cannot change the password of a LDAP user. There are two special cases :
|
||||||
// - a `None` password can be used to deactivate a LDAP user
|
// - a `None` password can be used to deactivate a LDAP user
|
||||||
// - a "*" password is used as the default password of an active LDAP user
|
// - a "*" password is used as the default password of an active LDAP user
|
||||||
if cfg!(feature = "ldap")
|
//
|
||||||
&& password.is_some()
|
// The above now applies to all non-password origin users including SSO
|
||||||
|
// users. Note that users with no origin are also password-origin users.
|
||||||
|
if password.is_some()
|
||||||
&& password != Some("*")
|
&& password != Some("*")
|
||||||
&& self
|
&& let Ok(origin) = self.origin(user_id).await
|
||||||
.db
|
&& origin != "password"
|
||||||
.userid_origin
|
|
||||||
.get(user_id)
|
|
||||||
.await
|
|
||||||
.deserialized::<String>()
|
|
||||||
.is_ok_and(is_equal_to!("ldap"))
|
|
||||||
{
|
{
|
||||||
return Err!(Request(InvalidParam("Cannot change password of a LDAP user")));
|
return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user.")));
|
||||||
}
|
}
|
||||||
|
|
||||||
password
|
password
|
||||||
|
|||||||
@@ -1834,6 +1834,10 @@
|
|||||||
#
|
#
|
||||||
#one_time_key_limit = 256
|
#one_time_key_limit = 256
|
||||||
|
|
||||||
|
# This item is undocumented. Please contribute documentation for it.
|
||||||
|
#
|
||||||
|
#sso_aware_preferred = false
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[global.tls]
|
#[global.tls]
|
||||||
@@ -2097,6 +2101,131 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#[[global.identity_provider]]
|
||||||
|
|
||||||
|
# The brand-name of the service (e.g. Apple, Facebook, GitHub, GitLab,
|
||||||
|
# Google) or the software (e.g. keycloak, MAS) providing the identity.
|
||||||
|
# When a brand is recognized we apply certain defaults to this config
|
||||||
|
# for your convenience. For certain brands we apply essential internal
|
||||||
|
# workarounds specific to that provider; it is important to configure this
|
||||||
|
# field properly when a provider needs to be recognized (like GitHub for
|
||||||
|
# example). Several configured providers can share the same brand name. It
|
||||||
|
# is not case-sensitive.
|
||||||
|
#
|
||||||
|
#brand =
|
||||||
|
|
||||||
|
# The ID of your OAuth application which the provider generates upon
|
||||||
|
# registration. This ID then uniquely identifies this configuration
|
||||||
|
# instance itself, becoming the identity provider's ID and must be unique
|
||||||
|
# and remain unchanged.
|
||||||
|
#
|
||||||
|
#client_id =
|
||||||
|
|
||||||
|
# Secret key the provider generated for you along with the `client_id`
|
||||||
|
# above. Unlike the `client_id`, the `client_secret` can be changed here
|
||||||
|
# whenever the provider regenerates one for you.
|
||||||
|
#
|
||||||
|
#client_secret =
|
||||||
|
|
||||||
|
# The callback URL configured when registering the OAuth application with
|
||||||
|
# the provider. Tuwunel's callback URL must be strictly formatted exactly
|
||||||
|
# as instructed. The URL host must point directly at the matrix server and
|
||||||
|
# use the following path:
|
||||||
|
# `/_matrix/client/unstable/login/sso/callback/<client_id>` where
|
||||||
|
# `<client_id>` is the same one configured for this provider above.
|
||||||
|
#
|
||||||
|
#callback_url =
|
||||||
|
|
||||||
|
# Optional display-name for this provider instance seen on the login page
|
||||||
|
# by users. It defaults to `brand`. When configuring multiple providers
|
||||||
|
# using the same `brand` this can be set to distinguish them.
|
||||||
|
#
|
||||||
|
#name =
|
||||||
|
|
||||||
|
# Optional icon for the provider. The canonical providers have a default
|
||||||
|
# icon based on the `brand` supplied above when this is not supplied. Note
|
||||||
|
# that it uses an MXC url which is curious in the auth-media era and may
|
||||||
|
# not be reliable.
|
||||||
|
#
|
||||||
|
#icon =
|
||||||
|
|
||||||
|
# Optional list of scopes to authorize. An empty array does not impose any
|
||||||
|
# restrictions from here, effectively defaulting to all scopes you
|
||||||
|
# configured for the OAuth application at the provider. This setting
|
||||||
|
# allows for restricting to a subset of those scopes for this instance.
|
||||||
|
# Note the user can further restrict scopes during their authorization.
|
||||||
|
#
|
||||||
|
#scope = []
|
||||||
|
|
||||||
|
# List of userinfo claims which shape and restrict the way we compute a
|
||||||
|
# Matrix UserId for new registrations. Reviewing Tuwunel's documentation
|
||||||
|
# will be necessary for a complete description in detail. An empty array
|
||||||
|
# imposes no restriction here, avoiding generated fallbacks as much as
|
||||||
|
# possible. For simplicity we reserve a claim called "unique" which can be
|
||||||
|
# listed alone to ensure *only* generated ID's are used for registrations.
|
||||||
|
#
|
||||||
|
#userid_claims = []
|
||||||
|
|
||||||
|
# Issuer URL the provider publishes for you. We have pre-supplied default
|
||||||
|
# values for some of the canonical providers, making this field optional
|
||||||
|
# based on the `brand` set above. Otherwise it is required for OIDC
|
||||||
|
# discovery to acquire additional provider configuration, and it must be
|
||||||
|
# correct to pass validations during various interactions.
|
||||||
|
#
|
||||||
|
#issuer_url =
|
||||||
|
|
||||||
|
# Extra path components after the issuer_url leading to the location of
|
||||||
|
# the `.well-known` directory used for discovery. This will be empty for
|
||||||
|
# specification-compliant providers. We have supplied any known values
|
||||||
|
# based on `brand` (e.g. `/login/oauth` for GitHub).
|
||||||
|
#
|
||||||
|
#base_path =
|
||||||
|
|
||||||
|
# Overrides the `.well-known` location where the provider's OIDC
|
||||||
|
# configuration is found. It is very unlikely you will need to set this;
|
||||||
|
# available for developers or special purposes only.
|
||||||
|
#
|
||||||
|
#discovery_url =
|
||||||
|
|
||||||
|
# Overrides the authorize URL requested during the grant phase. This is
|
||||||
|
# generally discovered or derived automatically, but may be required as a
|
||||||
|
# workaround for any non-standard or undiscoverable provider.
|
||||||
|
#
|
||||||
|
#authorization_url =
|
||||||
|
|
||||||
|
# Overrides the access token URL; the same caveats apply as with the other
|
||||||
|
# URL overrides.
|
||||||
|
#
|
||||||
|
#token_url =
|
||||||
|
|
||||||
|
# Overrides the revocation URL; the same caveats apply as with the other
|
||||||
|
# URL overrides.
|
||||||
|
#
|
||||||
|
#revocation_url =
|
||||||
|
|
||||||
|
# Overrides the introspection URL; the same caveats apply as with the
|
||||||
|
# other URL overrides.
|
||||||
|
#
|
||||||
|
#introspection_url =
|
||||||
|
|
||||||
|
# Overrides the userinfo URL; the same caveats apply as with the other URL
|
||||||
|
# overrides.
|
||||||
|
#
|
||||||
|
#userinfo_url =
|
||||||
|
|
||||||
|
# Whether to perform discovery and adjust this provider's configuration
|
||||||
|
# accordingly. This defaults to true. When true, it is an error when
|
||||||
|
# discovery fails and authorizations will not be attempted to the
|
||||||
|
# provider.
|
||||||
|
#
|
||||||
|
#discovery = true
|
||||||
|
|
||||||
|
# The duration in seconds before a grant authorization session expires.
|
||||||
|
#
|
||||||
|
#grant_session_duration =
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[global.appservice.<ID>]
|
#[global.appservice.<ID>]
|
||||||
|
|
||||||
# The URL for the application service.
|
# The URL for the application service.
|
||||||
|
|||||||
Reference in New Issue
Block a user